import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward=2048):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, 5000, d_model))
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward
)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
src = self.embedding(src) + self.positional_encoding[:, :src.size(1), :]
tgt = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]
output = self.transformer(src, tgt)
output = self.fc_out(output)
return output
def generate_data(vocab_size, seq_length, num_samples):
src = torch.randint(0, vocab_size, (num_samples, seq_length))
tgt = torch.randint(0, vocab_size, (num_samples, seq_length))
target = torch.randint(0, vocab_size, (num_samples, seq_length))
return src, tgt, target
vocab_size = 10000
seq_length = 20
num_samples = 1000
batch_size = 32
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
lr = 0.001
num_epochs = 5
src, tgt, target = generate_data(vocab_size, seq_length, num_samples)
dataset = TensorDataset(src, tgt, target)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
src_batch, tgt_batch, target_batch = batch
optimizer.zero_grad()
output = model(src_batch, tgt_batch)
output = output.view(-1, vocab_size)
target_batch = target_batch.view(-1)
loss = criterion(output, target_batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
def evaluate(model, src, tgt):
model.eval()
with torch.no_grad():
output = model(src, tgt)
return output
test_src, test_tgt, _ = generate_data(vocab_size, seq_length, 1)
output = evaluate(model, test_src, test_tgt)
print("模型输出形状:", output.shape)