多模态动态融合模型 Predictive Dynamic Fusion 阅读与代码分析
参考论文:Cao B, Xia Y, Ding Y, et al. Predictive Dynamic Fusion[J]. arXiv preprint arXiv:2406.04802, 2024.
一、理论核心
本文主要解析论文中的置信度(Confidence)概念及多模态训练代码的参数配置。
1. 置信度 (Confidence)
在机器学习中,置信度表示模型对其预测结果'有多确定',即模型认为自己预测正确的程度。
例如,在分类任务中,若模型输出'这是正类的概率是 0.92',则 0.92 可视为该预测的置信度。在监督学习中,给定输入样本 $x$,模型预测类别为 $ ilde{y}$,置信度通常定义为模型对预测类别的后验概率估计。
2. 置信度与不确定性
文中使用**熵 (Entropy)**来衡量整体不确定性,可视作置信度的扩展。置信度高对应熵低。
分类评价指标对比
| 指标 | 含义 |
|---|---|
| Accuracy | 模型整体准确率 |
| Precision | 预测为正类的样本中真正的比例 |
| Recall | 真正正类样本中被找出的比例 |
| F1 | Precision 和 Recall 的调和平均 |
| ROC-AUC | 区分正负样本的能力 |

3. 模态融合权重机制
该文旨在解决多模态权重融合问题。多个模态从不同维度评价目标状态,需融合结果。融合权重 $ heta$ 应当与损失 $l$ 呈负相关,并与其他模态的损失呈正相关。即:当前模态越可靠,权重越大;其他模态越不可靠,当前模态权重越大。
Mono-Confidences 与 Holo-Confidences
- Mono-Confidences: 当前模态本身有多可靠。
- Holo-Confidences: 相对其他模态我有多可靠。

Co-Belief (协同信度)
多模态融合需既考虑自身可靠性,又考虑整体模态状态。由协同信度确定该模态的最终权重。

