多模态动态融合模型Predictive Dynamic Fusion阅读与代码分析运行1-信度概念与基础参数指标

多模态动态融合模型Predictive Dynamic Fusion阅读与代码分析运行1-信度概念与基础参数指标

参考文:Cao B, Xia Y, Ding Y, et al. Predictive Dynamic Fusion[J]. arXiv preprint arXiv:2406.04802, 2024.[2406.04802] Predictive Dynamic Fusion

一、理论

今天就先看看论文中的各个指标含义和多模态训练代码的参数吧

文章中一个比较重要的概念就是置信度的概念了,在论文前段,对置信度的扩展比较多同时没有什么具体说明,不知道概念的话读着还是很混乱的;

置信度

在机器学习中,置信度表示模型对其预测结果“有多确定”。
它刻画的是:模型认为自己预测是正确的程度

例如,在分类任务中:“这是正类的概率是 0.92”,那么 0.92 就可以视为模型对该预测的置信度

在监督学习中,给定输入样本 xxx,模型预测类别为 y^\hat{y}y^​,则置信度通常定义为:

即:模型对预测类别的后验概率估计

置信度 和 不确定性(补充)

文中用来衡量整体不确定性,算是置信度的一种扩展:

关于熵的概念,之前在b站看到的一位up主讲的很生动:https://www.bilibili.com/video/BV15V411W7VB/

置信度高 <=> 熵低

分类评价指标对比

指标含义对照

指标一句话解释
Accuracy模型整体准不准
Precision模型说“是”的时候靠谱吗
Recall真正“是”的有没有被找全
F1Precision 和 Recall 的折中
ROC-AUC正样本排在负样本前面的能力

The Mono-Confidences and Holo-Confidences

该文的目的之一是为了解决模态权重融合的权重问题;也就是,多个模态分别从多个维度评价目标的状态,给出不一样的结果,怎么融合这几个结果的问题。

目前可以确定的是:融合权重 ω 应当与损失 l 呈负相关,并且与其他模态的损失呈正相关。也就是:当前模态越可靠 → 权重越大;其他模态越不可靠 → 当前模态权重越大

对单个模态的模型,权重 ω 是要求的权重,损失loss是:

所以,就有人两个信度指标:

The Mono-ConfidencesHolo-Confidences
当前模态本身有多可靠相对其他模态我有多可靠

将他们统合:

Co-Belief(协同信度)

Mono-Confidence:只看自己;Holo-Confidence:只看别人;但多模态融合需要:既考虑自身可靠性,又考虑整体模态状态。

故有:

再由协同信度确定该模态的权重。

理论先到这里,其他的后面再看;

二、代码

1、运行环境

代码训练环境没有明确说明,但根据结构可以看得出来用的是autodl里的云服务器,Ubuntu20.04+python3.11的版本,卡随便租一个都一样。

论文附带代码只有2mb,明显缺失了很多预训练结构与数据集文件;

2、数据集文件

这里选用了代码中可选的第二个训练集MVSA_Single,需要自己到网站下好转到autodl服务器上:MVSA_Single

训练集之类的划分源代码已有了,自己按要求放到同一目录下即可。

3、词向量文件

源代码缺失了预训练好的词向量文件glove.840B.300d,需要自己使用指令下载到指定目录

wget https://nlp.stanford.edu/data/glove.840B.300d.zip

4、源代码逻辑错误

训练代码中的forward函数存在运行逻辑错误,文本和图像的loss(txt_clf_loss和img_clf_loss)定义在了if之外,会运行不成功;估计是作者没有仔细整理,代码算法逻辑倒没什么问题;

原代码150行左右:

def model_forward(i_epoch, model, args, criterion,optimizer, batch,mode='eval'): txt, segment, mask, img, tgt,idx = batch freeze_img = i_epoch < args.freeze_img freeze_txt = i_epoch < args.freeze_txt if args.model == "bow": txt = txt.cuda() out = model(txt) elif args.model == "img": img = img.cuda() out = model(img) elif args.model == "concatbow": txt, img = txt.cuda(), img.cuda() out = model(txt, img) elif args.model == "bert": txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda() out = model(txt, mask, segment) elif args.model == "concatbert": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) elif args.model == "latefusion_pdf": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() tgt = tgt.cuda() maeloss = nn.L1Loss(reduction='mean') out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred = model(txt, mask,segment,img,'pdf_train') label = F.one_hot(tgt, num_classes=args.n_classes) # [b,c] if args.task_type == "multilabel": txt_pred = torch.sigmoid(txt_logits) img_pred = torch.sigmoid(img_logits) else: txt_pred = torch.nn.functional.softmax(txt_logits, dim=1) img_pred = torch.nn.functional.softmax(img_logits, dim=1) txt_tcp, _ = torch.max(txt_pred * label, dim=1,keepdim=True) img_tcp, _ = torch.max(img_pred * label, dim=1,keepdim=True) tcp_pred_loss = maeloss(txt_tcp_pred, txt_tcp.detach()) + maeloss(img_tcp_pred, img_tcp.detach()) else: assert args.model == "mmbt" for param in model.enc.img_encoder.parameters(): param.requires_grad = not freeze_img for param in model.enc.encoder.parameters(): param.requires_grad = not freeze_txt txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) tgt = tgt.cuda() txt_clf_loss = nn.CrossEntropyLoss()(txt_logits, tgt) img_clf_loss = nn.CrossEntropyLoss()(img_logits, tgt) clf_loss=txt_clf_loss+img_clf_loss+nn.CrossEntropyLoss()(out,tgt) if mode=='train': loss = torch.mean(clf_loss)+torch.mean(tcp_pred_loss) return loss,out,tgt else: loss= torch.mean(clf_loss)+torch.mean(tcp_pred_loss) return loss,out,tgt

