from monai.transforms import (
LoadImaged, AddChanneld, Spacingd, Orientationd,
ScaleIntensityRanged, RandCropByPosNegLabeld, Compose, EnsureTyped
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.data import Dataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
import torch
import numpy as np
import matplotlib.pyplot as plt
data_dir = "./medical_images/"
train_files = [
{"image": data_dir + "patient1_CT.nii.gz", "label": data_dir + "patient1_mask.nii.gz"},
{"image": data_dir + "patient2_CT.nii.gz", "label": data_dir + "patient2_mask.nii.gz"}
]
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.0, 1.0, 2.0),
mode=("bilinear", "nearest")
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-1000, a_max=1000,
b_min=0.0, b_max=1.0, clip=True
),
RandCropByPosNegLabeld(
keys=["image", "label"], label_key="label",
spatial_size=(96, 96, 64), pos=1, neg=1, num_samples=4
),
EnsureTyped(keys=["image", "label"], dtype=torch.float32)
])
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
spatial_dims=3, in_channels=1, out_channels=2,
channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")
max_epochs = 50
for epoch in range(max_epochs):
model.train()
epoch_loss = 0
for batch_data in train_loader:
inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(train_loader)
model.eval()
with torch.no_grad():
metric_sum = 0.0
for val_data in train_loader:
val_images, val_labels = val_data["image"].to(device), val_data["label"].to(device)
val_outputs = sliding_window_inference(val_images, (96, 96, 64), 4, model)
val_outputs = [torch.argmax(i, dim=1) for i in decollate_batch(val_outputs)]
dice_metric(y_pred=val_outputs, y=val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
print(f"Epoch {epoch+1}/{max_epochs}, Loss: {epoch_loss:.4f}, Dice: {metric:.4f}")
def visualize_slice(image, label, prediction, slice_index=25):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image[0, 0, :, :, slice_index], cmap="gray")
axes[0].set_title("Input Image")
axes[0].axis("off")
axes[1].imshow(label[0, 0, :, :, slice_index], cmap="jet")
axes[1].set_title("Ground Truth")
axes[1].axis("off")
axes[2].imshow(prediction[0, 0, :, :, slice_index], cmap="jet")
axes[2].set_title("Prediction")
axes[2].axis("off")
plt.show()
test_data = train_ds[0]
image = test_data["image"].unsqueeze(0).to(device)
with torch.no_grad():
prediction = sliding_window_inference(image, (96, 96, 64), 4, model)
prediction = torch.argmax(prediction, dim=1, keepdim=True)
visualize_slice(image.cpu().numpy(), test_data["label"].unsqueeze(0).numpy(), prediction.cpu().numpy())
from monai.transforms import RandRotated, RandFlipd, RandZoomd
train_transforms.insert(6, Compose([
RandRotated(keys=["image", "label"], range_x=0.3, prob=0.5),
RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, prob=0.5)
]))
from monai.networks.nets import SwinUNETR
model = SwinUNETR(img_size=(96, 96, 64), in_channels=1, out_channels=2, feature_size=48).to(device)
scaler = torch.cuda.amp.GradScaler()
import pandas as pd
import polars as pl
import pygwalker as pyg
from datetime import datetime
data = {
"patient_id": [1001, 1002, 1003, 1001, 1004],
"visit_date": ["2023-01-15", "2023-02-20", "2023-01-05", "2023-03-10", "2023-02-28"],
"diagnosis": ["Hypertension", "Diabetes", "Hypertension", "Asthma", "Diabetes"],
"medication": ["Lisinopril", "Metformin", "Amlodipine", "Albuterol", "Insulin"],
"age": [45, 62, 58, 36, 70],
"blood_pressure": ["140/90", "130/85", "150/95", "120/80", "145/88"],
"lab_result": [None, 6.5, 7.1, None, 8.0]
}
df_pd = pd.DataFrame(data)
df_pd["visit_date"] = pd.to_datetime(df_pd["visit_date"])
df_pl = pl.DataFrame(data).with_columns(pl.col("visit_date").str.to_date("%Y-%m-%d"))
df_processed = df_pl.to_pandas()
walker = pyg.walk(
df_processed,
spec="./ehr_analysis.json",
dark="light",
show_cloud_tool=False,
field_specs={
"age": {"analyticType": "dimension"},
"lab_result": {"analyticType": "measure"}
}
)