跳到主要内容2025 腾讯广告算法大赛 Baseline 项目源码解析 | 极客日志PythonAI算法
2025 腾讯广告算法大赛 Baseline 项目源码解析
2025 腾讯广告算法大赛 Baseline 项目基于 PyTorch 实现序列推荐系统,利用 Transformer 架构处理用户交互序列。核心特性包括 Flash Attention 加速、多模态特征融合(文本/图像 Embedding)及 RQ-VAE 量化压缩。代码结构清晰,涵盖数据预处理、模型定义、训练循环及推理逻辑,适合学习推荐系统工程化落地。
HadoopMan4 浏览 项目概述
这个 Baseline 项目是一个基于 PyTorch 构建的序列推荐系统,核心目标是建模用户与物品的交互序列。它通过引入多模态特征(如文本、图像的 Embedding)来增强推荐效果,整体架构采用 Transformer 变体,并针对大规模数据场景进行了优化。
核心架构与模块
1. 主训练脚本 (main.py)
这是整个项目的入口,负责串联数据加载、模型初始化及训练循环。代码中集成了参数解析、断点续训支持以及 TensorBoard 日志记录功能。
import argparse
import json
import os
import time
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dataset import MyDataset
from model import BaselineModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=101, type=int)
parser.add_argument('--hidden_units', default=32, type=int)
parser.add_argument('--num_blocks', default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, action=)
parser.add_argument(, default=, =)
parser.add_argument(, action=)
parser.add_argument(, nargs=, default=[], =, choices=[(s) s (, )])
args = parser.parse_args()
args
__name__ == :
Path(os.environ.get()).mkdir(parents=, exist_ok=)
Path(os.environ.get()).mkdir(parents=, exist_ok=)
log_file = (Path(os.environ.get(), ), )
writer = SummaryWriter(os.environ.get())
data_path = os.environ.get()
args = get_args()
dataset = MyDataset(data_path, args)
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [, ])
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=, num_workers=, collate_fn=dataset.collate_fn
)
valid_loader = DataLoader(
valid_dataset, batch_size=args.batch_size, shuffle=, num_workers=, collate_fn=dataset.collate_fn
)
usernum, itemnum = dataset.usernum, dataset.itemnum
feat_statistics, feat_types = dataset.feat_statistics, dataset.feature_types
model = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)
name, param model.named_parameters():
:
torch.nn.init.xavier_normal_(param.data)
Exception:
model.pos_emb.weight.data[, :] =
model.item_emb.weight.data[, :] =
model.user_emb.weight.data[, :] =
k model.sparse_emb:
model.sparse_emb[k].weight.data[, :] =
epoch_start_idx =
args.state_dict_path :
:
model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
tail = args.state_dict_path[args.state_dict_path.find()+:]
epoch_start_idx = (tail[:tail.find()])+
:
(, end=)
(args.state_dict_path)
RuntimeError()
bce_criterion = torch.nn.BCEWithLogitsLoss(reduction=)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(, ))
best_val_ndcg, best_val_hr = ,
best_test_ndcg, best_test_hr = ,
T =
t0 = time.time()
global_step =
()
epoch (epoch_start_idx, args.num_epochs + ):
model.train()
args.inference_only:
step, batch tqdm((train_loader), total=(train_loader)):
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch
seq = seq.to(args.device)
pos = pos.to(args.device)
neg = neg.to(args.device)
pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
optimizer.zero_grad()
indices = np.where(next_token_type == )
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
log_json = json.dumps({: global_step, : loss.item(), : epoch, : time.time()})
log_file.write(log_json + )
log_file.flush()
(log_json)
writer.add_scalar(, loss.item(), global_step)
global_step +=
param model.item_emb.parameters():
loss += args.l2_emb * torch.norm(param)
loss.backward()
optimizer.step()
model.()
valid_loss_sum =
step, batch tqdm((valid_loader), total=(valid_loader)):
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch
seq = seq.to(args.device)
pos = pos.to(args.device)
neg = neg.to(args.device)
pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
indices = np.where(next_token_type == )
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
valid_loss_sum += loss.item()
valid_loss_sum /= (valid_loader)
save_dir = Path(os.environ.get(), )
save_dir.mkdir(parents=, exist_ok=)
torch.save(model.state_dict(), save_dir / )
()
writer.close()
log_file.close()
1
type
int
'--num_epochs'
3
type
int
'--num_heads'
1
type
int
'--dropout_rate'
0.2
type
float
'--l2_emb'
0.0
type
float
'--device'
'cuda'
type
str
'--inference_only'
'store_true'
'--state_dict_path'
None
type
str
'--norm_first'
'store_true'
'--mm_emb_id'
'+'
'81'
type
str
str
for
in
range
81
87
return
if
'__main__'
'TRAIN_LOG_PATH'
True
True
'TRAIN_TF_EVENTS_PATH'
True
True
open
'TRAIN_LOG_PATH'
'train.log'
'w'
'TRAIN_TF_EVENTS_PATH'
'TRAIN_DATA_PATH'
0.9
0.1
True
0
False
0
for
in
try
except
pass
0
0
0
0
0
0
for
in
0
0
1
if
is
not
None
try
'epoch='
6
int
'.'
1
except
print
'failed loading state_dicts, pls check file path: '
""
print
raise
'failed loading state_dicts, pls check file path!'
'mean'
0.9
0.98
0.0
0.0
0.0
0.0
0.0
0
print
"Start training"
for
in
range
1
if
break
for
in
enumerate
len
1
'global_step'
'loss'
'epoch'
'time'
'\n'
print
'Loss/train'
1
for
in
eval
0
for
in
enumerate
len
1
len
'TRAIN_CKPT_PATH'
f"global_step{global_step}.valid_loss={valid_loss_sum:.4f}"
True
True
"model.pt"
print
"Done"
2. 核心模型实现 (model.py)
BaselineModel 是整个系统的引擎,基于 Transformer 架构设计。它实现了 FlashMultiHeadAttention 以利用 PyTorch 2.0+ 的 Flash Attention 特性加速计算,同时兼容标准注意力机制作为降级方案。
模型支持多种特征类型:稀疏特征、数组特征、连续特征以及多模态 Embedding 特征。在特征处理上,通过 feat2emb 方法将不同维度的特征统一转换为 Embedding 表示,再拼接输入到全连接层和 Transformer 块中。
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dataset import save_emb
class FlashMultiHeadAttention(torch.nn.Module):
def __init__(self, hidden_units, num_heads, dropout_rate):
super(FlashMultiHeadAttention, self).__init__()
self.hidden_units = hidden_units
self.num_heads = num_heads
self.head_dim = hidden_units // num_heads
self.dropout_rate = dropout_rate
assert hidden_units % num_heads == 0, "hidden_units must be divisible by num_heads"
self.q_linear = torch.nn.Linear(hidden_units, hidden_units)
self.k_linear = torch.nn.Linear(hidden_units, hidden_units)
self.v_linear = torch.nn.Linear(hidden_units, hidden_units)
self.out_linear = torch.nn.Linear(hidden_units, hidden_units)
def forward(self, query, key, value, attn_mask=None):
batch_size, seq_len, _ = query.size()
Q = self.q_linear(query)
K = self.k_linear(key)
V = self.v_linear(value)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
if hasattr(F, 'scaled_dot_product_attention'):
attn_output = F.scaled_dot_product_attention(
Q, K, V, dropout_p=self.dropout_rate if self.training else 0.0, attn_mask=attn_mask.unsqueeze(1)
)
else:
scale = (self.head_dim) ** -0.5
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
if attn_mask is not None:
scores.masked_fill_(attn_mask.unsqueeze(1).logical_not(), float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training)
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_units)
output = self.out_linear(attn_output)
return output, None
class PointWiseFeedForward(torch.nn.Module):
def __init__(self, hidden_units, dropout_rate):
super(PointWiseFeedForward, self).__init__()
self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
self.dropout1 = torch.nn.Dropout(p=dropout_rate)
self.relu = torch.nn.ReLU()
self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
self.dropout2 = torch.nn.Dropout(p=dropout_rate)
def forward(self, inputs):
outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
outputs = outputs.transpose(-1, -2)
return outputs
class BaselineModel(torch.nn.Module):
def __init__(self, user_num, item_num, feat_statistics, feat_types, args):
self.user_num = user_num
self.item_num = item_num
self.dev = args.device
self.norm_first = args.norm_first
self.maxlen = args.maxlen
self.item_emb = torch.nn.Embedding(self.item_num + 1, args.hidden_units, padding_idx=0)
self.user_emb = torch.nn.Embedding(self.user_num + 1, args.hidden_units, padding_idx=0)
self.pos_emb = torch.nn.Embedding(2 * args.maxlen + 1, args.hidden_units, padding_idx=0)
self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
self.sparse_emb = torch.nn.ModuleDict()
self.emb_transform = torch.nn.ModuleDict()
self.attention_layernorms = torch.nn.ModuleList()
self.attention_layers = torch.nn.ModuleList()
self.forward_layernorms = torch.nn.ModuleList()
self.forward_layers = torch.nn.ModuleList()
self._init_feat_info(feat_statistics, feat_types)
userdim = args.hidden_units * (len(self.USER_SPARSE_FEAT) + 1 + len(self.USER_ARRAY_FEAT)) + len(self.USER_CONTINUAL_FEAT)
itemdim = (args.hidden_units * (len(self.ITEM_SPARSE_FEAT) + 1 + len(self.ITEM_ARRAY_FEAT)) + len(self.ITEM_CONTINUAL_FEAT) + args.hidden_units * len(self.ITEM_EMB_FEAT))
self.userdnn = torch.nn.Linear(userdim, args.hidden_units)
self.itemdnn = torch.nn.Linear(itemdim, args.hidden_units)
self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
for _ in range(args.num_blocks):
new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
self.attention_layernorms.append(new_attn_layernorm)
new_attn_layer = FlashMultiHeadAttention(args.hidden_units, args.num_heads, args.dropout_rate)
self.attention_layers.append(new_attn_layer)
new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
self.forward_layernorms.append(new_fwd_layernorm)
new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)
self.forward_layers.append(new_fwd_layer)
for k in self.USER_SPARSE_FEAT:
self.sparse_emb[k] = torch.nn.Embedding(self.USER_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)
for k in self.ITEM_SPARSE_FEAT:
self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)
for k in self.ITEM_ARRAY_FEAT:
self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)
for k in self.USER_ARRAY_FEAT:
self.sparse_emb[k] = torch.nn.Embedding(self.USER_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)
for k in self.ITEM_EMB_FEAT:
self.emb_transform[k] = torch.nn.Linear(self.ITEM_EMB_FEAT[k], args.hidden_units)
def _init_feat_info(self, feat_statistics, feat_types):
self.USER_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['user_sparse']}
self.USER_CONTINUAL_FEAT = feat_types['user_continual']
self.ITEM_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['item_sparse']}
self.ITEM_CONTINUAL_FEAT = feat_types['item_continual']
self.USER_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['user_array']}
self.ITEM_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['item_array']}
EMB_SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}
self.ITEM_EMB_FEAT = {k: EMB_SHAPE_DICT[k] for k in feat_types['item_emb']}
def feat2tensor(self, seq_feature, k):
batch_size = len(seq_feature)
if k in self.ITEM_ARRAY_FEAT or k in self.USER_ARRAY_FEAT:
max_array_len = 0
max_seq_len = 0
for i in range(batch_size):
seq_data = [item[k] for item in seq_feature[i]]
max_seq_len = max(max_seq_len, len(seq_data))
max_array_len = max(max_array_len, max(len(item_data) for item_data in seq_data))
batch_data = np.zeros((batch_size, max_seq_len, max_array_len), dtype=np.int64)
for i in range(batch_size):
seq_data = [item[k] for item in seq_feature[i]]
for j, item_data in enumerate(seq_data):
actual_len = min(len(item_data), max_array_len)
batch_data[i, j, :actual_len] = item_data[:actual_len]
return torch.from_numpy(batch_data).to(self.dev)
else:
max_seq_len = max(len(seq_feature[i]) for i in range(batch_size))
batch_data = np.zeros((batch_size, max_seq_len), dtype=np.int64)
for i in range(batch_size):
seq_data = [item[k] for item in seq_feature[i]]
batch_data[i] = seq_data
return torch.from_numpy(batch_data).to(self.dev)
def feat2emb(self, seq, feature_array, mask=None, include_user=False):
seq = seq.to(self.dev)
if include_user:
user_mask = (mask == 2).to(self.dev)
item_mask = (mask == 1).to(self.dev)
user_embedding = self.user_emb(user_mask * seq)
item_embedding = self.item_emb(item_mask * seq)
item_feat_list = [item_embedding]
user_feat_list = [user_embedding]
else:
item_embedding = self.item_emb(seq)
item_feat_list = [item_embedding]
all_feat_types = [(self.ITEM_SPARSE_FEAT, 'item_sparse', item_feat_list), (self.ITEM_ARRAY_FEAT, 'item_array', item_feat_list), (self.ITEM_CONTINUAL_FEAT, 'item_continual', item_feat_list)]
if include_user:
all_feat_types.extend([(self.USER_SPARSE_FEAT, 'user_sparse', user_feat_list), (self.USER_ARRAY_FEAT, 'user_array', user_feat_list), (self.USER_CONTINUAL_FEAT, 'user_continual', user_feat_list)])
for feat_dict, feat_type, feat_list in all_feat_types:
if not feat_dict: continue
for k in feat_dict:
tensor_feature = self.feat2tensor(feature_array, k)
if feat_type.endswith('sparse'):
feat_list.append(self.sparse_emb[k](tensor_feature))
elif feat_type.endswith('array'):
feat_list.append(self.sparse_emb[k](tensor_feature).sum(2))
elif feat_type.endswith('continual'):
feat_list.append(tensor_feature.unsqueeze(2))
for k in self.ITEM_EMB_FEAT:
batch_size = len(feature_array)
emb_dim = self.ITEM_EMB_FEAT[k]
seq_len = len(feature_array[0])
batch_emb_data = np.zeros((batch_size, seq_len, emb_dim), dtype=np.float32)
for i, seq_item in enumerate(feature_array):
for j, item in enumerate(seq_item):
if k in item:
batch_emb_data[i, j] = item[k]
tensor_feature = torch.from_numpy(batch_emb_data).to(self.dev)
item_feat_list.append(self.emb_transform[k](tensor_feature))
all_item_emb = torch.cat(item_feat_list, dim=2)
all_item_emb = torch.relu(self.itemdnn(all_item_emb))
if include_user:
all_user_emb = torch.cat(user_feat_list, dim=2)
all_user_emb = torch.relu(self.userdnn(all_user_emb))
seqs_emb = all_item_emb + all_user_emb
else:
seqs_emb = all_item_emb
return seqs_emb
def log2feats(self, log_seqs, mask, seq_feature):
batch_size = log_seqs.shape[0]
maxlen = log_seqs.shape[1]
seqs = self.feat2emb(log_seqs, seq_feature, mask=mask, include_user=True)
seqs *= self.item_emb.embedding_dim ** 0.5
poss = torch.arange(1, maxlen + 1, device=self.dev).unsqueeze(0).expand(batch_size, -1).clone()
poss *= log_seqs != 0
seqs += self.pos_emb(poss)
seqs = self.emb_dropout(seqs)
maxlen = seqs.shape[1]
ones_matrix = torch.ones((maxlen, maxlen), dtype=torch.bool, device=self.dev)
attention_mask_tril = torch.tril(ones_matrix)
attention_mask_pad = (mask != 0).to(self.dev)
attention_mask = attention_mask_tril.unsqueeze(0) & attention_mask_pad.unsqueeze(1)
for i in range(len(self.attention_layers)):
if self.norm_first:
x = self.attention_layernorms[i](seqs)
mha_outputs, _ = self.attention_layers[i](x, x, x, attn_mask=attention_mask)
seqs = seqs + mha_outputs
seqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs))
else:
mha_outputs, _ = self.attention_layers[i](seqs, seqs, seqs, attn_mask=attention_mask)
seqs = self.attention_layernorms[i](seqs + mha_outputs)
seqs = self.forward_layernorms[i](seqs + self.forward_layers[i](seqs))
log_feats = self.last_layernorm(seqs)
return log_feats
def forward(self, user_item, pos_seqs, neg_seqs, mask, next_mask, next_action_type, seq_feature, pos_feature, neg_feature):
log_feats = self.log2feats(user_item, mask, seq_feature)
loss_mask = (next_mask == 1).to(self.dev)
pos_embs = self.feat2emb(pos_seqs, pos_feature, include_user=False)
neg_embs = self.feat2emb(neg_seqs, neg_feature, include_user=False)
pos_logits = (log_feats * pos_embs).sum(dim=-1)
neg_logits = (log_feats * neg_embs).sum(dim=-1)
pos_logits = pos_logits * loss_mask
neg_logits = neg_logits * loss_mask
return pos_logits, neg_logits
def predict(self, log_seqs, seq_feature, mask):
log_feats = self.log2feats(log_seqs, mask, seq_feature)
final_feat = log_feats[:, -1, :]
return final_feat
def save_item_emb(self, item_ids, retrieval_ids, feat_dict, save_path, batch_size=1024):
all_embs = []
for start_idx in tqdm(range(0, len(item_ids), batch_size), desc="Saving item embeddings"):
end_idx = min(start_idx + batch_size, len(item_ids))
item_seq = torch.tensor(item_ids[start_idx:end_idx], device=self.dev).unsqueeze(0)
batch_feat = []
for i in range(start_idx, end_idx):
batch_feat.append(feat_dict[i])
batch_feat = np.array(batch_feat, dtype=object)
batch_emb = self.feat2emb(item_seq, [batch_feat], include_user=False).squeeze(0)
all_embs.append(batch_emb.detach().cpu().numpy().astype(np.float32))
final_ids = np.array(retrieval_ids, dtype=np.uint64).reshape(-1, 1)
final_embs = np.concatenate(all_embs, axis=0)
save_emb(final_embs, Path(save_path, 'embedding.fbin'))
save_emb(final_ids, Path(save_path, 'id.u64bin'))
3. 数据处理 (dataset.py)
数据模块是推荐系统的基石。MyDataset 类负责高效读取用户行为序列,支持文件偏移量随机访问,这对于处理海量数据至关重要。它实现了负采样逻辑,确保正负样本平衡,并能处理冷启动问题(即训练集中未出现的特征值)。测试数据集 MyTestDataset 继承自训练集,专门用于推理阶段的预测。
import json
import pickle
import struct
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, args):
super().__init__()
self.data_dir = Path(data_dir)
self._load_data_and_offsets()
self.maxlen = args.maxlen
self.mm_emb_ids = args.mm_emb_id
self.item_feat_dict = json.load(open(Path(data_dir, "item_feat_dict.json"), 'r'))
self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_ids)
with open(self.data_dir / 'indexer.pkl', 'rb') as ff:
indexer = pickle.load(ff)
self.itemnum = len(indexer['i'])
self.usernum = len(indexer['u'])
self.indexer_i_rev = {v: k for k, v in indexer['i'].items()}
self.indexer_u_rev = {v: k for k, v in indexer['u'].items()}
self.indexer = indexer
self.feature_default_value, self.feature_types, self.feat_statistics = self._init_feat_info()
def _load_data_and_offsets(self):
self.data_file = open(self.data_dir / "seq.jsonl", 'rb')
with open(Path(self.data_dir, 'seq_offsets.pkl'), 'rb') as f:
self.seq_offsets = pickle.load(f)
def _load_user_data(self, uid):
self.data_file.seek(self.seq_offsets[uid])
line = self.data_file.readline()
data = json.loads(line)
return data
def _random_neq(self, l, r, s):
t = np.random.randint(l, r)
while t in s or str(t) not in self.item_feat_dict:
t = np.random.randint(l, r)
return t
def __getitem__(self, uid):
user_sequence = self._load_user_data(uid)
ext_user_sequence = []
for record_tuple in user_sequence:
u, i, user_feat, item_feat, action_type, _ = record_tuple
if u and user_feat:
ext_user_sequence.insert(0, (u, user_feat, 2, action_type))
if i and item_feat:
ext_user_sequence.append((i, item_feat, 1, action_type))
seq = np.zeros([self.maxlen + 1], dtype=np.int32)
pos = np.zeros([self.maxlen + 1], dtype=np.int32)
neg = np.zeros([self.maxlen + 1], dtype=np.int32)
token_type = np.zeros([self.maxlen + 1], dtype=np.int32)
next_token_type = np.zeros([self.maxlen + 1], dtype=np.int32)
next_action_type = np.zeros([self.maxlen + 1], dtype=np.int32)
seq_feat = np.empty([self.maxlen + 1], dtype=object)
pos_feat = np.empty([self.maxlen + 1], dtype=object)
neg_feat = np.empty([self.maxlen + 1], dtype=object)
nxt = ext_user_sequence[-1]
idx = self.maxlen
ts = set()
for record_tuple in ext_user_sequence:
if record_tuple[2] == 1 and record_tuple[0]:
ts.add(record_tuple[0])
for record_tuple in reversed(ext_user_sequence[:-1]):
i, feat, type_, act_type = record_tuple
next_i, next_feat, next_type, next_act_type = nxt
feat = self.fill_missing_feat(feat, i)
next_feat = self.fill_missing_feat(next_feat, next_i)
seq[idx] = i
token_type[idx] = type_
next_token_type[idx] = next_type
if next_act_type is not None:
next_action_type[idx] = next_act_type
seq_feat[idx] = feat
if next_type == 1 and next_i != 0:
pos[idx] = next_i
pos_feat[idx] = next_feat
neg_id = self._random_neq(1, self.itemnum + 1, ts)
neg[idx] = neg_id
neg_feat[idx] = self.fill_missing_feat(self.item_feat_dict[str(neg_id)], neg_id)
nxt = record_tuple
idx -= 1
if idx == -1:
break
seq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)
pos_feat = np.where(pos_feat == None, self.feature_default_value, pos_feat)
neg_feat = np.where(neg_feat == None, self.feature_default_value, neg_feat)
return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat
def __len__(self):
return len(self.seq_offsets)
def _init_feat_info(self):
feat_default_value = {}
feat_statistics = {}
feat_types = {}
feat_types['user_sparse'] = ['103', '104', '105', '109']
feat_types['item_sparse'] = ['100', '117', '111', '118', '101', '102', '119', '120', '114', '112', '121', '115', '122', '116']
feat_types['item_array'] = []
feat_types['user_array'] = ['106', '107', '108', '110']
feat_types['item_emb'] = self.mm_emb_ids
feat_types['user_continual'] = []
feat_types['item_continual'] = []
for feat_id in feat_types['user_sparse']:
feat_default_value[feat_id] = 0
feat_statistics[feat_id] = len(self.indexer['f'][feat_id])
for feat_id in feat_types['item_sparse']:
feat_default_value[feat_id] = 0
feat_statistics[feat_id] = len(self.indexer['f'][feat_id])
for feat_id in feat_types['item_array']:
feat_default_value[feat_id] = [0]
feat_statistics[feat_id] = len(self.indexer['f'][feat_id])
for feat_id in feat_types['user_array']:
feat_default_value[feat_id] = [0]
feat_statistics[feat_id] = len(self.indexer['f'][feat_id])
for feat_id in feat_types['user_continual']:
feat_default_value[feat_id] = 0
for feat_id in feat_types['item_continual']:
feat_default_value[feat_id] = 0
for feat_id in feat_types['item_emb']:
feat_default_value[feat_id] = np.zeros(list(self.mm_emb_dict[feat_id].values())[0].shape[0], dtype=np.float32)
return feat_default_value, feat_types, feat_statistics
def fill_missing_feat(self, feat, item_id):
if feat == None:
feat = {}
filled_feat = {}
for k in feat.keys():
filled_feat[k] = feat[k]
all_feat_ids = []
for feat_type in self.feature_types.values():
all_feat_ids.extend(feat_type)
missing_fields = set(all_feat_ids) - set(feat.keys())
for feat_id in missing_fields:
filled_feat[feat_id] = self.feature_default_value[feat_id]
for feat_id in self.feature_types['item_emb']:
if item_id != 0 and self.indexer_i_rev[item_id] in self.mm_emb_dict[feat_id]:
if type(self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]) == np.ndarray:
filled_feat[feat_id] = self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]
return filled_feat
@staticmethod
def collate_fn(batch):
seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = zip(*batch)
seq = torch.from_numpy(np.array(seq))
pos = torch.from_numpy(np.array(pos))
neg = torch.from_numpy(np.array(neg))
token_type = torch.from_numpy(np.array(token_type))
next_token_type = torch.from_numpy(np.array(next_token_type))
next_action_type = torch.from_numpy(np.array(next_action_type))
seq_feat = list(seq_feat)
pos_feat = list(pos_feat)
neg_feat = list(neg_feat)
return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat
class MyTestDataset(MyDataset):
def __init__(self, data_dir, args):
super().__init__(data_dir, args)
def _load_data_and_offsets(self):
self.data_file = open(self.data_dir / "predict_seq.jsonl", 'rb')
with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:
self.seq_offsets = pickle.load(f)
def _process_cold_start_feat(self, feat):
processed_feat = {}
for feat_id, feat_value in feat.items():
if type(feat_value) == list:
value_list = []
for v in feat_value:
if type(v) == str:
value_list.append(0)
else:
value_list.append(v)
processed_feat[feat_id] = value_list
elif type(feat_value) == str:
processed_feat[feat_id] = 0
else:
processed_feat[feat_id] = feat_value
return processed_feat
def __getitem__(self, uid):
user_sequence = self._load_user_data(uid)
ext_user_sequence = []
for record_tuple in user_sequence:
u, i, user_feat, item_feat, _, _ = record_tuple
if u:
if type(u) == str:
user_id = u
else:
user_id = self.indexer_u_rev[u]
if u and user_feat:
if type(u) == str:
u = 0
if user_feat:
user_feat = self._process_cold_start_feat(user_feat)
ext_user_sequence.insert(0, (u, user_feat, 2))
if i and item_feat:
if i > self.itemnum:
i = 0
if item_feat:
item_feat = self._process_cold_start_feat(item_feat)
ext_user_sequence.append((i, item_feat, 1))
seq = np.zeros([self.maxlen + 1], dtype=np.int32)
token_type = np.zeros([self.maxlen + 1], dtype=np.int32)
seq_feat = np.empty([self.maxlen + 1], dtype=object)
idx = self.maxlen
ts = set()
for record_tuple in ext_user_sequence:
if record_tuple[2] == 1 and record_tuple[0]:
ts.add(record_tuple[0])
for record_tuple in reversed(ext_user_sequence[:-1]):
i, feat, type_ = record_tuple
feat = self.fill_missing_feat(feat, i)
seq[idx] = i
token_type[idx] = type_
seq_feat[idx] = feat
idx -= 1
if idx == -1:
break
seq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)
return seq, token_type, seq_feat, user_id
def __len__(self):
with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:
temp = pickle.load(f)
return len(temp)
@staticmethod
def collate_fn(batch):
seq, token_type, seq_feat, user_id = zip(*batch)
seq = torch.from_numpy(np.array(seq))
token_type = torch.from_numpy(np.array(token_type))
seq_feat = list(seq_feat)
return seq, token_type, seq_feat, user_id
def save_emb(emb, save_path):
num_points = emb.shape[0]
num_dimensions = emb.shape[1]
print(f'saving {save_path}')
with open(Path(save_path), 'wb') as f:
f.write(struct.pack('II', num_points, num_dimensions))
emb.tofile(f)
def load_mm_emb(mm_path, feat_ids):
SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}
mm_emb_dict = {}
for feat_id in tqdm(feat_ids, desc='Loading mm_emb'):
shape = SHAPE_DICT[feat_id]
emb_dict = {}
if feat_id != '81':
try:
base_path = Path(mm_path, f'emb_{feat_id}_{shape}')
for json_file in base_path.glob('*.json'):
with open(json_file, 'r', encoding='utf-8') as file:
for line in file:
data_dict_origin = json.loads(line.strip())
insert_emb = data_dict_origin['emb']
if isinstance(insert_emb, list):
insert_emb = np.array(insert_emb, dtype=np.float32)
data_dict = {data_dict_origin['anonymous_cid']: insert_emb}
emb_dict.update(data_dict)
except Exception as e:
print(f"transfer error: {e}")
if feat_id == '81':
with open(Path(mm_path, f'emb_{feat_id}_{shape}.pkl'), 'rb') as f:
emb_dict = pickle.load(f)
mm_emb_dict[feat_id] = emb_dict
print(f'Loaded #{feat_id} mm_emb')
return mm_emb_dict
4. 多模态特征压缩 (model_rqvae.py)
为了降低高维多模态 Embedding 的存储和计算开销,项目引入了 RQ-VAE(Residual Quantized Variational AutoEncoder)框架。该模块可以将连续的向量空间映射为离散的语义 ID,从而作为新的稀疏特征加入模型训练。核心组件包括编码器、解码器、向量量化模块以及残差量化器,支持 K-means 初始化和余弦/L2 距离度量。
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
def kmeans(data, n_clusters, kmeans_iters):
km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")
data_cpu = data.detach().cpu()
np_data = data_cpu.numpy()
km.fit(np_data)
return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)
class BalancedKmeans(torch.nn.Module):
def __init__(self, num_clusters: int, kmeans_iters: int, tolerance: float, device: str):
super().__init__()
self.num_clusters = num_clusters
self.kmeans_iters = kmeans_iters
self.tolerance = tolerance
self.device = device
self._codebook = None
def _compute_distances(self, data):
return torch.cdist(data, self._codebook)
def _assign_clusters(self, dist):
samples_cnt = dist.shape[0]
samples_labels = torch.zeros(samples_cnt, dtype=torch.long, device=self.device)
clusters_cnt = torch.zeros(self.num_clusters, dtype=torch.long, device=self.device)
sorted_indices = torch.argsort(dist, dim=-1)
for i in range(samples_cnt):
for j in range(self.num_clusters):
cluster_idx = sorted_indices[i, j]
if clusters_cnt[cluster_idx] < samples_cnt // self.num_clusters:
samples_labels[i] = cluster_idx
clusters_cnt[cluster_idx] += 1
break
return samples_labels
def _update_codebook(self, data, samples_labels):
_new_codebook = []
for i in range(self.num_clusters):
cluster_data = data[samples_labels == i]
if len(cluster_data) > 0:
_new_codebook.append(cluster_data.mean(dim=0))
else:
_new_codebook.append(self._codebook[i])
return torch.stack(_new_codebook)
def fit(self, data):
num_emb, codebook_emb_dim = data.shape
data = data.to(self.device)
indices = torch.randperm(num_emb)[:self.num_clusters]
self._codebook = data[indices].clone()
for _ in range(self.kmeans_iters):
dist = self._compute_distances(data)
samples_labels = self._assign_clusters(dist)
_new_codebook = self._update_codebook(data, samples_labels)
if torch.norm(_new_codebook - self._codebook) < self.tolerance:
break
self._codebook = _new_codebook
return self._codebook, samples_labels
def predict(self, data):
data = data.to(self.device)
dist = self._compute_distances(data)
samples_labels = self._assign_clusters(dist)
return samples_labels
class RQEncoder(torch.nn.Module):
def __init__(self, input_dim: int, hidden_channels: list, latent_dim: int):
super().__init__()
self.stages = torch.nn.ModuleList()
in_dim = input_dim
for out_dim in hidden_channels:
stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())
self.stages.append(stage)
in_dim = out_dim
self.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, latent_dim), torch.nn.ReLU()))
def forward(self, x):
for stage in self.stages:
x = stage(x)
return x
class RQDecoder(torch.nn.Module):
def __init__(self, latent_dim: int, hidden_channels: list, output_dim: int):
super().__init__()
self.stages = torch.nn.ModuleList()
in_dim = latent_dim
for out_dim in hidden_channels:
stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())
self.stages.append(stage)
in_dim = out_dim
self.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, output_dim), torch.nn.ReLU()))
def forward(self, x):
for stage in self.stages:
x = stage(x)
return x
class VQEmbedding(torch.nn.Embedding):
def __init__(self, num_clusters, codebook_emb_dim: int, kmeans_method: str, kmeans_iters: int, distances_method: str, device: str):
super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim)
self.num_clusters = num_clusters
self.codebook_emb_dim = codebook_emb_dim
self.kmeans_method = kmeans_method
self.kmeans_iters = kmeans_iters
self.distances_method = distances_method
self.device = device
def _create_codebook(self, data):
if self.kmeans_method == 'kmeans':
_codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)
elif self.kmeans_method == 'bkmeans':
BKmeans = BalancedKmeans(num_clusters=self.num_clusters, kmeans_iters=self.kmeans_iters, tolerance=1e-4, device=self.device)
_codebook, _ = BKmeans.fit(data)
else:
_codebook = torch.randn(self.num_clusters, self.codebook_emb_dim)
_codebook = _codebook.to(self.device)
assert _codebook.shape == (self.num_clusters, self.codebook_emb_dim)
self.codebook = torch.nn.Parameter(_codebook)
@torch.no_grad()
def _compute_distances(self, data):
_codebook_t = self.codebook.t()
assert _codebook_t.shape == (self.codebook_emb_dim, self.num_clusters)
assert data.shape[-1] == self.codebook_emb_dim
if self.distances_method == 'cosine':
data_norm = F.normalize(data, p=2, dim=-1)
_codebook_t_norm = F.normalize(_codebook_t, p=2, dim=0)
distances = 1 - torch.mm(data_norm, _codebook_t_norm)
else:
data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True)
_codebook_t_norm_sq = _codebook_t.pow(2).sum(dim=0, keepdim=True)
distances = torch.addmm(data_norm_sq + _codebook_t_norm_sq, data, _codebook_t, beta=1.0, alpha=-2.0)
return distances
@torch.no_grad()
def _create_semantic_id(self, data):
distances = self._compute_distances(data)
_semantic_id = torch.argmin(distances, dim=-1)
return _semantic_id
def _update_emb(self, _semantic_id):
update_emb = super().forward(_semantic_id)
return update_emb
def forward(self, data):
self._create_codebook(data)
_semantic_id = self._create_semantic_id(data)
update_emb = self._update_emb(_semantic_id)
return update_emb, _semantic_id
class RQ(torch.nn.Module):
def __init__(self, num_codebooks: int, codebook_size: list, codebook_emb_dim, shared_codebook: bool, kmeans_method, kmeans_iters, distances_method, loss_beta: float, device: str):
super().__init__()
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
assert len(self.codebook_size) == self.num_codebooks
self.codebook_emb_dim = codebook_emb_dim
self.shared_codebook = shared_codebook
self.kmeans_method = kmeans_method
self.kmeans_iters = kmeans_iters
self.distances_method = distances_method
self.loss_beta = loss_beta
self.device = device
if self.shared_codebook:
self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[0], self.codebook_emb_dim, self.kmeans_method, self.kmeans_iters, self.distances_method, self.device,) for _ in range(self.num_codebooks)])
else:
self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[idx], self.codebook_emb_dim, self.kmeans_method, self.kmeans_iters, self.distances_method, self.device,) for idx in range(self.num_codebooks)])
def quantize(self, data):
res_emb = data.detach().clone()
vq_emb_list, res_emb_list = [], []
semantic_id_list = []
vq_emb_aggre = torch.zeros_like(data)
for i in range(self.num_codebooks):
vq_emb, _semantic_id = self.vqmodules[i](res_emb)
res_emb -= vq_emb
vq_emb_aggre += vq_emb
res_emb_list.append(res_emb)
vq_emb_list.append(vq_emb_aggre)
semantic_id_list.append(_semantic_id.unsqueeze(dim=-1))
semantic_id_list = torch.cat(semantic_id_list, dim=-1)
return vq_emb_list, res_emb_list, semantic_id_list
def _rqvae_loss(self, vq_emb_list, res_emb_list):
rqvae_loss_list = []
for idx, quant in enumerate(vq_emb_list):
loss1 = (res_emb_list[idx].detach() - quant).pow(2.0).mean()
loss2 = (res_emb_list[idx] - quant.detach()).pow(2.0).mean()
partial_loss = loss1 + self.loss_beta * loss2
rqvae_loss_list.append(partial_loss)
rqvae_loss = torch.sum(torch.stack(rqvae_loss_list))
return rqvae_loss
def forward(self, data):
vq_emb_list, res_emb_list, semantic_id_list = self.quantize(data)
rqvae_loss = self._rqvae_loss(vq_emb_list, res_emb_list)
return vq_emb_list, semantic_id_list, rqvae_loss
class RQVAE(torch.nn.Module):
def __init__(self, input_dim: int, hidden_channels: list, latent_dim: int, num_codebooks: int, codebook_size: list, shared_codebook: bool, kmeans_method, kmeans_iters, distances_method, loss_beta: float, device: str):
super().__init__()
self.encoder = RQEncoder(input_dim, hidden_channels, latent_dim).to(device)
self.decoder = RQDecoder(latent_dim, hidden_channels[::-1], input_dim).to(device)
self.rq = RQ(num_codebooks, codebook_size, latent_dim, shared_codebook, kmeans_method, kmeans_iters, distances_method, loss_beta, device,).to(device)
def encode(self, x):
return self.encoder(x)
def decode(self, z_vq):
if isinstance(z_vq, list):
z_vq = z_vq[-1]
return self.decoder(z_vq)
def compute_loss(self, x_hat, x_gt, rqvae_loss):
recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean")
total_loss = recon_loss + rqvae_loss
return recon_loss, rqvae_loss, total_loss
def _get_codebook(self, x_gt):
z_e = self.encode(x_gt)
vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)
return semantic_id_list
def forward(self, x_gt):
z_e = self.encode(x_gt)
vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)
x_hat = self.decode(vq_emb_list)
recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)
return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss
运行说明
使用简单的 Bash 脚本即可启动训练流程,主要依赖环境变量配置数据路径和日志路径。
#!/bin/bash
echo ${RUNTIME_SCRIPT_DIR}
cd ${RUNTIME_SCRIPT_DIR}
python -u main.py
技术亮点总结
- 高效注意力:集成 Flash Attention 机制,显著提升长序列建模效率。
- 多模态融合:原生支持文本、图像等多模态 Embedding,并通过线性变换对齐维度。
- 特征工程:灵活处理稀疏、数组、连续及多模态特征,适配复杂业务场景。
- 序列建模:基于 Transformer 架构,有效捕捉用户行为的时间依赖性。
- 可扩展性:提供物品 Embedding 保存与检索接口,支持大规模候选集召回。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online