项目概述
2025 腾讯广告算法大赛 Baseline,一个简单的序列推荐系统,主要用于建模用户和物品的交互序列,并利用多模态特征(文本、图像等 embedding)来提升推荐效果。
核心文件功能
1. main.py - 主训练脚本
- 负责模型训练的整体流程
- 包含参数解析、数据加载、模型初始化、训练循环等
- 支持断点续训和仅推理模式
- 使用 TensorBoard 记录训练日志
main.py 代码
import argparse
import json
import os
import time
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dataset import MyDataset
from model import BaselineModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=101, type=int)
parser.add_argument('--hidden_units', default=32, type=int)
parser.add_argument('--num_blocks', default=1, type=int)
parser.add_argument('--num_epochs', default=3, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.2, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--inference_only', action='store_true')
parser.add_argument('--state_dict_path', default=None, type=str)
parser.add_argument('--norm_first', action='store_true')
parser.add_argument('--mm_emb_id', nargs='+', default=['81'], type=str, choices=[str(s) for s in range(81, 87)])
args = parser.parse_args()
return args
if __name__ == '__main__':
Path(os.environ.get('TRAIN_LOG_PATH')).mkdir(parents=True, exist_ok=True)
Path(os.environ.get('TRAIN_TF_EVENTS_PATH')).mkdir(parents=True, exist_ok=True)
log_file = open(Path(os.environ.get('TRAIN_LOG_PATH'), 'train.log'), 'w')
writer = SummaryWriter(os.environ.get('TRAIN_TF_EVENTS_PATH'))
data_path = os.environ.get('TRAIN_DATA_PATH')
args = get_args()
dataset = MyDataset(data_path, args)
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn
)
valid_loader = DataLoader(
valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn
)
usernum, itemnum = dataset.usernum, dataset.itemnum
feat_statistics, feat_types = dataset.feat_statistics, dataset.feature_types
model = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)
for name, param in model.named_parameters():
try:
torch.nn.init.xavier_normal_(param.data)
except Exception:
pass
model.pos_emb.weight.data[0, :] = 0
model.item_emb.weight.data[0, :] = 0
model.user_emb.weight.data[0, :] = 0
for k in model.sparse_emb:
model.sparse_emb[k].weight.data[0, :] = 0
epoch_start_idx = 1
if args.state_dict_path is not None:
try:
model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
tail = args.state_dict_path[args.state_dict_path.find('epoch=')+6:]
epoch_start_idx = int(tail[:tail.find('.')])+1
except:
print('failed loading state_dicts, pls check file path: ', end="")
print(args.state_dict_path)
raise RuntimeError('failed loading state_dicts, pls check file path!')
bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
best_val_ndcg, best_val_hr = 0.0, 0.0
best_test_ndcg, best_test_hr = 0.0, 0.0
T = 0.0
t0 = time.time()
global_step = 0
print("Start training")
for epoch in range(epoch_start_idx, args.num_epochs + 1):
model.train()
if args.inference_only:
break
for step, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch
seq = seq.to(args.device)
pos = pos.to(args.device)
neg = neg.to(args.device)
pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
optimizer.zero_grad()
indices = np.where(next_token_type == 1)
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
log_json = json.dumps({'global_step': global_step, 'loss': loss.item(), 'epoch': epoch, 'time': time.time()})
log_file.write(log_json + '\n')
log_file.flush()
print(log_json)
writer.add_scalar('Loss/train', loss.item(), global_step)
global_step += 1
for param in model.item_emb.parameters():
loss += args.l2_emb * torch.norm(param)
loss.backward()
optimizer.step()
model.eval()
valid_loss_sum = 0
for step, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch
seq = seq.to(args.device)
pos = pos.to(args.device)
neg = neg.to(args.device)
pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
indices = np.where(next_token_type == 1)
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
valid_loss_sum += loss.item()
valid_loss_sum /= len(valid_loader)
save_dir = Path(os.environ.get('TRAIN_CKPT_PATH'), f"global_step{global_step}.valid_loss={valid_loss_sum:.4f}")
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), save_dir / "model.pt")
print("Done")
writer.close()
log_file.close()
2. model.py - 核心模型实现
BaselineModel - 主推荐模型
基于 Transformer 的序列推荐模型,具有以下特点:
模型架构:
- 使用
FlashMultiHeadAttention 实现高效的多头注意力机制
- 采用
PointWiseFeedForward 作为前馈网络
- 支持多种特征类型:稀疏特征、数组特征、连续特征、多模态 embedding 特征
特征处理:
- 用户特征:稀疏特征 (103,104,105,109)、数组特征 (106,107,108,110)
- 物品特征:稀疏特征 (100,117,111 等)、多模态 embedding 特征 (81-86)
- 通过
feat2emb 方法将不同类型特征转换为统一的 embedding 表示
核心方法:
log2feats:将用户序列转换为特征表示
forward:训练时计算正负样本的 logits
predict:推理时生成用户表征
save_item_emb:保存物品 embedding 用于检索
model.py 代码
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dataset import save_emb
class FlashMultiHeadAttention(torch.nn.Module):
def __init__(self, hidden_units, num_heads, dropout_rate):
super().__init__()
self.hidden_units = hidden_units
self.num_heads = num_heads
self.head_dim = hidden_units // num_heads
self.dropout_rate = dropout_rate
assert hidden_units % num_heads == 0, "hidden_units must be divisible by num_heads"
self.q_linear = torch.nn.Linear(hidden_units, hidden_units)
self.k_linear = torch.nn.Linear(hidden_units, hidden_units)
self.v_linear = torch.nn.Linear(hidden_units, hidden_units)
self.out_linear = torch.nn.Linear(hidden_units, hidden_units)
def forward(self, query, key, value, attn_mask=None):
batch_size, seq_len, _ = query.size()
Q = self.q_linear(query)
K = self.k_linear(key)
V = self.v_linear(value)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, .num_heads, .head_dim).transpose(, )
V = V.view(batch_size, seq_len, .num_heads, .head_dim).transpose(, )
(F, ):
attn_output = F.scaled_dot_product_attention(
Q, K, V, dropout_p=.dropout_rate .training ,
attn_mask=attn_mask.unsqueeze()
)
:
scale = (.head_dim) ** -
scores = torch.matmul(Q, K.transpose(-, -)) * scale
attn_mask :
scores.masked_fill_(attn_mask.unsqueeze().logical_not(), ())
attn_weights = F.softmax(scores, dim=-)
attn_weights = F.dropout(attn_weights, p=.dropout_rate, training=.training)
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(, ).contiguous().view(batch_size, seq_len, .hidden_units)
output = .out_linear(attn_output)
output,
(torch.nn.Module):
():
().__init__()
.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=)
.dropout1 = torch.nn.Dropout(p=dropout_rate)
.relu = torch.nn.ReLU()
.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=)
.dropout2 = torch.nn.Dropout(p=dropout_rate)
():
outputs = .dropout2(.conv2(.relu(.dropout1(.conv1(inputs.transpose(-, -))))))
outputs = outputs.transpose(-, -)
outputs
(torch.nn.Module):
():
.user_num = user_num
.item_num = item_num
.dev = args.device
.norm_first = args.norm_first
.maxlen = args.maxlen
.item_emb = torch.nn.Embedding(.item_num + , args.hidden_units, padding_idx=)
.user_emb = torch.nn.Embedding(.user_num + , args.hidden_units, padding_idx=)
.pos_emb = torch.nn.Embedding( * args.maxlen + , args.hidden_units, padding_idx=)
.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
.sparse_emb = torch.nn.ModuleDict()
.emb_transform = torch.nn.ModuleDict()
.attention_layernorms = torch.nn.ModuleList()
.attention_layers = torch.nn.ModuleList()
.forward_layernorms = torch.nn.ModuleList()
.forward_layers = torch.nn.ModuleList()
._init_feat_info(feat_statistics, feat_types)
userdim = args.hidden_units * ((.USER_SPARSE_FEAT) + + (.USER_ARRAY_FEAT)) + (.USER_CONTINUAL_FEAT)
itemdim = (args.hidden_units * ((.ITEM_SPARSE_FEAT) + + (.ITEM_ARRAY_FEAT)) + (.ITEM_CONTINUAL_FEAT) + args.hidden_units * (.ITEM_EMB_FEAT))
.userdnn = torch.nn.Linear(userdim, args.hidden_units)
.itemdnn = torch.nn.Linear(itemdim, args.hidden_units)
.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=)
_ (args.num_blocks):
new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=)
.attention_layernorms.append(new_attn_layernorm)
new_attn_layer = FlashMultiHeadAttention(args.hidden_units, args.num_heads, args.dropout_rate)
.attention_layers.append(new_attn_layer)
new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=)
.forward_layernorms.append(new_fwd_layernorm)
new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)
.forward_layers.append(new_fwd_layer)
k .USER_SPARSE_FEAT:
.sparse_emb[k] = torch.nn.Embedding(.USER_SPARSE_FEAT[k] + , args.hidden_units, padding_idx=)
k .ITEM_SPARSE_FEAT:
.sparse_emb[k] = torch.nn.Embedding(.ITEM_SPARSE_FEAT[k] + , args.hidden_units, padding_idx=)
k .ITEM_ARRAY_FEAT:
.sparse_emb[k] = torch.nn.Embedding(.ITEM_ARRAY_FEAT[k] + , args.hidden_units, padding_idx=)
k .USER_ARRAY_FEAT:
.sparse_emb[k] = torch.nn.Embedding(.USER_ARRAY_FEAT[k] + , args.hidden_units, padding_idx=)
k .ITEM_EMB_FEAT:
.emb_transform[k] = torch.nn.Linear(.ITEM_EMB_FEAT[k], args.hidden_units)
():
.USER_SPARSE_FEAT = {k: feat_statistics[k] k feat_types[]}
.USER_CONTINUAL_FEAT = feat_types[]
.ITEM_SPARSE_FEAT = {k: feat_statistics[k] k feat_types[]}
.ITEM_CONTINUAL_FEAT = feat_types[]
.USER_ARRAY_FEAT = {k: feat_statistics[k] k feat_types[]}
.ITEM_ARRAY_FEAT = {k: feat_statistics[k] k feat_types[]}
EMB_SHAPE_DICT = {: , : , : , : , : , : }
.ITEM_EMB_FEAT = {k: EMB_SHAPE_DICT[k] k feat_types[]}
():
batch_size = (seq_feature)
k .ITEM_ARRAY_FEAT k .USER_ARRAY_FEAT:
max_array_len =
max_seq_len =
i (batch_size):
seq_data = [item[k] item seq_feature[i]]
max_seq_len = (max_seq_len, (seq_data))
max_array_len = (max_array_len, ((item_data) item_data seq_data))
batch_data = np.zeros((batch_size, max_seq_len, max_array_len), dtype=np.int64)
i (batch_size):
seq_data = [item[k] item seq_feature[i]]
j, item_data (seq_data):
actual_len = ((item_data), max_array_len)
batch_data[i, j, :actual_len] = item_data[:actual_len]
torch.from_numpy(batch_data).to(.dev)
:
max_seq_len = ((seq_feature[i]) i (batch_size))
batch_data = np.zeros((batch_size, max_seq_len), dtype=np.int64)
i (batch_size):
seq_data = [item[k] item seq_feature[i]]
batch_data[i] = seq_data
torch.from_numpy(batch_data).to(.dev)
():
seq = seq.to(.dev)
include_user:
user_mask = (mask == ).to(.dev)
item_mask = (mask == ).to(.dev)
user_embedding = .user_emb(user_mask * seq)
item_embedding = .item_emb(item_mask * seq)
item_feat_list = [item_embedding]
user_feat_list = [user_embedding]
:
item_embedding = .item_emb(seq)
item_feat_list = [item_embedding]
all_feat_types = [
(.ITEM_SPARSE_FEAT, , item_feat_list),
(.ITEM_ARRAY_FEAT, , item_feat_list),
(.ITEM_CONTINUAL_FEAT, , item_feat_list),
]
include_user:
all_feat_types.extend([
(.USER_SPARSE_FEAT, , user_feat_list),
(.USER_ARRAY_FEAT, , user_feat_list),
(.USER_CONTINUAL_FEAT, , user_feat_list),
])
feat_dict, feat_type, feat_list all_feat_types:
feat_dict:
k feat_dict:
tensor_feature = .feat2tensor(feature_array, k)
feat_type.endswith():
feat_list.append(.sparse_emb[k](tensor_feature))
feat_type.endswith():
feat_list.append(.sparse_emb[k](tensor_feature).())
feat_type.endswith():
feat_list.append(tensor_feature.unsqueeze())
k .ITEM_EMB_FEAT:
batch_size = (feature_array)
emb_dim = .ITEM_EMB_FEAT[k]
seq_len = (feature_array[])
batch_emb_data = np.zeros((batch_size, seq_len, emb_dim), dtype=np.float32)
i, seq (feature_array):
j, item (seq):
k item:
batch_emb_data[i, j] = item[k]
tensor_feature = torch.from_numpy(batch_emb_data).to(.dev)
item_feat_list.append(.emb_transform[k](tensor_feature))
all_item_emb = torch.cat(item_feat_list, dim=)
all_item_emb = torch.relu(.itemdnn(all_item_emb))
include_user:
all_user_emb = torch.cat(user_feat_list, dim=)
all_user_emb = torch.relu(.userdnn(all_user_emb))
seqs_emb = all_item_emb + all_user_emb
:
seqs_emb = all_item_emb
seqs_emb
():
batch_size = log_seqs.shape[]
maxlen = log_seqs.shape[]
seqs = .feat2emb(log_seqs, seq_feature, mask=mask, include_user=)
seqs *= .item_emb.embedding_dim **
poss = torch.arange(, maxlen + , device=.dev).unsqueeze().expand(batch_size, -).clone()
poss *= log_seqs !=
seqs += .pos_emb(poss)
seqs = .emb_dropout(seqs)
maxlen = seqs.shape[]
ones_matrix = torch.ones((maxlen, maxlen), dtype=torch., device=.dev)
attention_mask_tril = torch.tril(ones_matrix)
attention_mask_pad = (mask != ).to(.dev)
attention_mask = attention_mask_tril.unsqueeze() & attention_mask_pad.unsqueeze()
i ((.attention_layers)):
.norm_first:
x = .attention_layernorms[i](seqs)
mha_outputs, _ = .attention_layers[i](x, x, x, attn_mask=attention_mask)
seqs = seqs + mha_outputs
seqs = seqs + .forward_layers[i](.forward_layernorms[i](seqs))
:
mha_outputs, _ = .attention_layers[i](seqs, seqs, seqs, attn_mask=attention_mask)
seqs = .attention_layernorms[i](seqs + mha_outputs)
seqs = .forward_layernorms[i](seqs + .forward_layers[i](seqs))
log_feats = .last_layernorm(seqs)
log_feats
():
log_feats = .log2feats(user_item, mask, seq_feature)
loss_mask = (next_mask == ).to(.dev)
pos_embs = .feat2emb(pos_seqs, pos_feature, include_user=)
neg_embs = .feat2emb(neg_seqs, neg_feature, include_user=)
pos_logits = (log_feats * pos_embs).(dim=-)
neg_logits = (log_feats * neg_embs).(dim=-)
pos_logits = pos_logits * loss_mask
neg_logits = neg_logits * loss_mask
pos_logits, neg_logits
():
log_feats = .log2feats(log_seqs, mask, seq_feature)
final_feat = log_feats[:, -, :]
final_feat
():
all_embs = []
start_idx tqdm((, (item_ids), batch_size), desc=):
end_idx = (start_idx + batch_size, (item_ids))
item_seq = torch.tensor(item_ids[start_idx:end_idx], device=.dev).unsqueeze()
batch_feat = []
i (start_idx, end_idx):
batch_feat.append(feat_dict[i])
batch_feat = np.array(batch_feat, dtype=)
batch_emb = .feat2emb(item_seq, [batch_feat], include_user=).squeeze()
all_embs.append(batch_emb.detach().cpu().numpy().astype(np.float32))
final_ids = np.array(retrieval_ids, dtype=np.uint64).reshape(-, )
final_embs = np.concatenate(all_embs, axis=)
save_emb(final_embs, Path(save_path, ))
save_emb(final_ids, Path(save_path, ))
3. dataset.py - 数据处理
MyDataset - 训练数据集
- 处理用户行为序列数据,支持用户和物品交替出现的序列格式
- 实现高效的数据加载,使用文件偏移量进行随机访问
- 支持多种特征类型的 padding 和缺失值填充
- 实现负采样机制用于训练
MyTestDataset - 测试数据集
- 继承自训练数据集,专门用于推理阶段
- 处理冷启动问题(训练时未见过的特征值)
dataset.py 代码
import json
import pickle
import struct
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
class MyDataset(torch.utils.data.Dataset):
"""
用户序列数据集
Args:
data_dir: 数据文件目录
args: 全局参数
Attributes:
data_dir: 数据文件目录
maxlen: 最大长度
item_feat_dict: 物品特征字典
mm_emb_ids: 激活的 mm_emb 特征 ID
mm_emb_dict: 多模态特征字典
itemnum: 物品数量
usernum: 用户数量
indexer_i_rev: 物品索引字典 (reid -> item_id)
indexer_u_rev: 用户索引字典 (reid -> user_id)
indexer: 索引字典
feature_default_value: 特征缺省值
feature_types: 特征类型,分为 user 和 item 的 sparse, array, emb, continual 类型
feat_statistics: 特征统计信息,包括 user 和 item 的特征数量
"""
def __init__(self, data_dir, args):
""" 初始化数据集 """
super().__init__()
self.data_dir = Path(data_dir)
self._load_data_and_offsets()
self.maxlen = args.maxlen
self.mm_emb_ids = args.mm_emb_id
self.item_feat_dict = json.load(open(Path(data_dir, "item_feat_dict.json"), 'r'))
self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_ids)
with open(self.data_dir / 'indexer.pkl', 'rb') as ff:
indexer = pickle.load(ff)
self.itemnum = len(indexer['i'])
self.usernum = (indexer[])
.indexer_i_rev = {v: k k, v indexer[].items()}
.indexer_u_rev = {v: k k, v indexer[].items()}
.indexer = indexer
.feature_default_value, .feature_types, .feat_statistics = ._init_feat_info()
():
.data_file = (.data_dir / , )
(Path(.data_dir, ), ) f:
.seq_offsets = pickle.load(f)
():
.data_file.seek(.seq_offsets[uid])
line = .data_file.readline()
data = json.loads(line)
data
():
t = np.random.randint(l, r)
t s (t) .item_feat_dict:
t = np.random.randint(l, r)
t
():
user_sequence = ._load_user_data(uid)
ext_user_sequence = []
record_tuple user_sequence:
u, i, user_feat, item_feat, action_type, _ = record_tuple
u user_feat:
ext_user_sequence.insert(, (u, user_feat, , action_type))
i item_feat:
ext_user_sequence.append((i, item_feat, , action_type))
seq = np.zeros([.maxlen + ], dtype=np.int32)
pos = np.zeros([.maxlen + ], dtype=np.int32)
neg = np.zeros([.maxlen + ], dtype=np.int32)
token_type = np.zeros([.maxlen + ], dtype=np.int32)
next_token_type = np.zeros([.maxlen + ], dtype=np.int32)
next_action_type = np.zeros([.maxlen + ], dtype=np.int32)
seq_feat = np.empty([.maxlen + ], dtype=)
pos_feat = np.empty([.maxlen + ], dtype=)
neg_feat = np.empty([.maxlen + ], dtype=)
nxt = ext_user_sequence[-]
idx = .maxlen
ts = ()
record_tuple ext_user_sequence:
record_tuple[] == record_tuple[]:
ts.add(record_tuple[])
record_tuple (ext_user_sequence[:-]):
i, feat, type_, act_type = record_tuple
next_i, next_feat, next_type, next_act_type = nxt
feat = .fill_missing_feat(feat, i)
next_feat = .fill_missing_feat(next_feat, next_i)
seq[idx] = i
token_type[idx] = type_
next_token_type[idx] = next_type
next_act_type :
next_action_type[idx] = next_act_type
seq_feat[idx] = feat
next_type == next_i != :
pos[idx] = next_i
pos_feat[idx] = next_feat
neg_id = ._random_neq(, .itemnum + , ts)
neg[idx] = neg_id
neg_feat[idx] = .fill_missing_feat(.item_feat_dict[(neg_id)], neg_id)
nxt = record_tuple
idx -=
idx == -:
seq_feat = np.where(seq_feat == , .feature_default_value, seq_feat)
pos_feat = np.where(pos_feat == , .feature_default_value, pos_feat)
neg_feat = np.where(neg_feat == , .feature_default_value, neg_feat)
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat
():
(.seq_offsets)
():
feat_default_value = {}
feat_statistics = {}
feat_types = {}
feat_types[] = [, , , ]
feat_types[] = [, , , , , , , , , , , , , ,]
feat_types[] = []
feat_types[] = [, , , ]
feat_types[] = .mm_emb_ids
feat_types[] = []
feat_types[] = []
feat_id feat_types[]:
feat_default_value[feat_id] =
feat_statistics[feat_id] = (.indexer[][feat_id])
feat_id feat_types[]:
feat_default_value[feat_id] =
feat_statistics[feat_id] = (.indexer[][feat_id])
feat_id feat_types[]:
feat_default_value[feat_id] = []
feat_statistics[feat_id] = (.indexer[][feat_id])
feat_id feat_types[]:
feat_default_value[feat_id] = []
feat_statistics[feat_id] = (.indexer[][feat_id])
feat_id feat_types[]:
feat_default_value[feat_id] =
feat_id feat_types[]:
feat_default_value[feat_id] =
feat_id feat_types[]:
feat_default_value[feat_id] = np.zeros((.mm_emb_dict[feat_id].values())[].shape[], dtype=np.float32)
feat_default_value, feat_types, feat_statistics
():
feat == :
feat = {}
filled_feat = {}
k feat.keys():
filled_feat[k] = feat[k]
all_feat_ids = []
feat_type .feature_types.values():
all_feat_ids.extend(feat_type)
missing_fields = (all_feat_ids) - (feat.keys())
feat_id missing_fields:
filled_feat[feat_id] = .feature_default_value[feat_id]
feat_id .feature_types[]:
item_id != .indexer_i_rev[item_id] .mm_emb_dict[feat_id]:
(.mm_emb_dict[feat_id][.indexer_i_rev[item_id]]) == np.ndarray:
filled_feat[feat_id] = .mm_emb_dict[feat_id][.indexer_i_rev[item_id]]
filled_feat
():
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = (*batch)
seq = torch.from_numpy(np.array(seq))
pos = torch.from_numpy(np.array(pos))
neg = torch.from_numpy(np.array(neg))
token_type = torch.from_numpy(np.array(token_type))
next_token_type = torch.from_numpy(np.array(next_token_type))
next_action_type = torch.from_numpy(np.array(next_action_type))
seq_feat = (seq_feat)
pos_feat = (pos_feat)
neg_feat = (neg_feat)
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat
():
():
().__init__(data_dir, args)
():
.data_file = (.data_dir / , )
(Path(.data_dir, ), ) f:
.seq_offsets = pickle.load(f)
():
processed_feat = {}
feat_id, feat_value feat.items():
(feat_value) == :
value_list = []
v feat_value:
(v) == :
value_list.append()
:
value_list.append(v)
processed_feat[feat_id] = value_list
(feat_value) == :
processed_feat[feat_id] =
:
processed_feat[feat_id] = feat_value
processed_feat
():
user_sequence = ._load_user_data(uid)
ext_user_sequence = []
record_tuple user_sequence:
u, i, user_feat, item_feat, _, _ = record_tuple
u:
(u) == :
user_id = u
:
user_id = .indexer_u_rev[u]
u user_feat:
(u) == :
u =
user_feat:
user_feat = ._process_cold_start_feat(user_feat)
ext_user_sequence.insert(, (u, user_feat, ))
i item_feat:
i > .itemnum:
i =
item_feat:
item_feat = ._process_cold_start_feat(item_feat)
ext_user_sequence.append((i, item_feat, ))
seq = np.zeros([.maxlen + ], dtype=np.int32)
token_type = np.zeros([.maxlen + ], dtype=np.int32)
seq_feat = np.empty([.maxlen + ], dtype=)
idx = .maxlen
ts = ()
record_tuple ext_user_sequence:
record_tuple[] == record_tuple[]:
ts.add(record_tuple[])
record_tuple (ext_user_sequence[:-]):
i, feat, type_ = record_tuple
feat = .fill_missing_feat(feat, i)
seq[idx] = i
token_type[idx] = type_
seq_feat[idx] = feat
idx -=
idx == -:
seq_feat = np.where(seq_feat == , .feature_default_value, seq_feat)
seq, token_type, seq_feat, user_id
():
(Path(.data_dir, ), ) f:
temp = pickle.load(f)
(temp)
():
seq, token_type, seq_feat, user_id = (*batch)
seq = torch.from_numpy(np.array(seq))
token_type = torch.from_numpy(np.array(token_type))
seq_feat = (seq_feat)
seq, token_type, seq_feat, user_id
():
num_points = emb.shape[]
num_dimensions = emb.shape[]
()
(Path(save_path), ) f:
f.write(struct.pack(, num_points, num_dimensions))
emb.tofile(f)
():
SHAPE_DICT = {: , : , : , : , : , : }
mm_emb_dict = {}
feat_id tqdm(feat_ids, desc=):
shape = SHAPE_DICT[feat_id]
emb_dict = {}
feat_id != :
:
base_path = Path(mm_path, )
json_file base_path.glob():
(json_file, , encoding=) file:
line file:
data_dict_origin = json.loads(line.strip())
insert_emb = data_dict_origin[]
(insert_emb, ):
insert_emb = np.array(insert_emb, dtype=np.float32)
data_dict = {data_dict_origin[]: insert_emb}
emb_dict.update(data_dict)
Exception e:
()
feat_id == :
(Path(mm_path, ), ) f:
emb_dict = pickle.load(f)
mm_emb_dict[feat_id] = emb_dict
()
mm_emb_dict
4. model_rqvae.py - 多模态特征压缩
实现了 RQ-VAE(Residual Quantized Variational AutoEncoder)框架,用于将高维多模态 embedding 转换为离散的语义 ID:
核心组件:
RQEncoder/RQDecoder:编码器和解码器
VQEmbedding:向量量化模块,支持 K-means 初始化
RQ:残差量化器,实现多级量化
RQVAE:完整的 RQ-VAE 模型
量化方法:
- 支持标准 K-means 和平衡 K-means 聚类
- 使用余弦距离或 L2 距离进行向量量化
- 通过残差量化实现更精确的特征表示
model_rqvae.py 代码
""" 选手可参考以下流程,使用提供的 RQ-VAE 框架代码将多模态 emb 数据转换为 Semantic Id:
1. 使用 MmEmbDataset 读取不同特征 ID 的多模态 emb 数据.
2. 训练 RQ-VAE 模型,训练完成后将数据转换为 Semantic Id.
3. 参照 Item Sparse 特征格式处理 Semantic Id,作为新特征加入 Baseline 模型训练.
"""
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
():
km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init=)
data_cpu = data.detach().cpu()
np_data = data_cpu.numpy()
km.fit(np_data)
torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)
(torch.nn.Module):
():
().__init__()
.num_clusters = num_clusters
.kmeans_iters = kmeans_iters
.tolerance = tolerance
.device = device
._codebook =
():
torch.cdist(data, ._codebook)
():
samples_cnt = dist.shape[]
samples_labels = torch.zeros(samples_cnt, dtype=torch.long, device=.device)
clusters_cnt = torch.zeros(.num_clusters, dtype=torch.long, device=.device)
sorted_indices = torch.argsort(dist, dim=-)
i (samples_cnt):
j (.num_clusters):
cluster_idx = sorted_indices[i, j]
clusters_cnt[cluster_idx] < samples_cnt // .num_clusters:
samples_labels[i] = cluster_idx
clusters_cnt[cluster_idx] +=
samples_labels
():
_new_codebook = []
i (.num_clusters):
cluster_data = data[samples_labels == i]
(cluster_data) > :
_new_codebook.append(cluster_data.mean(dim=))
:
_new_codebook.append(._codebook[i])
torch.stack(_new_codebook)
():
num_emb, codebook_emb_dim = data.shape
data = data.to(.device)
indices = torch.randperm(num_emb)[:.num_clusters]
._codebook = data[indices].clone()
_ (.kmeans_iters):
dist = ._compute_distances(data)
samples_labels = ._assign_clusters(dist)
_new_codebook = ._update_codebook(data, samples_labels)
torch.norm(_new_codebook - ._codebook) < .tolerance:
._codebook = _new_codebook
._codebook, samples_labels
():
data = data.to(.device)
dist = ._compute_distances(data)
samples_labels = ._assign_clusters(dist)
samples_labels
(torch.nn.Module):
():
().__init__()
.stages = torch.nn.ModuleList()
in_dim = input_dim
out_dim hidden_channels:
stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())
.stages.append(stage)
in_dim = out_dim
.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, latent_dim), torch.nn.ReLU()))
():
stage .stages:
x = stage(x)
x
(torch.nn.Module):
():
().__init__()
.stages = torch.nn.ModuleList()
in_dim = latent_dim
out_dim hidden_channels:
stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())
.stages.append(stage)
in_dim = out_dim
.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, output_dim), torch.nn.ReLU()))
():
stage .stages:
x = stage(x)
x
(torch.nn.Embedding):
():
(VQEmbedding, ).__init__(num_clusters, codebook_emb_dim)
.num_clusters = num_clusters
.codebook_emb_dim = codebook_emb_dim
.kmeans_method = kmeans_method
.kmeans_iters = kmeans_iters
.distances_method = distances_method
.device = device
():
.kmeans_method == :
_codebook, _ = kmeans(data, .num_clusters, .kmeans_iters)
.kmeans_method == :
BKmeans = BalancedKmeans(
num_clusters=.num_clusters,
kmeans_iters=.kmeans_iters,
tolerance=,
device=.device
)
_codebook, _ = BKmeans.fit(data)
:
_codebook = torch.randn(.num_clusters, .codebook_emb_dim)
_codebook = _codebook.to(.device)
_codebook.shape == (.num_clusters, .codebook_emb_dim)
.codebook = torch.nn.Parameter(_codebook)
():
_codebook_t = .codebook.t()
_codebook_t.shape == (.codebook_emb_dim, .num_clusters)
data.shape[-] == .codebook_emb_dim
.distances_method == :
data_norm = F.normalize(data, p=, dim=-)
_codebook_t_norm = F.normalize(_codebook_t, p=, dim=)
distances = - torch.mm(data_norm, _codebook_t_norm)
:
data_norm_sq = data.().(dim=-, keepdim=)
_codebook_t_norm_sq = _codebook_t.().(dim=, keepdim=)
distances = torch.addmm(data_norm_sq + _codebook_t_norm_sq, data, _codebook_t, beta=, alpha=-)
distances
():
distances = ._compute_distances(data)
_semantic_id = torch.argmin(distances, dim=-)
_semantic_id
():
update_emb = ().forward(_semantic_id)
update_emb
():
._create_codebook(data)
_semantic_id = ._create_semantic_id(data)
update_emb = ._update_emb(_semantic_id)
update_emb, _semantic_id
(torch.nn.Module):
():
().__init__()
.num_codebooks = num_codebooks
.codebook_size = codebook_size
(.codebook_size) == .num_codebooks
.codebook_emb_dim = codebook_emb_dim
.shared_codebook = shared_codebook
.kmeans_method = kmeans_method
.kmeans_iters = kmeans_iters
.distances_method = distances_method
.loss_beta = loss_beta
.device = device
.shared_codebook:
.vqmodules = torch.nn.ModuleList([
VQEmbedding(
.codebook_size[], .codebook_emb_dim, .kmeans_method, .kmeans_iters, .distances_method, .device,
) _ (.num_codebooks)
])
:
.vqmodules = torch.nn.ModuleList([
VQEmbedding(
.codebook_size[idx], .codebook_emb_dim, .kmeans_method, .kmeans_iters, .distances_method, .device,
) idx (.num_codebooks)
])
():
res_emb = data.detach().clone()
vq_emb_list, res_emb_list = [], []
semantic_id_list = []
vq_emb_aggre = torch.zeros_like(data)
i (.num_codebooks):
vq_emb, _semantic_id = .vqmodules[i](res_emb)
res_emb -= vq_emb
vq_emb_aggre += vq_emb
res_emb_list.append(res_emb)
vq_emb_list.append(vq_emb_aggre)
semantic_id_list.append(_semantic_id.unsqueeze(dim=-))
semantic_id_list = torch.cat(semantic_id_list, dim=-)
vq_emb_list, res_emb_list, semantic_id_list
():
rqvae_loss_list = []
idx, quant (vq_emb_list):
loss1 = (res_emb_list[idx].detach() - quant).().mean()
loss2 = (res_emb_list[idx] - quant.detach()).().mean()
partial_loss = loss1 + .loss_beta * loss2
rqvae_loss_list.append(partial_loss)
rqvae_loss = torch.(torch.stack(rqvae_loss_list))
rqvae_loss
():
vq_emb_list, res_emb_list, semantic_id_list = .quantize(data)
rqvae_loss = ._rqvae_loss(vq_emb_list, res_emb_list)
vq_emb_list, semantic_id_list, rqvae_loss
(torch.nn.Module):
():
().__init__()
.encoder = RQEncoder(input_dim, hidden_channels, latent_dim).to(device)
.decoder = RQDecoder(latent_dim, hidden_channels[::-], input_dim).to(device)
.rq = RQ(
num_codebooks, codebook_size, latent_dim, shared_codebook, kmeans_method, kmeans_iters, distances_method, loss_beta, device,
).to(device)
():
.encoder(x)
():
(z_vq, ):
z_vq = z_vq[-]
.decoder(z_vq)
():
recon_loss = F.mse_loss(x_hat, x_gt, reduction=)
total_loss = recon_loss + rqvae_loss
recon_loss, rqvae_loss, total_loss
():
z_e = .encode(x_gt)
vq_emb_list, semantic_id_list, rqvae_loss = .rq(z_e)
semantic_id_list
():
z_e = .encode(x_gt)
vq_emb_list, semantic_id_list, rqvae_loss = .rq(z_e)
x_hat = .decode(vq_emb_list)
recon_loss, rqvae_loss, total_loss = .compute_loss(x_hat, x_gt, rqvae_loss)
x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss
5. run.sh - 运行脚本
简单的 bash 脚本,用于启动训练程序。
run.sh 代码
#!/bin/bash
echo ${RUNTIME_SCRIPT_DIR}
cd ${RUNTIME_SCRIPT_DIR}
python -u main.py
技术特点
- 高效注意力机制:使用 Flash Attention 优化计算效率
- 多模态融合:支持文本、图像等多种模态的 embedding 特征
- 特征工程:支持稀疏、密集、数组等多种特征类型
- 序列建模:同时建模用户和物品的交互序列
- 可扩展性:支持大规模物品库的 embedding 保存和检索
数据流程
- 训练阶段:读取用户序列 → 特征 embedding → Transformer 编码 → 计算正负样本 loss
- 推理阶段:生成用户表征 → 保存物品 embedding → 进行向量检索推荐
- 多模态处理:原始 embedding → RQ-VAE 压缩 → 语义 ID → 作为新特征加入模型