修改后:

def model_forward(i_epoch, model, args, criterion, optimizer, batch, mode='eval'): txt, segment, mask, img, tgt, idx = batch tgt = tgt.cuda() clf_loss = 0.0 tcp_pred_loss = 0.0 # ⭐ 先初始化,避免炸 # ---------- 普通单 / 早期融合模型 ---------- if args.model == "bow": txt = txt.cuda() out = model(txt) clf_loss = criterion(out, tgt) elif args.model == "img": img = img.cuda() out = model(img) clf_loss = criterion(out, tgt) elif args.model == "concatbow": txt, img = txt.cuda(), img.cuda() out = model(txt, img) clf_loss = criterion(out, tgt) elif args.model == "bert": txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda() out = model(txt, mask, segment) clf_loss = criterion(out, tgt) elif args.model == "concatbert": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) clf_loss = criterion(out, tgt) # ---------- late fusion(特例) ---------- elif args.model == "latefusion_pdf": txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred = \ model(txt, mask, segment, img, 'pdf_train') # 分类 loss txt_loss = criterion(txt_logits, tgt) img_loss = criterion(img_logits, tgt) clf_loss = txt_loss + img_loss # TCP loss maeloss = nn.L1Loss(reduction='mean') label = F.one_hot(tgt, num_classes=args.n_classes) if args.task_type == "multilabel": txt_pred = torch.sigmoid(txt_logits) img_pred = torch.sigmoid(img_logits) else: txt_pred = F.softmax(txt_logits, dim=1) img_pred = F.softmax(img_logits, dim=1) txt_tcp, _ = torch.max(txt_pred * label, dim=1, keepdim=True) img_tcp, _ = torch.max(img_pred * label, dim=1, keepdim=True) tcp_pred_loss = ( maeloss(txt_tcp_pred, txt_tcp.detach()) + maeloss(img_tcp_pred, img_tcp.detach()) ) # ---------- mmbt ---------- else: assert args.model == "mmbt" txt, img = txt.cuda(), img.cuda() mask, segment = mask.cuda(), segment.cuda() out = model(txt, mask, segment, img) clf_loss = criterion(out, tgt) # ---------- 总 loss ---------- loss = clf_loss + tcp_pred_loss return loss, out, tgt

四、各训练参数

主要是get_args里面的参数解释:

训练与优化相关参数

参数名默认值含义说明影响阶段备注 / 建议
batch_sz128每个 batch 的样本数量训练大 batch 更稳定,但占显存
gradient_accumulation_steps24梯度累积步数训练等效 batch = batch_sz × steps
lr1e-4初始学习率训练BERT 微调常用 1e-5~5e-5
weight_decay0.0权重衰减系数(L2 正则)训练防止过拟合
dropout0.1Dropout 概率模型Transformer 常用 0.1
max_epochs100最大训练轮数训练搭配 early stopping
patience10Early stopping 容忍轮数训练验证集无提升时停止
warmup0.1学习率 warmup 比例训练防止初期梯度震荡
lr_factor0.5学习率衰减倍率训练ReduceLROnPlateau
lr_patience2学习率衰减等待轮数训练验证集不提升则降 lr
seed123随机种子全局保证实验可复现
n_workers12DataLoader 线程数数据加载与 CPU 核数相关

文本模态:

参数名默认值含义说明影响阶段备注
bert_model./bert-base-uncasedBERT 预训练模型路径模型可换成 large
freeze_txt0是否冻结文本编码器训练1 表示不更新 BERT
max_seq_len512文本最大 token 长度数据BERT 上限
embed_sz300词向量维度模型对应 GloVe
glove_pathglove.840B.300d.txtGloVe 文件路径数据300 维
hidden_sz768文本隐藏层维度模型BERT-base 默认

图像模态(Image)相关参数

参数名默认值含义说明影响阶段备注
img_hidden_sz2048图像特征维度模型ResNet 输出
num_image_embeds1图像 token 数模型MMBT 中常见
img_embed_pool_typeavg图像特征池化方式模型avg / max
freeze_img0是否冻结图像编码器训练1 表示冻结
drop_img_percent0.0随机丢弃图像比例数据增强模态缺失模拟

融合参数:

