def model_forward(i_epoch, model, args, criterion, optimizer, batch, mode='eval'):
txt, segment, mask, img, tgt, idx = batch
tgt = tgt.cuda()
clf_loss = 0.0
tcp_pred_loss = 0.0
if args.model == "bow":
txt = txt.cuda()
out = model(txt)
clf_loss = criterion(out, tgt)
elif args.model == "img":
img = img.cuda()
out = model(img)
clf_loss = criterion(out, tgt)
elif args.model == "concatbow":
txt, img = txt.cuda(), img.cuda()
out = model(txt, img)
clf_loss = criterion(out, tgt)
elif args.model == "bert":
txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda()
out = model(txt, mask, segment)
clf_loss = criterion(out, tgt)
elif args.model == "concatbert":
txt, img = txt.cuda(), img.cuda()
mask, segment = mask.cuda(), segment.cuda()
out = model(txt, mask, segment, img)
clf_loss = criterion(out, tgt)
elif args.model == "latefusion_pdf":
txt, img = txt.cuda(), img.cuda()
mask, segment = mask.cuda(), segment.cuda()
out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred = \
model(txt, mask, segment, img, 'pdf_train')
txt_loss = criterion(txt_logits, tgt)
img_loss = criterion(img_logits, tgt)
clf_loss = txt_loss + img_loss
maeloss = nn.L1Loss(reduction='mean')
label = F.one_hot(tgt, num_classes=args.n_classes)
if args.task_type == "multilabel":
txt_pred = torch.sigmoid(txt_logits)
img_pred = torch.sigmoid(img_logits)
else:
txt_pred = F.softmax(txt_logits, dim=1)
img_pred = F.softmax(img_logits, dim=1)
txt_tcp, _ = torch.max(txt_pred * label, dim=1, keepdim=True)
img_tcp, _ = torch.max(img_pred * label, dim=1, keepdim=True)
tcp_pred_loss = (
maeloss(txt_tcp_pred, txt_tcp.detach()) + maeloss(img_tcp_pred, img_tcp.detach())
)
else:
assert args.model == "mmbt"
txt, img = txt.cuda(), img.cuda()
mask, segment = mask.cuda(), segment.cuda()
out = model(txt, mask, segment, img)
clf_loss = criterion(out, tgt)
loss = clf_loss + tcp_pred_loss
return loss, out, tgt