2025 腾讯广告算法大赛 Baseline 项目解析
项目概述
2025 腾讯广告算法大赛 Baseline,一个简单的序列推荐系统,主要用于建模用户和物品的交互序列,并利用多模态特征(文本、图像等 embedding)来提升推荐效果。
核心文件功能
1. main.py - 主训练脚本
- 负责模型训练的整体流程
- 包含参数解析、数据加载、模型初始化、训练循环等
- 支持断点续训和仅推理模式
- 使用 TensorBoard 记录训练日志
main.py 代码
import argparse import json import os import time from pathlib import Path import numpy as np import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from dataset import MyDataset from model import BaselineModel defget_args(): parser = argparse.ArgumentParser()# Train params parser.add_argument('--batch_size', default=128,type=int) parser.add_argument('--lr', default=0.001,type=float) parser.add_argument('--maxlen', default=101,type=int)# Baseline Model construction parser.add_argument('--hidden_units', default=32,type=int) parser.add_argument('--num_blocks', default=1,type=int) parser.add_argument('--num_epochs', default=3,type=int) parser.add_argument('--num_heads', default=1,type=int) parser.add_argument('--dropout_rate', default=0.2,type=float) parser.add_argument('--l2_emb', default=0.0,type=float) parser.add_argument('--device', default='cuda',type=str) parser.add_argument('--inference_only', action='store_true') parser.add_argument('--state_dict_path', default=None,type=str) parser.add_argument('--norm_first', action='store_true')# MMemb Feature ID parser.add_argument('--mm_emb_id', nargs='+', default=['81'],type=str, choices=[str(s)for s inrange(81,87)]) args = parser.parse_args()return args if __name__ =='__main__': Path(os.environ.get('TRAIN_LOG_PATH')).mkdir(parents=True, exist_ok=True) Path(os.environ.get('TRAIN_TF_EVENTS_PATH')).mkdir(parents=True, exist_ok=True) log_file =open(Path(os.environ.get('TRAIN_LOG_PATH'),'train.log'),'w') writer = SummaryWriter(os.environ.get('TRAIN_TF_EVENTS_PATH'))# global dataset data_path = os.environ.get('TRAIN_DATA_PATH') args = get_args() dataset = MyDataset(data_path, args) train_dataset, valid_dataset = torch.utils.data.random_split(dataset,[0.9,0.1]) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn ) valid_loader = DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn ) usernum, itemnum = dataset.usernum, dataset.itemnum feat_statistics, feat_types = dataset.feat_statistics, dataset.feature_types model = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)for name, param in model.named_parameters():try: torch.nn.init.xavier_normal_(param.data)except Exception:pass model.pos_emb.weight.data[0,:]=0 model.item_emb.weight.data[0,:]=0 model.user_emb.weight.data[0,:]=0for k in model.sparse_emb: model.sparse_emb[k].weight.data[0,:]=0 epoch_start_idx =1if args.state_dict_path isnotNone:try: model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device))) tail = args.state_dict_path[args.state_dict_path.find('epoch=')+6:] epoch_start_idx =int(tail[: tail.find('.')])+1except:print('failed loading state_dicts, pls check file path: ', end="")print(args.state_dict_path)raise RuntimeError('failed loading state_dicts, pls check file path!') bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean') optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9,0.98)) best_val_ndcg, best_val_hr =0.0,0.0 best_test_ndcg, best_test_hr =0.0,0.0 T =0.0 t0 = time.time() global_step =0print("Start training")for epoch inrange(epoch_start_idx, args.num_epochs +1): model.train()if args.inference_only:breakfor step, batch in tqdm(enumerate(train_loader), total=len(train_loader)): seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch seq = seq.to(args.device) pos = pos.to(args.device) neg = neg.to(args.device) pos_logits, neg_logits = model( seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat ) pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros( neg_logits.shape, device=args.device ) optimizer.zero_grad() indices = np.where(next_token_type ==1) loss = bce_criterion(pos_logits[indices], pos_labels[indices]) loss += bce_criterion(neg_logits[indices], neg_labels[indices]) log_json = json.dumps({'global_step': global_step,'loss': loss.item(),'epoch': epoch,'time': time.time()}) log_file.write(log_json +'\n') log_file.flush()print(log_json) writer.add_scalar('Loss/train', loss.item(), global_step) global_step +=1for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param) loss.backward() optimizer.step() model.eval() valid_loss_sum =0for step, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)): seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch seq = seq.to(args.device) pos = pos.to(args.device) neg = neg.to(args.device) pos_logits, neg_logits = model( seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat ) pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros( neg_logits.shape, device=args.device ) indices = np.where(next_token_type ==1) loss = bce_criterion(pos_logits[indices], pos_labels[indices]) loss += bce_criterion(neg_logits[indices], neg_labels[indices]) valid_loss_sum += loss.item() valid_loss_sum /=len(valid_loader) writer.add_scalar('Loss/valid', valid_loss_sum, global_step) save_dir = Path(os.environ.get('TRAIN_CKPT_PATH'),f"global_step{global_step}.valid_loss={valid_loss_sum:.4f}") save_dir.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), save_dir /"model.pt")print("Done") writer.close() log_file.close()2. model.py - 核心模型实现
BaselineModel - 主推荐模型
基于 Transformer 的序列推荐模型,具有以下特点:
模型架构:
- 使用
FlashMultiHeadAttention实现高效的多头注意力机制 - 采用
PointWiseFeedForward作为前馈网络 - 支持多种特征类型:稀疏特征、数组特征、连续特征、多模态 embedding 特征
特征处理:
- 用户特征:稀疏特征 (103,104,105,109)、数组特征 (106,107,108,110)
- 物品特征:稀疏特征 (100,117,111 等)、多模态 embedding 特征 (81-86)
- 通过
feat2emb方法将不同类型特征转换为统一的 embedding 表示
核心方法:
log2feats:将用户序列转换为特征表示forward:训练时计算正负样本的 logitspredict:推理时生成用户表征save_item_emb:保存物品 embedding 用于检索
model.py 代码
from pathlib import Path import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from dataset import save_emb classFlashMultiHeadAttention(torch.nn.Module):def__init__(self, hidden_units, num_heads, dropout_rate):super(FlashMultiHeadAttention, self).__init__() self.hidden_units = hidden_units self.num_heads = num_heads self.head_dim = hidden_units // num_heads self.dropout_rate = dropout_rate assert hidden_units % num_heads ==0,"hidden_units must be divisible by num_heads" self.q_linear = torch.nn.Linear(hidden_units, hidden_units) self.k_linear = torch.nn.Linear(hidden_units, hidden_units) self.v_linear = torch.nn.Linear(hidden_units, hidden_units) self.out_linear = torch.nn.Linear(hidden_units, hidden_units)defforward(self, query, key, value, attn_mask=None): batch_size, seq_len, _ = query.size()# 计算Q, K, V Q = self.q_linear(query) K = self.k_linear(key) V = self.v_linear(value)# reshape为multi-head格式 Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2) K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2) V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)ifhasattr(F,'scaled_dot_product_attention'):# PyTorch 2.0+ 使用内置的Flash Attention attn_output = F.scaled_dot_product_attention( Q, K, V, dropout_p=self.dropout_rate if self.training else0.0, attn_mask=attn_mask.unsqueeze(1))else:# 降级到标准注意力机制 scale =(self.head_dim)**-0.5 scores = torch.matmul(Q, K.transpose(-2,-1))* scale if attn_mask isnotNone: scores.masked_fill_(attn_mask.unsqueeze(1).logical_not(),float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training) attn_output = torch.matmul(attn_weights, V)# reshape回原来的格式 attn_output = attn_output.transpose(1,2).contiguous().view(batch_size, seq_len, self.hidden_units)# 最终的线性变换 output = self.out_linear(attn_output)return output,NoneclassPointWiseFeedForward(torch.nn.Module):def__init__(self, hidden_units, dropout_rate):super(PointWiseFeedForward, self).__init__() self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) self.dropout1 = torch.nn.Dropout(p=dropout_rate) self.relu = torch.nn.ReLU() self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) self.dropout2 = torch.nn.Dropout(p=dropout_rate)defforward(self, inputs): outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1,-2)))))) outputs = outputs.transpose(-1,-2)# as Conv1D requires (N, C, Length)return outputs classBaselineModel(torch.nn.Module):""" Args: user_num: 用户数量 item_num: 物品数量 feat_statistics: 特征统计信息,key为特征ID,value为特征数量 feat_types: 各个特征的特征类型,key为特征类型名称,value为包含的特征ID列表,包括user和item的sparse, array, emb, continual类型 args: 全局参数 Attributes: user_num: 用户数量 item_num: 物品数量 dev: 设备 norm_first: 是否先归一化 maxlen: 序列最大长度 item_emb: Item Embedding Table user_emb: User Embedding Table sparse_emb: 稀疏特征Embedding Table emb_transform: 多模态特征的线性变换 userdnn: 用户特征拼接后经过的全连接层 itemdnn: 物品特征拼接后经过的全连接层 """def__init__(self, user_num, item_num, feat_statistics, feat_types, args):#super(BaselineModel, self).__init__() self.user_num = user_num self.item_num = item_num self.dev = args.device self.norm_first = args.norm_first self.maxlen = args.maxlen # TODO: loss += args.l2_emb for regularizing embedding vectors during training# https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch self.item_emb = torch.nn.Embedding(self.item_num +1, args.hidden_units, padding_idx=0) self.user_emb = torch.nn.Embedding(self.user_num +1, args.hidden_units, padding_idx=0) self.pos_emb = torch.nn.Embedding(2* args.maxlen +1, args.hidden_units, padding_idx=0) self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) self.sparse_emb = torch.nn.ModuleDict() self.emb_transform = torch.nn.ModuleDict() self.attention_layernorms = torch.nn.ModuleList()# to be Q for self-attention self.attention_layers = torch.nn.ModuleList() self.forward_layernorms = torch.nn.ModuleList() self.forward_layers = torch.nn.ModuleList() self._init_feat_info(feat_statistics, feat_types) userdim = args.hidden_units *(len(self.USER_SPARSE_FEAT)+1+len(self.USER_ARRAY_FEAT))+len( self.USER_CONTINUAL_FEAT ) itemdim =( args.hidden_units *(len(self.ITEM_SPARSE_FEAT)+1+len(self.ITEM_ARRAY_FEAT))+len(self.ITEM_CONTINUAL_FEAT)+ args.hidden_units *len(self.ITEM_EMB_FEAT)) self.userdnn = torch.nn.Linear(userdim, args.hidden_units) self.itemdnn = torch.nn.Linear(itemdim, args.hidden_units) self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)for _ inrange(args.num_blocks): new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) self.attention_layernorms.append(new_attn_layernorm) new_attn_layer = FlashMultiHeadAttention( args.hidden_units, args.num_heads, args.dropout_rate )# 优化:用FlashAttention替代标准Attention self.attention_layers.append(new_attn_layer) new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) self.forward_layernorms.append(new_fwd_layernorm) new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate) self.forward_layers.append(new_fwd_layer)for k in self.USER_SPARSE_FEAT: self.sparse_emb[k]= torch.nn.Embedding(self.USER_SPARSE_FEAT[k]+1, args.hidden_units, padding_idx=0)for k in self.ITEM_SPARSE_FEAT: self.sparse_emb[k]= torch.nn.Embedding(self.ITEM_SPARSE_FEAT[k]+1, args.hidden_units, padding_idx=0)for k in self.ITEM_ARRAY_FEAT: self.sparse_emb[k]= torch.nn.Embedding(self.ITEM_ARRAY_FEAT[k]+1, args.hidden_units, padding_idx=0)for k in self.USER_ARRAY_FEAT: self.sparse_emb[k]= torch.nn.Embedding(self.USER_ARRAY_FEAT[k]+1, args.hidden_units, padding_idx=0)for k in self.ITEM_EMB_FEAT: self.emb_transform[k]= torch.nn.Linear(self.ITEM_EMB_FEAT[k], args.hidden_units)def_init_feat_info(self, feat_statistics, feat_types):""" 将特征统计信息(特征数量)按特征类型分组产生不同的字典,方便声明稀疏特征的Embedding Table Args: feat_statistics: 特征统计信息,key为特征ID,value为特征数量 feat_types: 各个特征的特征类型,key为特征类型名称,value为包含的特征ID列表,包括user和item的sparse, array, emb, continual类型 """ self.USER_SPARSE_FEAT ={k: feat_statistics[k]for k in feat_types['user_sparse']} self.USER_CONTINUAL_FEAT = feat_types['user_continual'] self.ITEM_SPARSE_FEAT ={k: feat_statistics[k]for k in feat_types['item_sparse']} self.ITEM_CONTINUAL_FEAT = feat_types['item_continual'] self.USER_ARRAY_FEAT ={k: feat_statistics[k]for k in feat_types['user_array']} self.ITEM_ARRAY_FEAT ={k: feat_statistics[k]for k in feat_types['item_array']} EMB_SHAPE_DICT ={"81":32,"82":1024,"83":3584,"84":4096,"85":3584,"86":3584} self.ITEM_EMB_FEAT ={k: EMB_SHAPE_DICT[k]for k in feat_types['item_emb']}# 记录的是不同多模态特征的维度deffeat2tensor(self, seq_feature, k):""" Args: seq_feature: 序列特征list,每个元素为当前时刻的特征字典,形状为 [batch_size, maxlen] k: 特征ID Returns: batch_data: 特征值的tensor,形状为 [batch_size, maxlen, max_array_len(if array)] """ batch_size =len(seq_feature)if k in self.ITEM_ARRAY_FEAT or k in self.USER_ARRAY_FEAT:# 如果特征是Array类型,需要先对array进行padding,然后转换为tensor max_array_len =0 max_seq_len =0for i inrange(batch_size): seq_data =[item[k]for item in seq_feature[i]] max_seq_len =max(max_seq_len,len(seq_data)) max_array_len =max(max_array_len,max(len(item_data)for item_data in seq_data)) batch_data = np.zeros((batch_size, max_seq_len, max_array_len), dtype=np.int64)for i inrange(batch_size): seq_data =[item[k]for item in seq_feature[i]]for j, item_data inenumerate(seq_data): actual_len =min(len(item_data), max_array_len) batch_data[i, j,:actual_len]= item_data[:actual_len]return torch.from_numpy(batch_data).to(self.dev)else:# 如果特征是Sparse类型,直接转换为tensor max_seq_len =max(len(seq_feature[i])for i inrange(batch_size)) batch_data = np.zeros((batch_size, max_seq_len), dtype=np.int64)for i inrange(batch_size): seq_data =[item[k]for item in seq_feature[i]] batch_data[i]= seq_data return torch.from_numpy(batch_data).to(self.dev)deffeat2emb(self, seq, feature_array, mask=None, include_user=False):""" Args: seq: 序列ID feature_array: 特征list,每个元素为当前时刻的特征字典 mask: 掩码,1表示item,2表示user include_user: 是否处理用户特征,在两种情况下不打开:1) 训练时在转换正负样本的特征时(因为正负样本都是item);2) 生成候选库item embedding时。 Returns: seqs_emb: 序列特征的Embedding """ seq = seq.to(self.dev)# pre-compute embeddingif include_user: user_mask =(mask ==2).to(self.dev) item_mask =(mask ==1).to(self.dev) user_embedding = self.user_emb(user_mask * seq) item_embedding = self.item_emb(item_mask * seq) item_feat_list =[item_embedding] user_feat_list =[user_embedding]else: item_embedding = self.item_emb(seq) item_feat_list =[item_embedding]# batch-process all feature types all_feat_types =[(self.ITEM_SPARSE_FEAT,'item_sparse', item_feat_list),(self.ITEM_ARRAY_FEAT,'item_array', item_feat_list),(self.ITEM_CONTINUAL_FEAT,'item_continual', item_feat_list),]if include_user: all_feat_types.extend([(self.USER_SPARSE_FEAT,'user_sparse', user_feat_list),(self.USER_ARRAY_FEAT,'user_array', user_feat_list),(self.USER_CONTINUAL_FEAT,'user_continual', user_feat_list),])# batch-process each feature typefor feat_dict, feat_type, feat_list in all_feat_types:ifnot feat_dict:continuefor k in feat_dict: tensor_feature = self.feat2tensor(feature_array, k)if feat_type.endswith('sparse'): feat_list.append(self.sparse_emb[k](tensor_feature))elif feat_type.endswith('array'): feat_list.append(self.sparse_emb[k](tensor_feature).sum(2))elif feat_type.endswith('continual'): feat_list.append(tensor_feature.unsqueeze(2))for k in self.ITEM_EMB_FEAT:# collect all data to numpy, then batch-convert batch_size =len(feature_array) emb_dim = self.ITEM_EMB_FEAT[k] seq_len =len(feature_array[0])# pre-allocate tensor batch_emb_data = np.zeros((batch_size, seq_len, emb_dim), dtype=np.float32)for i, seq inenumerate(feature_array):for j, item inenumerate(seq):if k in item: batch_emb_data[i, j]= item[k]# batch-convert and transfer to GPU tensor_feature = torch.from_numpy(batch_emb_data).to(self.dev) item_feat_list.append(self.emb_transform[k](tensor_feature))# merge features all_item_emb = torch.cat(item_feat_list, dim=2) all_item_emb = torch.relu(self.itemdnn(all_item_emb))if include_user: all_user_emb = torch.cat(user_feat_list, dim=2) all_user_emb = torch.relu(self.userdnn(all_user_emb)) seqs_emb = all_item_emb + all_user_emb else: seqs_emb = all_item_emb return seqs_emb deflog2feats(self, log_seqs, mask, seq_feature):""" Args: log_seqs: 序列ID mask: token类型掩码,1表示item token,2表示user token seq_feature: 序列特征list,每个元素为当前时刻的特征字典 Returns: seqs_emb: 序列的Embedding,形状为 [batch_size, maxlen, hidden_units] """ batch_size = log_seqs.shape[0] maxlen = log_seqs.shape[1] seqs = self.feat2emb(log_seqs, seq_feature, mask=mask, include_user=True) seqs *= self.item_emb.embedding_dim**0.5 poss = torch.arange(1, maxlen +1, device=self.dev).unsqueeze(0).expand(batch_size,-1).clone() poss *= log_seqs !=0 seqs += self.pos_emb(poss) seqs = self.emb_dropout(seqs) maxlen = seqs.shape[1] ones_matrix = torch.ones((maxlen, maxlen), dtype=torch.bool, device=self.dev) attention_mask_tril = torch.tril(ones_matrix) attention_mask_pad =(mask !=0).to(self.dev) attention_mask = attention_mask_tril.unsqueeze(0)& attention_mask_pad.unsqueeze(1)for i inrange(len(self.attention_layers)):if self.norm_first: x = self.attention_layernorms[i](seqs) mha_outputs, _ = self.attention_layers[i](x, x, x, attn_mask=attention_mask) seqs = seqs + mha_outputs seqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs))else: mha_outputs, _ = self.attention_layers[i](seqs, seqs, seqs, attn_mask=attention_mask) seqs = self.attention_layernorms[i](seqs + mha_outputs) seqs = self.forward_layernorms[i](seqs + self.forward_layers[i](seqs)) log_feats = self.last_layernorm(seqs)return log_feats defforward( self, user_item, pos_seqs, neg_seqs, mask, next_mask, next_action_type, seq_feature, pos_feature, neg_feature ):""" 训练时调用,计算正负样本的logits Args: user_item: 用户序列ID pos_seqs: 正样本序列ID neg_seqs: 负样本序列ID mask: token类型掩码,1表示item token,2表示user token next_mask: 下一个token类型掩码,1表示item token,2表示user token next_action_type: 下一个token动作类型,0表示曝光,1表示点击 seq_feature: 序列特征list,每个元素为当前时刻的特征字典 pos_feature: 正样本特征list,每个元素为当前时刻的特征字典 neg_feature: 负样本特征list,每个元素为当前时刻的特征字典 Returns: pos_logits: 正样本logits,形状为 [batch_size, maxlen] neg_logits: 负样本logits,形状为 [batch_size, maxlen] """ log_feats = self.log2feats(user_item, mask, seq_feature) loss_mask =(next_mask ==1).to(self.dev) pos_embs = self.feat2emb(pos_seqs, pos_feature, include_user=False) neg_embs = self.feat2emb(neg_seqs, neg_feature, include_user=False) pos_logits =(log_feats * pos_embs).sum(dim=-1) neg_logits =(log_feats * neg_embs).sum(dim=-1) pos_logits = pos_logits * loss_mask neg_logits = neg_logits * loss_mask return pos_logits, neg_logits defpredict(self, log_seqs, seq_feature, mask):""" 计算用户序列的表征 Args: log_seqs: 用户序列ID seq_feature: 序列特征list,每个元素为当前时刻的特征字典 mask: token类型掩码,1表示item token,2表示user token Returns: final_feat: 用户序列的表征,形状为 [batch_size, hidden_units] """ log_feats = self.log2feats(log_seqs, mask, seq_feature) final_feat = log_feats[:,-1,:]return final_feat defsave_item_emb(self, item_ids, retrieval_ids, feat_dict, save_path, batch_size=1024):""" 生成候选库item embedding,用于检索 Args: item_ids: 候选item ID(re-id形式) retrieval_ids: 候选item ID(检索ID,从0开始编号,检索脚本使用) feat_dict: 训练集所有item特征字典,key为特征ID,value为特征值 save_path: 保存路径 batch_size: 批次大小 """ all_embs =[]for start_idx in tqdm(range(0,len(item_ids), batch_size), desc="Saving item embeddings"): end_idx =min(start_idx + batch_size,len(item_ids)) item_seq = torch.tensor(item_ids[start_idx:end_idx], device=self.dev).unsqueeze(0) batch_feat =[]for i inrange(start_idx, end_idx): batch_feat.append(feat_dict[i]) batch_feat = np.array(batch_feat, dtype=object) batch_emb = self.feat2emb(item_seq,[batch_feat], include_user=False).squeeze(0) all_embs.append(batch_emb.detach().cpu().numpy().astype(np.float32))# 合并所有批次的结果并保存 final_ids = np.array(retrieval_ids, dtype=np.uint64).reshape(-1,1) final_embs = np.concatenate(all_embs, axis=0) save_emb(final_embs, Path(save_path,'embedding.fbin')) save_emb(final_ids, Path(save_path,'id.u64bin'))3. dataset.py - 数据处理
MyDataset - 训练数据集
- 处理用户行为序列数据,支持用户和物品交替出现的序列格式
- 实现高效的数据加载,使用文件偏移量进行随机访问
- 支持多种特征类型的 padding 和缺失值填充
- 实现负采样机制用于训练
MyTestDataset - 测试数据集
- 继承自训练数据集,专门用于推理阶段
- 处理冷启动问题(训练时未见过的特征值)
dataset.py 代码
import json import pickle import struct from pathlib import Path import numpy as np import torch from tqdm import tqdm classMyDataset(torch.utils.data.Dataset):""" 用户序列数据集 Args: data_dir: 数据文件目录 args: 全局参数 Attributes: data_dir: 数据文件目录 maxlen: 最大长度 item_feat_dict: 物品特征字典 mm_emb_ids: 激活的mm_emb特征ID mm_emb_dict: 多模态特征字典 itemnum: 物品数量 usernum: 用户数量 indexer_i_rev: 物品索引字典 (reid -> item_id) indexer_u_rev: 用户索引字典 (reid -> user_id) indexer: 索引字典 feature_default_value: 特征缺省值 feature_types: 特征类型,分为user和item的sparse, array, emb, continual类型 feat_statistics: 特征统计信息,包括user和item的特征数量 """def__init__(self, data_dir, args):""" 初始化数据集 """super().__init__() self.data_dir = Path(data_dir) self._load_data_and_offsets() self.maxlen = args.maxlen self.mm_emb_ids = args.mm_emb_id self.item_feat_dict = json.load(open(Path(data_dir,"item_feat_dict.json"),'r')) self.mm_emb_dict = load_mm_emb(Path(data_dir,"creative_emb"), self.mm_emb_ids)withopen(self.data_dir /'indexer.pkl','rb')as ff: indexer = pickle.load(ff) self.itemnum =len(indexer['i']) self.usernum =len(indexer['u']) self.indexer_i_rev ={v: k for k, v in indexer['i'].items()} self.indexer_u_rev ={v: k for k, v in indexer['u'].items()} self.indexer = indexer self.feature_default_value, self.feature_types, self.feat_statistics = self._init_feat_info()def_load_data_and_offsets(self):""" 加载用户序列数据和每一行的文件偏移量(预处理好的), 用于快速随机访问数据并I/O """ self.data_file =open(self.data_dir /"seq.jsonl",'rb')withopen(Path(self.data_dir,'seq_offsets.pkl'),'rb')as f: self.seq_offsets = pickle.load(f)def_load_user_data(self, uid):""" 从数据文件中加载单个用户的数据 Args: uid: 用户ID(reid) Returns: data: 用户序列数据,格式为[(user_id, item_id, user_feat, item_feat, action_type, timestamp)] """ self.data_file.seek(self.seq_offsets[uid]) line = self.data_file.readline() data = json.loads(line)return data def_random_neq(self, l, r, s):""" 生成一个不在序列s中的随机整数, 用于训练时的负采样 Args: l: 随机整数的最小值 r: 随机整数的最大值 s: 序列 Returns: t: 不在序列s中的随机整数 """ t = np.random.randint(l, r)while t in s orstr(t)notin self.item_feat_dict: t = np.random.randint(l, r)return t def__getitem__(self, uid):""" 获取单个用户的数据,并进行padding处理,生成模型需要的数据格式 Args: uid: 用户ID(reid) Returns: seq: 用户序列ID pos: 正样本ID(即下一个真实访问的item) neg: 负样本ID token_type: 用户序列类型,1表示item,2表示user next_token_type: 下一个token类型,1表示item,2表示user seq_feat: 用户序列特征,每个元素为字典,key为特征ID,value为特征值 pos_feat: 正样本特征,每个元素为字典,key为特征ID,value为特征值 neg_feat: 负样本特征,每个元素为字典,key为特征ID,value为特征值 """ user_sequence = self._load_user_data(uid)# 动态加载用户数据 ext_user_sequence =[]for record_tuple in user_sequence: u, i, user_feat, item_feat, action_type, _ = record_tuple if u and user_feat: ext_user_sequence.insert(0,(u, user_feat,2, action_type))if i and item_feat: ext_user_sequence.append((i, item_feat,1, action_type)) seq = np.zeros([self.maxlen +1], dtype=np.int32) pos = np.zeros([self.maxlen +1], dtype=np.int32) neg = np.zeros([self.maxlen +1], dtype=np.int32) token_type = np.zeros([self.maxlen +1], dtype=np.int32) next_token_type = np.zeros([self.maxlen +1], dtype=np.int32) next_action_type = np.zeros([self.maxlen +1], dtype=np.int32) seq_feat = np.empty([self.maxlen +1], dtype=object) pos_feat = np.empty([self.maxlen +1], dtype=object) neg_feat = np.empty([self.maxlen +1], dtype=object) nxt = ext_user_sequence[-1] idx = self.maxlen ts =set()for record_tuple in ext_user_sequence:if record_tuple[2]==1and record_tuple[0]: ts.add(record_tuple[0])# left-padding, 从后往前遍历,将用户序列填充到maxlen+1的长度for record_tuple inreversed(ext_user_sequence[:-1]): i, feat, type_, act_type = record_tuple next_i, next_feat, next_type, next_act_type = nxt feat = self.fill_missing_feat(feat, i) next_feat = self.fill_missing_feat(next_feat, next_i) seq[idx]= i token_type[idx]= type_ next_token_type[idx]= next_type if next_act_type isnotNone: next_action_type[idx]= next_act_type seq_feat[idx]= feat if next_type ==1and next_i !=0: pos[idx]= next_i pos_feat[idx]= next_feat neg_id = self._random_neq(1, self.itemnum +1, ts) neg[idx]= neg_id neg_feat[idx]= self.fill_missing_feat(self.item_feat_dict[str(neg_id)], neg_id) nxt = record_tuple idx -=1if idx ==-1:break seq_feat = np.where(seq_feat ==None, self.feature_default_value, seq_feat) pos_feat = np.where(pos_feat ==None, self.feature_default_value, pos_feat) neg_feat = np.where(neg_feat ==None, self.feature_default_value, neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat def__len__(self):""" 返回数据集长度,即用户数量 Returns: usernum: 用户数量 """returnlen(self.seq_offsets)def_init_feat_info(self):""" 初始化特征信息, 包括特征缺省值和特征类型 Returns: feat_default_value: 特征缺省值,每个元素为字典,key为特征ID,value为特征缺省值 feat_types: 特征类型,key为特征类型名称,value为包含的特征ID列表 """ feat_default_value ={} feat_statistics ={} feat_types ={} feat_types['user_sparse']=['103','104','105','109'] feat_types['item_sparse']=['100','117','111','118','101','102','119','120','114','112','121','115','122','116',] feat_types['item_array']=[] feat_types['user_array']=['106','107','108','110'] feat_types['item_emb']= self.mm_emb_ids feat_types['user_continual']=[] feat_types['item_continual']=[]for feat_id in feat_types['user_sparse']: feat_default_value[feat_id]=0 feat_statistics[feat_id]=len(self.indexer['f'][feat_id])for feat_id in feat_types['item_sparse']: feat_default_value[feat_id]=0 feat_statistics[feat_id]=len(self.indexer['f'][feat_id])for feat_id in feat_types['item_array']: feat_default_value[feat_id]=[0] feat_statistics[feat_id]=len(self.indexer['f'][feat_id])for feat_id in feat_types['user_array']: feat_default_value[feat_id]=[0] feat_statistics[feat_id]=len(self.indexer['f'][feat_id])for feat_id in feat_types['user_continual']: feat_default_value[feat_id]=0for feat_id in feat_types['item_continual']: feat_default_value[feat_id]=0for feat_id in feat_types['item_emb']: feat_default_value[feat_id]= np.zeros(list(self.mm_emb_dict[feat_id].values())[0].shape[0], dtype=np.float32 )return feat_default_value, feat_types, feat_statistics deffill_missing_feat(self, feat, item_id):""" 对于原始数据中缺失的特征进行填充缺省值 Args: feat: 特征字典 item_id: 物品ID Returns: filled_feat: 填充后的特征字典 """if feat ==None: feat ={} filled_feat ={}for k in feat.keys(): filled_feat[k]= feat[k] all_feat_ids =[]for feat_type in self.feature_types.values(): all_feat_ids.extend(feat_type) missing_fields =set(all_feat_ids)-set(feat.keys())for feat_id in missing_fields: filled_feat[feat_id]= self.feature_default_value[feat_id]for feat_id in self.feature_types['item_emb']:if item_id !=0and self.indexer_i_rev[item_id]in self.mm_emb_dict[feat_id]:iftype(self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]])== np.ndarray: filled_feat[feat_id]= self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]return filled_feat @staticmethoddefcollate_fn(batch):""" Args: batch: 多个__getitem__返回的数据 Returns: seq: 用户序列ID, torch.Tensor形式 pos: 正样本ID, torch.Tensor形式 neg: 负样本ID, torch.Tensor形式 token_type: 用户序列类型, torch.Tensor形式 next_token_type: 下一个token类型, torch.Tensor形式 seq_feat: 用户序列特征, list形式 pos_feat: 正样本特征, list形式 neg_feat: 负样本特征, list形式 """ seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat =zip(*batch) seq = torch.from_numpy(np.array(seq)) pos = torch.from_numpy(np.array(pos)) neg = torch.from_numpy(np.array(neg)) token_type = torch.from_numpy(np.array(token_type)) next_token_type = torch.from_numpy(np.array(next_token_type)) next_action_type = torch.from_numpy(np.array(next_action_type)) seq_feat =list(seq_feat) pos_feat =list(pos_feat) neg_feat =list(neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat classMyTestDataset(MyDataset):""" 测试数据集 """def__init__(self, data_dir, args):super().__init__(data_dir, args)def_load_data_and_offsets(self): self.data_file =open(self.data_dir /"predict_seq.jsonl",'rb')withopen(Path(self.data_dir,'predict_seq_offsets.pkl'),'rb')as f: self.seq_offsets = pickle.load(f)def_process_cold_start_feat(self, feat):""" 处理冷启动特征。训练集未出现过的特征value为字符串,默认转换为0.可设计替换为更好的方法。 """ processed_feat ={}for feat_id, feat_value in feat.items():iftype(feat_value)==list: value_list =[]for v in feat_value:iftype(v)==str: value_list.append(0)else: value_list.append(v) processed_feat[feat_id]= value_list eliftype(feat_value)==str: processed_feat[feat_id]=0else: processed_feat[feat_id]= feat_value return processed_feat def__getitem__(self, uid):""" 获取单个用户的数据,并进行padding处理,生成模型需要的数据格式 Args: uid: 用户在self.data_file中储存的行号 Returns: seq: 用户序列ID token_type: 用户序列类型,1表示item,2表示user seq_feat: 用户序列特征,每个元素为字典,key为特征ID,value为特征值 user_id: user_id eg. user_xxxxxx ,便于后面对照答案 """ user_sequence = self._load_user_data(uid)# 动态加载用户数据 ext_user_sequence =[]for record_tuple in user_sequence: u, i, user_feat, item_feat, _, _ = record_tuple if u:iftype(u)==str:# 如果是字符串,说明是user_id user_id = u else:# 如果是int,说明是re_id user_id = self.indexer_u_rev[u]if u and user_feat:iftype(u)==str: u =0if user_feat: user_feat = self._process_cold_start_feat(user_feat) ext_user_sequence.insert(0,(u, user_feat,2))if i and item_feat:# 序列对于训练时没见过的item,不会直接赋0,而是保留creative_id,creative_id远大于训练时的itemnumif i > self.itemnum: i =0if item_feat: item_feat = self._process_cold_start_feat(item_feat) ext_user_sequence.append((i, item_feat,1)) seq = np.zeros([self.maxlen +1], dtype=np.int32) token_type = np.zeros([self.maxlen +1], dtype=np.int32) seq_feat = np.empty([self.maxlen +1], dtype=object) idx = self.maxlen ts =set()for record_tuple in ext_user_sequence:if record_tuple[2]==1and record_tuple[0]: ts.add(record_tuple[0])for record_tuple inreversed(ext_user_sequence[:-1]): i, feat, type_ = record_tuple feat = self.fill_missing_feat(feat, i) seq[idx]= i token_type[idx]= type_ seq_feat[idx]= feat idx -=1if idx ==-1:break seq_feat = np.where(seq_feat ==None, self.feature_default_value, seq_feat)return seq, token_type, seq_feat, user_id def__len__(self):""" Returns: len(self.seq_offsets): 用户数量 """withopen(Path(self.data_dir,'predict_seq_offsets.pkl'),'rb')as f: temp = pickle.load(f)returnlen(temp)@staticmethoddefcollate_fn(batch):""" 将多个__getitem__返回的数据拼接成一个batch Args: batch: 多个__getitem__返回的数据 Returns: seq: 用户序列ID, torch.Tensor形式 token_type: 用户序列类型, torch.Tensor形式 seq_feat: 用户序列特征, list形式 user_id: user_id, str """ seq, token_type, seq_feat, user_id =zip(*batch) seq = torch.from_numpy(np.array(seq)) token_type = torch.from_numpy(np.array(token_type)) seq_feat =list(seq_feat)return seq, token_type, seq_feat, user_id defsave_emb(emb, save_path):""" 将Embedding保存为二进制文件 Args: emb: 要保存的Embedding,形状为 [num_points, num_dimensions] save_path: 保存路径 """ num_points = emb.shape[0]# 数据点数量 num_dimensions = emb.shape[1]# 向量的维度print(f'saving {save_path}')withopen(Path(save_path),'wb')as f: f.write(struct.pack('II', num_points, num_dimensions)) emb.tofile(f)defload_mm_emb(mm_path, feat_ids):""" 加载多模态特征Embedding Args: mm_path: 多模态特征Embedding路径 feat_ids: 要加载的多模态特征ID列表 Returns: mm_emb_dict: 多模态特征Embedding字典,key为特征ID,value为特征Embedding字典(key为item ID,value为Embedding) """ SHAPE_DICT ={"81":32,"82":1024,"83":3584,"84":4096,"85":3584,"86":3584} mm_emb_dict ={}for feat_id in tqdm(feat_ids, desc='Loading mm_emb'): shape = SHAPE_DICT[feat_id] emb_dict ={}if feat_id !='81':try: base_path = Path(mm_path,f'emb_{feat_id}_{shape}')for json_file in base_path.glob('*.json'):withopen(json_file,'r', encoding='utf-8')asfile:for line infile: data_dict_origin = json.loads(line.strip()) insert_emb = data_dict_origin['emb']ifisinstance(insert_emb,list): insert_emb = np.array(insert_emb, dtype=np.float32) data_dict ={data_dict_origin['anonymous_cid']: insert_emb} emb_dict.update(data_dict)except Exception as e:print(f"transfer error: {e}")if feat_id =='81':withopen(Path(mm_path,f'emb_{feat_id}_{shape}.pkl'),'rb')as f: emb_dict = pickle.load(f) mm_emb_dict[feat_id]= emb_dict print(f'Loaded #{feat_id} mm_emb')return mm_emb_dict 4. model_rqvae.py - 多模态特征压缩
实现了 RQ-VAE(Residual Quantized Variational AutoEncoder)框架,用于将高维多模态 embedding 转换为离散的语义 ID:
核心组件:
RQEncoder/RQDecoder:编码器和解码器VQEmbedding:向量量化模块,支持 K-means 初始化RQ:残差量化器,实现多级量化RQVAE:完整的 RQ-VAE 模型
量化方法:
- 支持标准 K-means 和平衡 K-means 聚类
- 使用余弦距离或 L2 距离进行向量量化
- 通过残差量化实现更精确的特征表示
model_rqvae.py 代码
""" 选手可参考以下流程,使用提供的 RQ-VAE 框架代码将多模态emb数据转换为Semantic Id: 1. 使用 MmEmbDataset 读取不同特征 ID 的多模态emb数据. 2. 训练 RQ-VAE 模型, 训练完成后将数据转换为Semantic Id. 3. 参照 Item Sparse 特征格式处理Semantic Id,作为新特征加入Baseline模型训练. """import torch import torch.nn.functional as F from sklearn.cluster import KMeans # class MmEmbDataset(torch.utils.data.Dataset):# """# Build Dataset for RQ-VAE Training# Args:# data_dir = os.environ.get('TRAIN_DATA_PATH')# feature_id = MM emb ID# """# def __init__(self, data_dir, feature_id):# super().__init__()# self.data_dir = Path(data_dir)# self.mm_emb_id = [feature_id]# self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_id)# self.mm_emb = self.mm_emb_dict[self.mm_emb_id[0]]# self.tid_list, self.emb_list = list(self.mm_emb.keys()), list(self.mm_emb.values())# self.emb_list = [torch.tensor(emb, dtype=torch.float32) for emb in self.emb_list]# assert len(self.tid_list) == len(self.emb_list)# self.item_cnt = len(self.tid_list)# def __getitem__(self, index):# tid = torch.tensor(self.tid_list[index], dtype=torch.long)# emb = self.emb_list[index]# return tid, emb# def __len__(self):# return self.item_cnt# @staticmethod# def collate_fn(batch):# tid, emb = zip(*batch)# tid_batch, emb_batch = torch.stack(tid, dim=0), torch.stack(emb, dim=0)# return tid_batch, emb_batch## Kmeansdefkmeans(data, n_clusters, kmeans_iters):""" auto init: n_init = 10 if n_clusters <= 10 else 1 """ km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")# sklearn only support cpu data_cpu = data.detach().cpu() np_data = data_cpu.numpy() km.fit(np_data)return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)## Balanced KmeansclassBalancedKmeans(torch.nn.Module):def__init__(self, num_clusters:int, kmeans_iters:int, tolerance:float, device:str):super().__init__() self.num_clusters = num_clusters self.kmeans_iters = kmeans_iters self.tolerance = tolerance self.device = device self._codebook =Nonedef_compute_distances(self, data):return torch.cdist(data, self._codebook)def_assign_clusters(self, dist): samples_cnt = dist.shape[0] samples_labels = torch.zeros(samples_cnt, dtype=torch.long, device=self.device) clusters_cnt = torch.zeros(self.num_clusters, dtype=torch.long, device=self.device) sorted_indices = torch.argsort(dist, dim=-1)for i inrange(samples_cnt):for j inrange(self.num_clusters): cluster_idx = sorted_indices[i, j]if clusters_cnt[cluster_idx]< samples_cnt // self.num_clusters: samples_labels[i]= cluster_idx clusters_cnt[cluster_idx]+=1breakreturn samples_labels def_update_codebook(self, data, samples_labels): _new_codebook =[]for i inrange(self.num_clusters): cluster_data = data[samples_labels == i]iflen(cluster_data)>0: _new_codebook.append(cluster_data.mean(dim=0))else: _new_codebook.append(self._codebook[i])return torch.stack(_new_codebook)deffit(self, data): num_emb, codebook_emb_dim = data.shape data = data.to(self.device)# initialize codebook indices = torch.randperm(num_emb)[: self.num_clusters] self._codebook = data[indices].clone()for _ inrange(self.kmeans_iters): dist = self._compute_distances(data) samples_labels = self._assign_clusters(dist) _new_codebook = self._update_codebook(data, samples_labels)if torch.norm(_new_codebook - self._codebook)< self.tolerance:break self._codebook = _new_codebook return self._codebook, samples_labels defpredict(self, data): data = data.to(self.device) dist = self._compute_distances(data) samples_labels = self._assign_clusters(dist)return samples_labels ## Base RQVAEclassRQEncoder(torch.nn.Module):def__init__(self, input_dim:int, hidden_channels:list, latent_dim:int):super().__init__() self.stages = torch.nn.ModuleList() in_dim = input_dim for out_dim in hidden_channels: stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU()) self.stages.append(stage) in_dim = out_dim self.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, latent_dim), torch.nn.ReLU()))defforward(self, x):for stage in self.stages: x = stage(x)return x classRQDecoder(torch.nn.Module):def__init__(self, latent_dim:int, hidden_channels:list, output_dim:int):super().__init__() self.stages = torch.nn.ModuleList() in_dim = latent_dim for out_dim in hidden_channels: stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU()) self.stages.append(stage) in_dim = out_dim self.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, output_dim), torch.nn.ReLU()))defforward(self, x):for stage in self.stages: x = stage(x)return x ## Generate semantic idclassVQEmbedding(torch.nn.Embedding):def__init__( self, num_clusters, codebook_emb_dim:int, kmeans_method:str, kmeans_iters:int, distances_method:str, device:str,):super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim) self.num_clusters = num_clusters self.codebook_emb_dim = codebook_emb_dim self.kmeans_method = kmeans_method self.kmeans_iters = kmeans_iters self.distances_method = distances_method self.device = device def_create_codebook(self, data):if self.kmeans_method =='kmeans': _codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)elif self.kmeans_method =='bkmeans': BKmeans = BalancedKmeans( num_clusters=self.num_clusters, kmeans_iters=self.kmeans_iters, tolerance=1e-4, device=self.device ) _codebook, _ = BKmeans.fit(data)else: _codebook = torch.randn(self.num_clusters, self.codebook_emb_dim) _codebook = _codebook.to(self.device)assert _codebook.shape ==(self.num_clusters, self.codebook_emb_dim) self.codebook = torch.nn.Parameter(_codebook)@torch.no_grad()def_compute_distances(self, data): _codebook_t = self.codebook.t()assert _codebook_t.shape ==(self.codebook_emb_dim, self.num_clusters)assert data.shape[-1]== self.codebook_emb_dim if self.distances_method =='cosine': data_norm = F.normalize(data, p=2, dim=-1) _codebook_t_norm = F.normalize(_codebook_t, p=2, dim=0) distances =1- torch.mm(data_norm, _codebook_t_norm)# l2else: data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True) _codebook_t_norm_sq = _codebook_t.pow(2).sum(dim=0, keepdim=True) distances = torch.addmm(data_norm_sq + _codebook_t_norm_sq, data, _codebook_t, beta=1.0, alpha=-2.0)return distances @torch.no_grad()def_create_semantic_id(self, data): distances = self._compute_distances(data) _semantic_id = torch.argmin(distances, dim=-1)return _semantic_id def_update_emb(self, _semantic_id): update_emb =super().forward(_semantic_id)return update_emb defforward(self, data): self._create_codebook(data) _semantic_id = self._create_semantic_id(data) update_emb = self._update_emb(_semantic_id)return update_emb, _semantic_id ## Residual QuantizerclassRQ(torch.nn.Module):""" Args: num_codebooks, codebook_size, codebook_emb_dim -> Build codebook if_shared_codebook -> If use same codebook kmeans_method, kmeans_iters -> Initialize codebook distances_method -> Generate semantic_id loss_beta -> Calculate RQ-VAE loss """def__init__( self, num_codebooks:int, codebook_size:list, codebook_emb_dim, shared_codebook:bool, kmeans_method, kmeans_iters, distances_method, loss_beta:float, device:str,):super().__init__() self.num_codebooks = num_codebooks self.codebook_size = codebook_size assertlen(self.codebook_size)== self.num_codebooks self.codebook_emb_dim = codebook_emb_dim self.shared_codebook = shared_codebook self.kmeans_method = kmeans_method self.kmeans_iters = kmeans_iters self.distances_method = distances_method self.loss_beta = loss_beta self.device = device if self.shared_codebook: self.vqmodules = torch.nn.ModuleList([ VQEmbedding( self.codebook_size[0], self.codebook_emb_dim, self.kmeans_method, self.kmeans_iters, self.distances_method, self.device,)for _ inrange(self.num_codebooks)])else: self.vqmodules = torch.nn.ModuleList([ VQEmbedding( self.codebook_size[idx], self.codebook_emb_dim, self.kmeans_method, self.kmeans_iters, self.distances_method, self.device,)for idx inrange(self.num_codebooks)])defquantize(self, data):""" Exa: i-th quantize: input[i]( i.e. res[i-1] ) = VQ[i] + res[i] vq_emb_list: [vq1, vq1+vq2, ...] res_emb_list: [res1, res2, ...] semantic_id_list: [vq1_sid, vq2_sid, ...] Returns: vq_emb_list[0] -> [batch_size, codebook_emb_dim] semantic_id_list -> [batch_size, num_codebooks] """ res_emb = data.detach().clone() vq_emb_list, res_emb_list =[],[] semantic_id_list =[] vq_emb_aggre = torch.zeros_like(data)for i inrange(self.num_codebooks): vq_emb, _semantic_id = self.vqmodules[i](res_emb) res_emb -= vq_emb vq_emb_aggre += vq_emb res_emb_list.append(res_emb) vq_emb_list.append(vq_emb_aggre) semantic_id_list.append(_semantic_id.unsqueeze(dim=-1)) semantic_id_list = torch.cat(semantic_id_list, dim=-1)return vq_emb_list, res_emb_list, semantic_id_list def_rqvae_loss(self, vq_emb_list, res_emb_list): rqvae_loss_list =[]for idx, quant inenumerate(vq_emb_list):# stop gradient loss1 =(res_emb_list[idx].detach()- quant).pow(2.0).mean() loss2 =(res_emb_list[idx]- quant.detach()).pow(2.0).mean() partial_loss = loss1 + self.loss_beta * loss2 rqvae_loss_list.append(partial_loss) rqvae_loss = torch.sum(torch.stack(rqvae_loss_list))return rqvae_loss defforward(self, data): vq_emb_list, res_emb_list, semantic_id_list = self.quantize(data) rqvae_loss = self._rqvae_loss(vq_emb_list, res_emb_list)return vq_emb_list, semantic_id_list, rqvae_loss classRQVAE(torch.nn.Module):def__init__( self, input_dim:int, hidden_channels:list, latent_dim:int, num_codebooks:int, codebook_size:list, shared_codebook:bool, kmeans_method, kmeans_iters, distances_method, loss_beta:float, device:str,):super().__init__() self.encoder = RQEncoder(input_dim, hidden_channels, latent_dim).to(device) self.decoder = RQDecoder(latent_dim, hidden_channels[::-1], input_dim).to(device) self.rq = RQ( num_codebooks, codebook_size, latent_dim, shared_codebook, kmeans_method, kmeans_iters, distances_method, loss_beta, device,).to(device)defencode(self, x):return self.encoder(x)defdecode(self, z_vq):ifisinstance(z_vq,list): z_vq = z_vq[-1]return self.decoder(z_vq)defcompute_loss(self, x_hat, x_gt, rqvae_loss): recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean") total_loss = recon_loss + rqvae_loss return recon_loss, rqvae_loss, total_loss def_get_codebook(self, x_gt): z_e = self.encode(x_gt) vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)return semantic_id_list defforward(self, x_gt): z_e = self.encode(x_gt) vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e) x_hat = self.decode(vq_emb_list) recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss 5. run.sh - 运行脚本
简单的 bash 脚本,用于启动训练程序。
run.sh 代码
#!/bin/bash# show ${RUNTIME_SCRIPT_DIR}echo${RUNTIME_SCRIPT_DIR}# enter train workspacecd${RUNTIME_SCRIPT_DIR}# write your code below python -u main.py 技术特点
- 高效注意力机制:使用 Flash Attention 优化计算效率
- 多模态融合:支持文本、图像等多种模态的 embedding 特征
- 特征工程:支持稀疏、密集、数组等多种特征类型
- 序列建模:同时建模用户和物品的交互序列
- 可扩展性:支持大规模物品库的 embedding 保存和检索
数据流程
- 训练阶段:读取用户序列 → 特征 embedding → Transformer 编码 → 计算正负样本 loss
- 推理阶段:生成用户表征 → 保存物品 embedding → 进行向量检索推荐
- 多模态处理:原始 embedding → RQ-VAE 压缩 → 语义 ID → 作为新特征加入模型