参数名默认值含义说明影响阶段备注
modellatefusion_pdf使用的模型结构模型PDF = Predictive Dynamic Fusion
hidden[]额外隐藏层结构模型如 [512,256]
include_bnTrue是否使用 BatchNorm模型提高训练稳定性
dfTrue是否启用动态融合模型PDF 核心开关
baselineNone对比方法名称实验仅用于记录

任务与数据相关参数:

参数名默认值含义说明影响阶段备注
taskMVSA_Single使用的数据集数据多模态情绪识别
task_typeclassification任务类型训练单标签 / 多标签
weight_classes1是否类别加权loss类别不平衡时用
noise0.0标签噪声比例数据鲁棒性实验
data_path/path/to/data_dir/数据集路径数据必须配置
savedir/path/to/save_dir/模型保存路径输出checkpoint

其中,很多任务数据相关参数都需要调整

Read more

公益服务平台信息管理系统源码-SpringBoot后端+Vue前端+MySQL【可直接运行】

公益服务平台信息管理系统源码-SpringBoot后端+Vue前端+MySQL【可直接运行】

摘要 随着社会公益事业的快速发展,公益服务平台的数字化管理需求日益增长。传统的手工记录和分散式管理方式效率低下,难以满足现代公益组织对信息整合、资源共享和高效协作的需求。公益服务平台信息管理系统的开发旨在解决这一问题,通过信息化手段实现公益项目的规范化、透明化和高效化管理。该系统能够整合志愿者、受助者、捐赠资源等多方信息,提升公益服务的可追溯性和协作效率,同时为公益组织提供数据支持,助力其优化资源配置和决策制定。关键词:公益服务、信息管理、数字化、资源整合、高效协作。 本系统采用SpringBoot作为后端框架,结合Vue.js前端技术和MySQL数据库,构建了一套完整的公益服务平台信息管理系统。SpringBoot提供了高效的开发环境和稳定的后端支持,Vue.js实现了动态交互和友好的用户界面,MySQL则确保了数据的安全存储和高效查询。系统功能涵盖用户管理、帮扶信息管理、捐赠项目管理等模块,支持多角色权限控制、数据可视化分析和实时信息更新。通过前后端分离的设计,系统具备良好的扩展性和维护性,能够适应不同规模公益组织的需求。关键词:SpringBoot、Vue.js、MySQL、

突破亚马逊壁垒,Web Unlocker API 助您轻松获取数据

突破亚马逊壁垒,Web Unlocker API 助您轻松获取数据

目录 * 一、Web Unlocker API简介 * 二、开始使用Web Unlocker API * 1、首先进入控制台页面,点击左侧第一个tab键“代理 & 抓取基础设施”,找到“网页解锁器”,开始使用。 * 2、进入网页解锁器页面后,填写通道名称,添加简短描述,点击添加 * 3、直接展示代理基础设施/web_unlocker3的详细信息 * 4、配置网页解锁器 * 5、以Python脚本获取亚马逊平台数据为示例 * 6、结果示例 * 三、Web Scraper * 1、快速使用Web Scraper * 2、通过python获取亚马逊网页数据 * 3、定位具体数据 * 4、运行并保存到csv文件 * 四、SERP API * 五、优惠升级

Google Stitch 2.0 深度解析:AI 驱动的前端革命,从像素到生产力的全栈跨越

Google Stitch 2.0 深度解析:AI 驱动的前端革命,从像素到生产力的全栈跨越

在人工智能迅速蚕食传统开发流程的今天,谷歌推出的 Stitch 2.0 不仅仅是一个简单的 UI 生成工具更新,它标志着前端开发进入了一个全新的“意图驱动”时代。通过将自然语言描述、草图或截图直接转化为生产级别的代码,Stitch 2.0 正在重新定义设计师与开发者之间的协作边界,并让“全栈 AI 编程助手”的概念真正落地。 核心引擎的进化:Gemini 3.0 Pro 带来的视觉推理 Stitch 2.0 的质变源于底层模型的升级。通过默认集成 Gemini 3.0 Pro,该工具在逻辑推理和视觉布局质量上实现了跨越式提升。 从“画饼”到“工程化”的布局生成 不同于早期的 AI 工具只能生成零散的元素,Gemini 3.0 Pro

JavaScript WebAPI 核心操作指南

JavaScript WebAPI 核心操作指南

JavaScript(WebAPI) WebAPI 背景知识 什么是 WebAPI 前面学习的 JS 分成三个大的部分: * ECMAScript:基础语法部分 * DOM API:操作页面结构 * BOM API:操作浏览器 WebAPI 就包含了 DOM + BOM。 这个是 W3C 组织规定的(和制定 ECMAScript 标准的大佬们不是一伙人)。 前面学的 JS 基础语法主要学的是 ECMAScript,这让我们建立基本的编程思维,相当于练武需要先扎马步。但是真正来写一个更加复杂的有交互式的页面,还需要 WebAPI 的支持,相当于各种招式。 什么是 API API 是一个更广义的概念,而 WebAPI 是一个更具体的概念,特指 DOM+BOM。 所谓的 API