二、代码实现
1. 运行环境
建议使用 Ubuntu 20.04 + Python 3.11 的云服务器环境。
2. 数据集文件
选用代码中可选的训练集 MVSA_Single,需自行下载至服务器目录。
3. 词向量文件
源代码缺失预训练词向量 glove.840B.300d,需手动下载至指定目录:
wget https://nlp.stanford.edu/data/glove.840B.300d.zip
4. 代码逻辑修复
训练代码中的 forward 函数存在逻辑错误,文本和图像的 loss 定义位置不当,会导致运行失败。以下是修正后的逻辑结构:
def model_forward(i_epoch, model, args, criterion, optimizer, batch, mode='eval'):
txt, segment, mask, img, tgt, idx = batch
tgt = tgt.cuda()
clf_loss = 0.0
tcp_pred_loss = 0.0
# 普通单 / 早期融合模型
if args.model == "bow":
txt = txt.cuda()
out = model(txt)
clf_loss = criterion(out, tgt)
elif args.model == "img":
img = img.cuda()
out = model(img)
clf_loss = criterion(out, tgt)
elif args.model == "concatbow":
txt, img = txt.cuda(), img.cuda()
out = model(txt, img)
clf_loss = criterion(out, tgt)
elif args.model == "bert":
txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda()
out = model(txt, mask, segment)
clf_loss = criterion(out, tgt)
elif args.model == "concatbert":
txt, img = txt.cuda(), img.cuda()
mask, segment = mask.cuda(), segment.cuda()
out = model(txt, mask, segment, img)
clf_loss = criterion(out, tgt)
# late fusion(特例)
elif args.model == "latefusion_pdf":
txt, img = txt.cuda(), img.cuda()
mask, segment = mask.cuda(), segment.cuda()
out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred = \
model(txt, mask, segment, img, 'pdf_train')
# 分类 loss
txt_loss = criterion(txt_logits, tgt)
img_loss = criterion(img_logits, tgt)
clf_loss = txt_loss + img_loss
# TCP loss
maeloss = nn.L1Loss(reduction='mean')
label = F.one_hot(tgt, num_classes=args.n_classes)
if args.task_type == "multilabel":
txt_pred = torch.sigmoid(txt_logits)
img_pred = torch.sigmoid(img_logits)
else:
txt_pred = F.softmax(txt_logits, dim=1)
img_pred = F.softmax(img_logits, dim=1)
txt_tcp, _ = torch.(txt_pred * label, dim=, keepdim=)
img_tcp, _ = torch.(img_pred * label, dim=, keepdim=)
tcp_pred_loss = (
maeloss(txt_tcp_pred, txt_tcp.detach()) + maeloss(img_tcp_pred, img_tcp.detach())
)
:
args.model ==
txt, img = txt.cuda(), img.cuda()
mask, segment = mask.cuda(), segment.cuda()
out = model(txt, mask, segment, img)
clf_loss = criterion(out, tgt)
loss = clf_loss + tcp_pred_loss
loss, out, tgt
三、各训练参数
主要是 get_args 中的参数解释。
训练与优化相关参数
| 参数名 | 默认值 | 含义说明 | 影响阶段 | 备注 |
|---|---|---|---|---|
batch_sz | 128 | 每个 batch 的样本数量 | 训练 | 大 batch 更稳定,但占显存 |
gradient_accumulation_steps | 24 | 梯度累积步数 | 训练 | 等效 batch = batch_sz × steps |
lr | 1e-4 | 初始学习率 | 训练 | BERT 微调常用 1e-5~5e-5 |
weight_decay | 0.0 | 权重衰减系数(L2 正则) | 训练 | 防止过拟合 |
dropout | 0.1 | Dropout 概率 | 模型 | Transformer 常用 0.1 |
max_epochs | 100 | 最大训练轮数 | 训练 | 搭配 early stopping |
patience | 10 | Early stopping 容忍轮数 | 训练 | 验证集无提升时停止 |
warmup | 0.1 | 学习率 warmup 比例 | 训练 | 防止初期梯度震荡 |
lr_factor | 0.5 | 学习率衰减倍率 | 训练 | ReduceLROnPlateau |
lr_patience | 2 | 学习率衰减等待轮数 | 训练 | 验证集不提升则降 lr |
seed | 123 | 随机种子 | 全局 | 保证实验可复现 |
文本模态参数
| 参数名 | 默认值 | 含义说明 | 影响阶段 | 备注 |
|---|---|---|---|---|
bert_model | ./bert-base-uncased | BERT 预训练模型路径 | 模型 | 可换成 large |
freeze_txt | 0 | 是否冻结文本编码器 | 训练 | 1 表示不更新 BERT |
max_seq_len | 512 | 文本最大 token 长度 | 数据 | BERT 上限 |
embed_sz | 300 | 词向量维度 | 模型 | 对应 GloVe |
glove_path | glove.840B.300d.txt | GloVe 文件路径 | 数据 | 300 维 |
hidden_sz | 768 | 文本隐藏层维度 | 模型 | BERT-base 默认 |
图像模态参数
| 参数名 | 默认值 | 含义说明 | 影响阶段 | 备注 |
|---|---|---|---|---|
img_hidden_sz | 2048 | 图像特征维度 | 模型 | ResNet 输出 |
num_image_embeds | 1 | 图像 token 数 | 模型 | MMBT 中常见 |
img_embed_pool_type | avg | 图像特征池化方式 | 模型 | avg / max |
freeze_img | 0 | 是否冻结图像编码器 | 训练 | 1 表示冻结 |
drop_img_percent | 0.0 | 随机丢弃图像比例 | 数据增强 | 模态缺失模拟 |
融合参数
| 参数名 | 默认值 | 含义说明 | 影响阶段 | 备注 |
|---|---|---|---|---|
model | latefusion_pdf | 使用的模型结构 | 模型 | PDF = Predictive Dynamic Fusion |
hidden | [] | 额外隐藏层结构 | 模型 | 如 [512,256] |
include_bn | True | 是否使用 BatchNorm | 模型 | 提高训练稳定性 |
df | True | 是否启用动态融合 | 模型 | PDF 核心开关 |
baseline | None | 对比方法名称 | 实验 | 仅用于记录 |
任务与数据相关参数
| 参数名 | 默认值 | 含义说明 | 影响阶段 | 备注 |
|---|---|---|---|---|
task | MVSA_Single | 使用的数据集 | 数据 | 多模态情绪识别 |
task_type | classification | 任务类型 | 训练 | 单标签 / 多标签 |
weight_classes | 1 | 是否类别加权 | loss | 类别不平衡时用 |
noise | 0.0 | 标签噪声比例 | 数据 | 鲁棒性实验 |
data_path | /path/to/data_dir/ | 数据集路径 | 数据 | 必须配置 |
savedir | /path/to/save_dir/ | 模型保存路径 | 输出 | checkpoint |


