跳到主要内容
医学影像分类器实践:基于深度学习的肺结节检测 | 极客日志
Python AI 算法
医学影像分类器实践:基于深度学习的肺结节检测 本项目利用深度学习技术开发肺结节检测分类器,基于 CT 影像区分良性和恶性结节。重点聚焦卷积神经网络(CNN)、视觉变换器(ViT)以及多模态方法。使用 LUNA16 数据集,整合 Transformer 原理,支持 3D 处理和分割任务。内容涵盖数据预处理、模型实现、评估优化及隐私保护技术。提供 ResNet-50、ViT、多模态融合及 UNETR 分割任务的完整代码示例,结合 Grad-CAM 与注意力可视化增强可解释性,旨在辅助临床诊断并降低漏诊率。
落日余晖 发布于 2026/4/10 更新于 2026/4/26 3 浏览AI 大模型实践项目:医学影像分类器(肺结节检测)
本项目利用深度学习技术开发肺结节检测分类器,基于 CT 影像区分良性和恶性结节,聚焦 卷积神经网络(CNN) 、视觉变换器(Vision Transformer, ViT) 以及受 Med-PaLM 启发的多模态方法。使用 LUNA16 数据集,整合 Transformer 原理(自注意力、位置编码),增强代码支持 3D 处理和分割任务,新增高级可视化和隐私保护技术(如联邦学习)。
一、项目概述
1.1 项目目标
功能 :构建分类器,检测 CT 影像中的肺结节(良性/恶性)。
医学意义 :肺结节是肺癌早期标志,自动分类可辅助诊断,降低漏诊率。
技术目标 :
掌握深度学习工作流:数据预处理、模型训练、评估。
实现高召回率(Recall),减少假阴性(漏诊)。
比较 CNN、ViT 和多模态模型在医学影像中的性能。
提供可解释性(如 Grad-CAM),增强医生信任。
1.2 数据集
LUNA16 (Lung Nodule Analysis 2016):
RSNA (Radiological Society of North America):
数据挑战 :
类不平衡 :恶性结节样本少(约 10-20%)。
高维数据 :3D CT 需降维或分块处理。
噪声与伪影 :CT 影像可能包含扫描噪声或金属伪影。
隐私保护 :需符合《个人信息保护法》和 HIPAA/GDPR。
1.3 技术栈
PyTorch :灵活实现 CNN、ViT 和 3D 模型。
Hugging Face :提供预训练 ViT 和多模态模型支持。
pydicom :读取和处理 DICOM 格式 CT 影像。
MONAI :医学影像专用框架,支持 3D 数据处理和分割。
scikit-learn/seaborn :评估指标(混淆矩阵、ROC 曲线)和可视化。
Chart.js :性能对比图表。
Flower :联邦学习框架,支持隐私保护训练。
1.4 医学影像分类挑战
数据稀缺 :高质量标注数据有限,需迁移学习或数据增强。
高召回需求 :漏诊(假阴性)成本高,需优化召回率。
3D 数据复杂性 :CT 体视显微镜数据需高效处理。
可解释性 :模型预测需与医学知识一致,需 Grad-CAM 或注意力可视化。
计算成本 :3D 模型和 ViT 训练需高性能 GPU(如 NVIDIA A100)。
伦理与法规 :确保公平性,保护患者隐私,符合医疗标准。
二、理论基础
2.1 卷积神经网络(CNN)
架构 :
卷积层 :提取局部特征(如结节边缘、纹理)。
池化层 :降维,保留关键信息。
残差连接 (ResNet):通过 $y = x + F(x)$ 缓解梯度消失。
3D CNN :扩展卷积核为 3D(如 3×3×3),直接处理 CT 体视显微镜数据。
数学基础 :
卷积操作 (2D):$Y(i,j) = \sum_m \sum_n X(i+m, j+n) \cdot K(m,n) + b$
$X$: 输入影像,$K$: 卷积核,$b$: 偏置。
3D 卷积 :$Y(i,j,k) = \sum_m \sum_n \sum_p X(i+m, j+n, k+p) \cdot K(m,n,p) + b$
损失函数 :$L = -\sum_i y_i \log(\hat{y}_i) + \lambda \sum ||W||_2^2$
适用性 :高效提取局部特征,适合小区域结节检测;3D CNN 适配体视显微镜数据。
2.2 Vision Transformer (ViT)
架构 (结合历史对话中的 Transformer):
图像分块 :将 CT 影像分割为 Patch(如 16×16),展平为向量序列。
位置编码 :添加正弦位置编码,保留 Patch 空间信息:
$E_{\text{pos}}(i, 2k) = \sin(i / 10000^{2k/d}), \quad E_{\text{pos}}(i, 2k+1) = \cos(i / 10000^{2k/d})$
Transformer 编码器 :多头自注意力(Multi-Head Attention)捕捉 Patch 间全局依赖。
分类头 :CLS Token 或全局池化输出分类结果。
数学基础 :
Patch 嵌入 :$z_0 = [x_{\text{class}}; x_p^1 W_E; x_p^2 W_E; \dots; x_p^N W_E] + E_{\text{pos}}$
$x_p^i$: 第 $i$ 个 Patch,$W_E$: 嵌入矩阵。
自注意力 :$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$
$Q, K, V \in \mathbb{R}^{N \times d_k}$,$N$: Patch 数量,$d_k$: 嵌入维度。
多头注意力 (历史对话):$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W_O$
$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$。
适用性 :全局建模能力强,适合复杂影像特征;需大规模预训练。
2.3 多模态模型(受 Med-PaLM 启发)
架构 :
影像模块 :ViT 处理 CT 影像。
文本模块 :BERT 处理临床报告(如病史)。
融合模块 :跨模态注意力整合影像和文本特征。
数学基础 :
跨模态注意力 (历史对话):$\text{Attention}(Q_{\text{text}}, K_{\text{image}}, V_{\text{image}}) = \text{softmax}\left(\frac{Q_{\text{text}}K_{\text{image}}^T}{\sqrt{d_k}}\right)V_{\text{image}}$
联合损失 :$L = \alpha L_{\text{class}} + \beta L_{\text{align}}$
$L_{\text{class}}$: 分类损失,$L_{\text{align}}$: 影像 - 文本对齐损失(如 CLIP 损失)。
适用性 :结合临床信息,提升诊断精度,适合综合诊断。
2.4 迁移学习与 LoRA
预训练 :
CNN:ImageNet 预训练 ResNet-50,学习通用视觉特征。
ViT:ImageNet 或 CheXpert 预训练 ViT,适配医学影像。
LoRA(低秩适配) :
仅更新低秩矩阵 $\Delta W = BA$,减少微调参数量:
$W' = W + \Delta W, \quad \Delta W = BA, \quad B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k}$
适合 LUNA16 小数据集,降低计算成本。
优势 :加速训练,适配小数据集,减少过拟合。
2.5 评估指标
混淆矩阵 :计算真阳性(TP)、假阳性(FP)、真阴性(TN)、假阴性(FN)。
指标 :
准确率:$\text{Accuracy} = \frac{TP+TN}{TP+TN+FP+FN}$
精确率:$\text{Precision} = \frac{TP}{TP+FP}$
召回率:$\text{Recall} = \frac{TP}{TP+FN}$(医学中关键)。
F1 分数:$\text{F1} = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}$
ROC 曲线与 AUC :绘制真阳性率(TPR)对假阳性率(FPR),AUC 量化区分能力。
可解释性 :Grad-CAM 和注意力热图,突出模型关注的结节区域。
三、数据预处理
3.1 LUNA16 数据集处理
数据格式 :DICOM 文件,3D CT 扫描(512×512×N 片)。
标注 :CSV 文件,提供结节坐标(x, y, z)和类别(0: 良性,1: 恶性)。
预处理步骤 :
读取 DICOM :使用 pydicom 加载 3D CT 影像。
归一化 :将 Hounsfield 单位(HU)归一到 [0,1]:$I_{\text{norm}} = \frac{I - \min(I)}{\max(I) - \min(I)}$
提取结节 :基于坐标提取 3D 体视显微镜块(如 32×32×32)或 2D 切片。
数据增强 :旋转、翻转、缩放、添加噪声,增加多样性。
数据集划分 :80% 训练,10% 验证,10% 测试(分层确保类平衡)。
3.2 实现示例(Python) 以下为 LUNA16 数据预处理代码,支持 2D 和 3D 数据:
import pydicom
import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from monai.transforms import Compose, Resize, RandRotate, RandFlip, ToTensor
class LUNA16Dataset (Dataset ):
def __init__ (self, dicom_dir, annotations_file, mode='2d' , transform=None ):
"""
LUNA16 数据集
:param dicom_dir: DICOM 文件目录
:param annotations_file: 标注 CSV 文件
:param mode: '2d' 或 '3d'(切片或体视显微镜)
:param transform: 数据增强
"""
self .dicom_dir = dicom_dir
self .annotations = pd.read_csv(annotations_file)
self .mode = mode
self .transform = transform
def __len__ (self ):
return len (self .annotations)
def __getitem__ (self, idx ):
dicom_id = self .annotations.iloc[idx]['dicom_id' ]
dicom_path = os.path.join(self .dicom_dir, dicom_id)
ds = pydicom.dcmread(dicom_path)
image = ds.pixel_array.astype(np.float32)
image = (image - np.min (image)) / (np.max (image) - np.min (image) + 1e-6 )
if self .mode == '2d' :
x, y, w, h, z = self .annotations.iloc[idx][['x' ,'y' ,'width' ,'height' ,'z' ]].values
image = image[z, y:y+h, x:x+w]
else :
x, y, z, w, h, d = self .annotations.iloc[idx][['x' ,'y' ,'z' ,'width' ,'height' ,'depth' ]].values
image = image[z:z+d, y:y+h, x:x+w]
if self .transform:
if self .mode == '2d' :
augmented = self .transform(image=image)
image = augmented['image' ]
else :
image = self .transform(image[np.newaxis,...])[0 ]
label = self .annotations.iloc[idx]['label' ]
return {'image' : image, 'label' : torch.tensor(label, dtype=torch.long)}
transform_2d = A.Compose([
A.Resize(224 , 224 ),
A.Rotate(limit=30 , p=0.5 ),
A.HorizontalFlip(p=0.5 ),
A.RandomBrightnessContrast(p=0.3 ),
A.Normalize(mean=[0.5 ], std=[0.5 ]),
ToTensorV2()
])
transform_3d = Compose([
Resize(spatial_size=(32 , 32 , 32 )),
RandRotate(range_x=30 , prob=0.5 ),
RandFlip(spatial_axis=0 , prob=0.5 ),
ToTensor()
])
dataset_2d = LUNA16Dataset(
dicom_dir='path/to/luna16' ,
annotations_file='annotations.csv' ,
mode='2d' ,
transform=transform_2d
)
dataset_3d = LUNA16Dataset(
dicom_dir='path/to/luna16' ,
annotations_file='annotations.csv' ,
mode='3d' ,
transform=transform_3d
)
pydicom :读取 DICOM 文件,提取像素数组。
模式选择 :支持 2D 切片(224×224)和 3D 体视显微镜块(32×32×32)。
数据增强 :
2D:旋转、翻转、亮度/对比度调整(albumentations)。
3D:体视显微镜旋转、翻转(MONAI)。
归一化 :将 Hounsfield 单位归一到 [0,1]。
注意 :需替换 dicom_dir 和 annotations_file 为实际路径。
四、模型实现
4.1 CNN 实现(ResNet-50,3D 支持) 基于 ResNet-50,支持 2D 和 3D CT 影像分类:
import torch
import torch.nn as nn
from torchvision.models import resnet50
from monai.networks.nets import ResNet
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
class ResNet3D (nn.Module):
def __init__ (self, num_classes=2 ):
super ().__init__()
self .resnet = ResNet(block='bottleneck' , layers=[3 ,4 ,6 ,3 ], spatial_dims=3 , n_input_channels=1 , num_classes=num_classes)
def forward (self, x ):
return self .resnet(x)
class ResNet2D (nn.Module):
def __init__ (self, num_classes=2 ):
super ().__init__()
self .resnet = resnet50(pretrained=True )
self .resnet.conv1 = nn.Conv2d(1 , 64 , kernel_size=7 , stride=2 , padding=3 )
self .resnet.fc = nn.Linear(self .resnet.fc.in_features, num_classes)
def forward (self, x ):
return self .resnet(x)
def train_model (model, dataloader, criterion, optimizer, num_epochs=10 , device='cuda' ):
model = model.to(device)
train_losses = []
for epoch in range (num_epochs):
model.train()
running_loss = 0.0
for batch in dataloader:
images = batch['image' ].to(device)
labels = batch['label' ].to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len (dataloader)
train_losses.append(avg_loss)
print (f'Epoch [{epoch+1 } /{num_epochs} ], Loss: {avg_loss:.4 f} ' )
return train_losses
dataloader_2d = DataLoader(dataset_2d, batch_size=16 , shuffle=True )
dataloader_3d = DataLoader(dataset_3d, batch_size=8 , shuffle=True )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
model_2d = ResNet2D(num_classes=2 )
model_3d = ResNet3D(num_classes=2 )
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.3 , 0.7 ]).to(device))
optimizer_2d = torch.optim.Adam(model_2d.parameters(), lr=1e-4 , weight_decay=1e-5 )
optimizer_3d = torch.optim.Adam(model_3d.parameters(), lr=1e-4 , weight_decay=1e-5 )
train_losses_2d = train_model(model_2d, dataloader_2d, criterion, optimizer_2d, device=device)
train_losses_3d = train_model(model_3d, dataloader_3d, criterion, optimizer_3d, device=device)
def evaluate_model (model, dataloader, device='cuda' ):
model.eval ()
predictions, true_labels = [], []
with torch.no_grad():
for batch in dataloader:
images = batch['image' ].to(device)
labels = batch['label' ].to(device)
outputs = model(images)
preds = torch.argmax(outputs, dim=1 )
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
return predictions, true_labels
predictions_2d, true_labels_2d = evaluate_model(model_2d, dataloader_2d)
predictions_3d, true_labels_3d = evaluate_model(model_3d, dataloader_3d)
print ("2D ResNet 准确率:" , accuracy_score(true_labels_2d, predictions_2d))
print ("3D ResNet 准确率:" , accuracy_score(true_labels_3d, predictions_3d))
模型 :2D ResNet-50(ImageNet 预训练)和 3D ResNet(MONAI 实现)。
损失函数 :加权交叉熵,权重 [0.3, 0.7] 应对恶性结节稀缺。
优化器 :Adam,学习率 1e-4,L2 正则化防止过拟合。
注意 :3D 模型需更大显存(如 16GB),批大小减小至 8。
4.2 ViT 实现(Hugging Face,LoRA) 基于 ViT,结合 LoRA 微调,支持注意力可视化:
from transformers import ViTImageProcessor, ViTForImageClassification
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224' )
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224' , num_labels=2 )
lora_config = LoraConfig(r=8 , lora_alpha=16 , target_modules=["query" , "value" ])
model = get_peft_model(model, lora_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
model = model.to(device)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.3 , 0.7 ]).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5 )
dataloader = DataLoader(dataset_2d, batch_size=16 , shuffle=True )
train_losses = []
for epoch in range (10 ):
model.train()
running_loss = 0.0
for batch in dataloader:
images = batch['image' ].to(device)
labels = batch['label' ].to(device)
inputs = processor(images, return_tensors='pt' , do_rescale=False ).to(device)
outputs = model(**inputs).logits
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len (dataloader)
train_losses.append(avg_loss)
print (f'Epoch [{epoch+1 } /10], Loss: {avg_loss:.4 f} ' )
def visualize_attention (model, image, processor, device='cuda' ):
model.eval ()
inputs = processor(image, return_tensors='pt' , do_rescale=False ).to(device)
with torch.no_grad():
outputs = model(**inputs, output_attentions=True )
attentions = outputs.attentions[-1 ].mean(dim=1 ).squeeze(0 )
h, w = image.shape[-2 :]
attn_map = attentions.mean(dim=0 ).reshape(14 , 14 ).cpu().numpy()
attn_map = np.resize(attn_map, (h, w))
plt.imshow(image.squeeze(0 ), cmap='gray' )
plt.imshow(attn_map, cmap='jet' , alpha=0.5 )
plt.title('ViT 注意力热图' )
plt.show()
predictions, true_labels = [], []
with torch.no_grad():
for batch in dataloader:
images = batch['image' ].to(device)
labels = batch['label' ].to(device)
inputs = processor(images, return_tensors='pt' , do_rescale=False ).to(device)
outputs = model(**inputs).logits
preds = torch.argmax(outputs, dim=1 )
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
print ("ViT 准确率:" , accuracy_score(true_labels, predictions))
sample_image = dataset_2d[0 ]['image' ]
visualize_attention(model, sample_image, processor)
ViT :预训练 ViT-base,修改分类头为 2 类。
LoRA :微调 query 和 value 矩阵,减少参数量。
注意力可视化 :展示最后一层注意力热图,突出模型关注区域。
注意 :仅支持 2D 影像,3D ViT 需扩展(见 4.4)。
4.3 多模态实现(受 Med-PaLM 启发) 结合 CT 影像和临床文本(如病史),实现多模态分类:
from transformers import ViTModel, BertTokenizer, BertModel
import torch.nn as nn
class MultiModalLungNoduleClassifier (nn.Module):
def __init__ (self, num_labels=2 ):
super ().__init__()
self .vit = ViTModel.from_pretrained('google/vit-base-patch16-224' )
self .bert = BertModel.from_pretrained('bert-base-uncased' )
self .fusion = nn.Linear(768 +768 , 512 )
self .classifier = nn.Linear(512 , num_labels)
self .relu = nn.ReLU()
self .dropout = nn.Dropout(0.1 )
def forward (self, image_inputs, text_inputs ):
vit_outputs = self .vit(**image_inputs).pooler_output
bert_outputs = self .bert(**text_inputs).pooler_output
combined = torch.cat((vit_outputs, bert_outputs), dim=-1 )
combined = self .relu(self .fusion(combined))
combined = self .dropout(combined)
logits = self .classifier(combined)
return logits
class LUNA16MultiModalDataset (Dataset ):
def __init__ (self, dicom_dir, annotations_file, texts, transform=None ):
self .dataset = LUNA16Dataset(dicom_dir, annotations_file, mode='2d' , transform=transform)
self .texts = texts
self .tokenizer = BertTokenizer.from_pretrained('bert-base-uncased' )
def __getitem__ (self, idx ):
item = self .dataset[idx]
text = self .texts[idx]
text_inputs = self .tokenizer(text, max_length=128 , padding='max_length' , truncation=True , return_tensors='pt' )
item['text_inputs' ] = {k: v.squeeze(0 ) for k, v in text_inputs.items()}
return item
def __len__ (self ):
return len (self .dataset)
texts = ["Patient with cough and fever, suspected malignancy." ] * len (dataset_2d)
multimodal_dataset = LUNA16MultiModalDataset('path/to/luna16' , 'annotations.csv' , texts, transform=transform_2d)
dataloader = DataLoader(multimodal_dataset, batch_size=16 , shuffle=True )
model = MultiModalLungNoduleClassifier(num_labels=2 ).to(device)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.3 , 0.7 ]).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5 )
for epoch in range (10 ):
model.train()
running_loss = 0.0
for batch in dataloader:
images = batch['image' ].to(device)
labels = batch['label' ].to(device)
image_inputs = processor(images, return_tensors='pt' , do_rescale=False ).to(device)
text_inputs = {k: v.to(device) for k, v in batch['text_inputs' ].items()}
outputs = model(image_inputs, text_inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
print (f'Epoch [{epoch+1 } /10], Loss: {running_loss/len (dataloader):.4 f} ' )
模型 :ViT(影像)+ BERT(文本),通过线性层融合特征。
数据 :扩展 LUNA16 数据集,添加模拟临床文本。
注意 :需真实临床文本(如病历),可从 MIMIC-III 获取。
4.4 分割任务(3D U-Net+ViT) 为肺结节分割,基于 MONAI 的 UNETR(U-Net+ViT):
from monai.networks.nets import UNETR
from monai.data import DataLoader, Dataset as MonaiDataset
from monai.transforms import LoadImageD, EnsureChannelFirstD, Compose
transform_seg = Compose([
LoadImageD(keys=['image' ]),
EnsureChannelFirstD(keys=['image' ]),
Resize(spatial_size=(32 , 32 , 32 )),
ToTensor()
])
seg_data = [{'image' : f'path/to/luna16/{i} .dcm' , 'mask' : f'path/to/mask/{i} .nii' } for i in range (100 )]
seg_dataset = MonaiDataset(seg_data, transform=transform_seg)
seg_dataloader = DataLoader(seg_dataset, batch_size=4 , shuffle=True )
model = UNETR(in_channels=1 , out_channels=2 , img_size=(32 , 32 , 32 ), feature_size=16 ).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4 )
for epoch in range (10 ):
model.train()
running_loss = 0.0
for batch in seg_dataloader:
images = batch['image' ].to(device)
masks = batch['mask' ].to(device)
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
print (f'Epoch [{epoch+1 } /10], Loss: {running_loss/len (seg_dataloader):.4 f} ' )
UNETR :结合 ViT 和 U-Net,处理 3D CT 分割。
数据 :假设掩膜(mask)标注,需从 LUNA16 或 BraTS 获取。
注意 :分割任务需更大显存(推荐 24GB)。
五、评估与优化
5.1 评估方法
交叉验证 :5 折分层 K 折,确保类不平衡数据评估稳定。
混淆矩阵 :计算 TP、FP、FN、TN,重点优化召回率。
ROC 曲线与 AUC :评估模型区分能力。
Dice 分数 (分割任务):评估分割精度:$\text{Dice} = \frac{2 |P \cap G|}{|P| + |G|}$
5.2 实现示例(Python) from sklearn.metrics import confusion_matrix, roc_curve, auc, classification_report
from monai.metrics import DiceMetric
import seaborn as sns
import matplotlib.pyplot as plt
def evaluate_classification (model, dataloader, processor=None , device='cuda' ):
model.eval ()
predictions, true_labels, probs = [], [], []
with torch.no_grad():
for batch in dataloader:
images = batch['image' ].to(device)
labels = batch['label' ].to(device)
if processor:
inputs = processor(images, return_tensors='pt' , do_rescale=False ).to(device)
outputs = model(**inputs).logits
else :
outputs = model(images)
preds = torch.argmax(outputs, dim=1 )
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
probs.extend(torch.softmax(outputs, dim=1 )[:, 1 ].cpu().numpy())
cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True , fmt='d' , cmap='Blues' , xticklabels=['良性' , '恶性' ], yticklabels=['良性' , '恶性' ])
plt.xlabel('预测' )
plt.ylabel('真实' )
plt.title('混淆矩阵' )
plt.show()
print (classification_report(true_labels, predictions, target_names=['良性' , '恶性' ]))
fpr, tpr, _ = roc_curve(true_labels, probs)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'ROC 曲线 (AUC = {roc_auc:.2 f} )' )
plt.plot([0 , 1 ], [0 , 1 ], 'k--' )
plt.xlabel('假阳性率' )
plt.ylabel('真阳性率' )
plt.title('ROC 曲线' )
plt.legend()
plt.show()
def evaluate_segmentation (model, dataloader, device='cuda' ):
dice_metric = DiceMetric(include_background=False , reduction='mean' )
model.eval ()
dice_scores = []
with torch.no_grad():
for batch in dataloader:
images = batch['image' ].to(device)
masks = batch['mask' ].to(device)
outputs = model(images)
preds = torch.argmax(outputs, dim=1 , keepdim=True )
dice_metric(preds, masks)
dice_score = dice_metric.aggregate().item()
dice_scores.append(dice_score)
dice_metric.reset()
print (f"Dice 分数:{dice_score:.4 f} " )
evaluate_classification(model_2d, dataloader_2d)
evaluate_classification(model, dataloader, processor)
evaluate_segmentation(model, seg_dataloader)
分类评估 :生成混淆矩阵、分类报告和 ROC 曲线,重点关注召回率。
分割评估 :使用 Dice 分数评估分割精度。
可视化 :Seaborn 绘制混淆矩阵,Matplotlib 绘制 ROC 曲线。
5.3 优化策略
类不平衡 :
加权损失:恶性结节权重 0.7,良性 0.3。
过采样:SMOTE 或重复采样恶性样本。
正则化 :Dropout(0.1)、L2 权重衰减(1e-5)。
超参数调优 :
学习率:网格搜索 [1e-5, 2e-5, 1e-4, 1e-3]。
批大小:2D 模型 16,3D 模型 8。
早停 :验证集损失 3 个 epoch 无下降时停止。
联邦学习 :使用 Flower 框架,实现跨医院隐私保护训练。
六、工作流与可视化
6.1 优化工作流流程图 以下是优化的医学影像分类和分割工作流,涵盖分类和分割任务:
graph TD
subgraph Preprocess
A[输入数据] --> B[预处理]
B --> C{模式选择}
C -->|2D| D[2D 预处理]
C -->|3D| E[3D 预处理]
D --> F[2D 分类]
E --> G[3D 分类]
E --> H[分割]
end
subgraph Model_Selection
F --> I{模型选择}
G --> I
H --> J[UNETR]
I -->|ResNet| K[ResNet-50]
I -->|ViT| L[ViT]
I -->|多模态| M[ViT+BERT]
end
subgraph Train_Eval
K --> N[训练]
L --> N
M --> N
J --> N
N --> O[评估]
O --> P{是否收敛?}
P -->|否| Q[调整参数]
Q --> N
P -->|是| R[输出结果]
end
subgraph Viz
R --> S[可解释性]
S --> T[Grad-CAM/SHAP]
end
节点文本简化 :避免冒号和长文本,保持简洁。
子图名称规范化 :使用英文标识符,降低解析负担。
分支标签简化 :移除空格,保持清晰。
精简描述 :核心逻辑不变,涵盖输入、预处理、模型选择、训练、评估、可解释性和输出。
逻辑保持一致 :支持 2D 分类、3D 分类和分割任务,涵盖 ResNet-50、ViT、多模态和 UNETR。
6.2 图表:CNN 与 ViT 性能对比 以下是 CNN 与 ViT 在肺结节分类上的性能对比折线图配置(假设数据):
{
"type" : "line" ,
"data" : {
"labels" : [ "2 折" , "3 折" , "5 折" , "10 折" ] ,
"datasets" : [
{
"label" : "ResNet 召回率" ,
"data" : [ 0.88 , 0.90 , 0.91 , 0.90 ] ,
"borderColor" : "#FF6384" ,
"fill" : false
} ,
{
"label" : "ViT 召回率" ,
"data" : [ 0.90 , 0.92 , 0.93 , 0.92 ] ,
"borderColor" : "#36A2EB" ,
"fill" : false
}
]
} ,
"options" : {
"title" : {
"display" : true ,
"text" : "CNN 与 ViT 召回率对比(肺结节分类)"
} ,
"scales" : {
"xAxes" : [ {
"scaleLabel" : {
"display" : true ,
"labelString" : "交叉验证折数"
}
} ] ,
"yAxes" : [ {
"scaleLabel" : {
"display" : true ,
"labelString" : "召回率"
} ,
"ticks" : {
"min" : 0.8 ,
"max" : 1.0
}
} ]
}
}
}
图表类型 :折线图,比较 ResNet 与 ViT 在不同折数下的召回率。
X 轴 :交叉验证折数(2、3、5、10)。
Y 轴 :召回率,范围 0.8-1.0,医学中关键。
数据 :假设数据,ViT 略优于 ResNet,反映全局建模优势。
生成说明 :可将 Chart.js 配置复制到支持工具生成图表。
6.3 图表:模型性能对比 以下是 ResNet-50(2D/3D)、ViT 和多模态模型在召回率上的对比(假设数据):
{
"type" : "bar" ,
"data" : {
"labels" : [ "2D ResNet-50" , "3D ResNet-50" , "ViT" , "多模态" ] ,
"datasets" : [
{
"label" : "召回率" ,
"data" : [ 0.88 , 0.90 , 0.92 , 0.94 ] ,
"backgroundColor" : [ "#FF6384" , "#36A2EB" , "#FFCE56" , "#4BC0C0" ] ,
"borderColor" : [ "#FF6384" , "#36A2EB" , "#FFCE56" , "#4BC0C0" ] ,
"borderWidth" : 1
} ,
{
"label" : "精确率" ,
"data" : [ 0.85 , 0.87 , 0.89 , 0.91 ] ,
"backgroundColor" : [ "#FF6384" , "#36A2EB" , "#FFCE56" , "#4BC0C0" ] ,
"borderColor" : [ "#FF6384" , "#36A2EB" , "#FFCE56" , "#4BC0C0" ] ,
"borderWidth" : 1
}
]
} ,
"options" : {
"scales" : {
"y" : {
"beginAtZero" : true ,
"title" : {
"display" : true ,
"text" : "性能指标"
}
} ,
"x" : {
"title" : {
"display" : true ,
"text" : "模型"
}
}
} ,
"plugins" : {
"title" : {
"display" : true ,
"text" : "模型性能对比(肺结节分类)"
}
}
}
}
X 轴 :模型类型(2D ResNet-50、3D ResNet-50、ViT、多模态)。
Y 轴 :召回率和精确率,医学中召回率优先。
数据 :假设数据,多模态模型因融合文本信息表现最佳。
生成 :复制代码至 Chart.js 工具渲染。
6.4 图表:训练时间对比 {
"type" : "bar" ,
"data" : {
"labels" : [ "2D ResNet-50" , "3D ResNet-50" , "ViT" , "多模态" , "UNETR" ] ,
"datasets" : [ {
"label" : "训练时间(小时)" ,
"data" : [ 2.0 , 5.0 , 3.0 , 6.0 , 8.0 ] ,
"backgroundColor" : [ "#FF6384" , "#36A2EB" , "#FFCE56" , "#4BC0C0" , "#9966FF" ] ,
"borderColor" : [ "#FF6384" , "#36A2EB" , "#FFCE56" , "#4BC0C0" , "#9966FF" ] ,
"borderWidth" : 1
} ]
} ,
"options" : {
"scales" : {
"y" : {
"beginAtZero" : true ,
"title" : {
"display" : true ,
"text" : "训练时间(小时)"
}
} ,
"x" : {
"title" : {
"display" : true ,
"text" : "模型"
}
}
} ,
"plugins" : {
"title" : {
"display" : true ,
"text" : "模型训练时间对比"
}
}
}
}
2D ResNet-50 :高效,最短训练时间(2 小时)。
3D ResNet-50 :处理体视显微镜数据,时间增加(5 小时)。
ViT :中等复杂度(3 小时)。
多模态 :融合影像和文本,时间较长(6 小时)。
UNETR :分割任务复杂,时间最长(8 小时)。
七、应用与展望
7.1 应用
疾病分类 :检测肺结节(良性/恶性),召回率达 94%(多模态,假设数据)。
分割任务 :精准定位结节边界,辅助手术规划。
多模态诊断 :结合 CT 和临床文本(如病史),提升诊断精度。
实时诊断 :部署模型于医院 PACS 系统,实现快速初步诊断。
数据集扩展 :验证模型在 RSNA 或 BraTS 数据集上的泛化性。
7.2 展望
3D 模型增强 :开发 3D ViT(如 UNETR),直接处理体视显微镜数据。
多模态扩展 :整合影像、文本、基因数据,构建统一诊断模型。
联邦学习 :
原理 :跨医院分布式训练,保护患者隐私:$W_{t+1} = \sum_{k=1}^K \frac{n_k}{N} W_k$
$W_k$: 医院 $k$ 的模型权重,$n_k$: 数据量,$N$: 总数据量。
框架 :使用 Flower(https://flower.dev/)实现 FedAvg。
优势 :符合《个人信息保护法》,提升数据利用率。
可解释性 :
Grad-CAM:突出结节区域。
SHAP/LIME:量化特征贡献,增强医生信任。
自动化流水线 :开发端到端系统,从 DICOM 读取到诊断报告生成。
八、环境准备
GPU 推荐 :NVIDIA A100(24GB)或 RTX 3090(16GB)。
CPU 可运行 ,但 3D 模型较慢。
pip install torch torchvision transformers peft monai pydicom albumentations scikit-learn seaborn matplotlib flower
LUNA16 :下载(~120GB,需注册),替换 dicom_dir 和 annotations_file。
RSNA :下载 CTA 数据,更新路径。
BraTS (可选):用于 MRI 分割任务。
复制代码至 Chart.js 工具渲染。
若需真实数据,请提供 LUNA16 实验结果。
2D ResNet-50 :2 小时,召回率 ~88%(假设)。
3D ResNet-50 :5 小时,召回率 ~90%(假设)。
ViT :3 小时,召回率 ~92%(假设)。
多模态 :6 小时,召回率 ~94%(假设)。
UNETR :8 小时,Dice 分数 ~0.85(假设)。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online