import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset
from tqdm import tqdm
class SingleExpert(nn.Module):
def __init__(self, input_dim=28*28, output_dim=20):
super(SingleExpert, self).__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, output_dim)
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class MoE(nn.Module):
def __init__(self, input_dim=28*28, output_dim=20, num_experts=4):
super(MoE, self).__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList([SingleExpert(input_dim, output_dim) for _ in range(num_experts)])
self.gating_network = nn.Linear(input_dim, num_experts)
def forward(self, x):
x_flat = x.view(-1, 28*28)
gate_outputs = torch.softmax(self.gating_network(x_flat), dim=1)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
expert_indices = torch.multinomial(gate_outputs, num_samples=1).squeeze()
final_output = expert_outputs[torch.arange(x.size(0)), expert_indices]
return final_output, expert_outputs, gate_outputs, expert_indices
def moe_loss(targets, expert_outputs, gate_outputs):
"""
计算 MoE 模型的损失。
Arguments:
- targets: 真实输出向量,形状 [batch_size, output_dim]
- expert_outputs: 每个专家的输出向量,形状 [batch_size, num_experts, output_dim]
- gate_outputs: 门网络的输出(每个专家的概率),形状 [batch_size, num_experts]
Returns:
- loss: 计算的损失值
"""
errors = torch.sum((expert_outputs - targets.unsqueeze(1))**2, dim=2)
weighted_errors = torch.exp(-0.5 * errors)
weighted_errors = gate_outputs * weighted_errors
loss = -torch.log(torch.sum(weighted_errors, dim=1) + 1e-8)
return loss.mean()
def mse_hard_loss(targets, expert_outputs, gate_outputs):
"""
计算 MoE 模型的均方误差硬损失。
Arguments:
- targets: 真实输出向量,形状 [batch_size, output_dim]
- expert_outputs: 每个专家的输出向量,形状 [batch_size, num_experts, output_dim]
- gate_outputs: 门网络的输出(每个专家的概率),形状 [batch_size, num_experts]
Returns:
- loss: 计算的损失值
"""
fused_output = torch.sum(gate_outputs.unsqueeze(-1) * expert_outputs, dim=1)
reconstruction_error = torch.sum((fused_output - targets)**2, dim=1)
loss = torch.mean(reconstruction_error)
return loss
def mse_soft_loss(targets, expert_outputs, gate_outputs):
"""
计算 MoE 模型的均方误差损失。
Arguments:
- targets: 真实输出向量,形状 [batch_size, output_dim]
- expert_outputs: 每个专家的输出向量,形状 [batch_size, num_experts, output_dim]
- gate_outputs: 门网络的输出(每个专家的概率),形状 [batch_size, num_experts]
Returns:
- loss: 计算的损失值
"""
errors = torch.sum((expert_outputs - targets.unsqueeze(1))**2, dim=2)
weighted_errors = gate_outputs * errors
loss = torch.mean(weighted_errors)
return loss
def one_hot_encoding(labels, num_classes=10):
return torch.eye(num_classes, device=labels.device)[labels]
def train(model, dataloader, optimizer, num_epochs=10, loss_name='moe'):
model.train()
expert_selection_count = torch.zeros(model.num_experts, device=device)
for epoch in range(num_epochs):
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in tqdm(dataloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
final_output, expert_outputs, gate_outputs, expert_indices = model(inputs)
for idx in expert_indices:
expert_selection_count[idx] += 1
one_hot_labels = one_hot_encoding(labels, num_classes=len(combined_classes))
if loss_name == 'moe':
loss = moe_loss(one_hot_labels, expert_outputs, gate_outputs)
elif loss_name == 'mse_hard':
loss = mse_hard_loss(one_hot_labels, expert_outputs, gate_outputs)
elif loss_name == 'mse_soft':
loss = mse_soft_loss(one_hot_labels, expert_outputs, gate_outputs)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(final_output, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader):.4f}, Accuracy: {100* correct / total:.2f}%')
print("\n专家选择次数统计(训练集):")
for i, count in enumerate(expert_selection_count):
print(f'专家 {i}: 被选择 {count.item()} 次')
return expert_selection_count
def test_with_expert_statistics(model, dataloader, dataset_name="", num_classes=20):
model.eval()
correct = 0
total = 0
expert_selection_count = torch.zeros(model.num_experts, device=device)
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
final_output, _, _, expert_indices = model(inputs)
for idx in expert_indices:
expert_selection_count[idx] += 1
_, predicted = torch.max(final_output, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'\n{dataset_name} 测试集准确率:{accuracy:.2f}%')
print(f"\n{dataset_name} 专家选择次数统计(数据集层面):")
for i, count in enumerate(expert_selection_count):
print(f'专家 {i}: 被选择 {count.item()} 次,占比 {100* count.item()/ total:.2f}%')
return accuracy, expert_selection_count
if __name__ == "__main__":
batch_size = 1024
num_experts = 4
loss_name = ['moe', 'mse_hard', 'mse_soft'][1]
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda:0')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
fashion_mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
fashion_mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
fashion_mnist_train.targets = fashion_mnist_train.targets + 10
fashion_mnist_test.targets = fashion_mnist_test.targets + 10
combined_train_data = ConcatDataset([mnist_train, fashion_mnist_train])
combined_test_data = ConcatDataset([mnist_test, fashion_mnist_test])
print(f"训练集样本数:{len(combined_train_data)}")
print(f"测试集样本数:{len(combined_test_data)}")
mnist_classes = [str(i) for i in range(10)]
fashion_mnist_classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
combined_classes = mnist_classes + fashion_mnist_classes
train_loader = DataLoader(combined_train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(combined_test_data, batch_size=batch_size, shuffle=False)
moe_model = MoE(output_dim=len(combined_classes), num_experts=num_experts).to(device)
optimizer_moe = optim.Adam(moe_model.parameters(), lr=0.001)
print("\n训练 MoE 模型...")
train_expert_selection = train(moe_model, train_loader, optimizer_moe, num_epochs=10, loss_name=loss_name)
print("测试 MoE 模型在 MNIST 和 Fashion-MNIST 上...")
print("测试在 MNIST 数据集上...")
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
test_accuracy_mnist, mnist_expert_selection = test_with_expert_statistics(
moe_model, mnist_test_loader, dataset_name="MNIST", num_classes=len(combined_classes))
print("测试在 Fashion-MNIST 数据集上...")
fashion_mnist_test_loader = DataLoader(fashion_mnist_test, batch_size=batch_size, shuffle=False)
test_accuracy_fashion_mnist, fashion_mnist_expert_selection = test_with_expert_statistics(
moe_model, fashion_mnist_test_loader, dataset_name="Fashion-MNIST", num_classes=len(combined_classes))
Epoch [1/10], Loss: 0.3535, Accuracy: 64.07% Epoch [2/10], Loss: 0.1851, Accuracy: 80.30% Epoch [3/10], Loss: 0.1515, Accuracy: 84.36% Epoch [4/10], Loss: 0.1331, Accuracy: 86.52% Epoch [5/10], Loss: 0.1210, Accuracy: 88.03% Epoch [6/10], Loss: 0.1123, Accuracy: 89.14% Epoch [7/10], Loss: 0.1066, Accuracy: 89.77% Epoch [8/10], Loss: 0.0995, Accuracy: 90.36% Epoch [9/10], Loss: 0.0943, Accuracy: 90.98% Epoch [10/10], Loss: 0.0908, Accuracy: 91.47% 专家选择次数统计(训练集): 专家 0: 被选择 537500.0 次 专家 1: 被选择 167879.0 次 专家 2: 被选择 194856.0 次 专家 3: 被选择 299765.0 次 测试 MoE 模型在 MNIST 和 Fashion-MNIST 上... 测试在 MNIST 数据集上... MNIST 测试集准确率:93.73% MNIST 专家选择次数统计(数据集层面): 专家 0: 被选择 4584.0 次,占比 45.84% 专家 1: 被选择 1927.0 次,占比 19.27% 专家 2: 被选择 1194.0 次,占比 11.94% 专家 3: 被选择 2295.0 次,占比 22.95% 测试在 Fashion-MNIST 数据集上... Fashion-MNIST 测试集准确率:85.36% Fashion-MNIST 专家选择次数统计(数据集层面): 专家 0: 被选择 4290.0 次,占比 42.90% 专家 1: 被选择 1249.0 次,占比 12.49% 专家 2: 被选择 1622.0 次,占比 16.22% 专家 3: 被选择 2839.0 次,占比 28.39%
Epoch [1/10], Loss: 0.0924, Accuracy: 79.95% Epoch [2/10], Loss: 0.0470, Accuracy: 89.93% Epoch [3/10], Loss: 0.0381, Accuracy: 91.64% Epoch [4/10], Loss: 0.0341, Accuracy: 92.55% Epoch [5/10], Loss: 0.0311, Accuracy: 93.20% Epoch [6/10], Loss: 0.0286, Accuracy: 93.82% Epoch [7/10], Loss: 0.0271, Accuracy: 94.16% Epoch [8/10], Loss: 0.0255, Accuracy: 94.49% Epoch [9/10], Loss: 0.0241, Accuracy: 94.82% Epoch [10/10], Loss: 0.0230, Accuracy: 95.05% 专家选择次数统计(训练集): 专家 0: 被选择 884.0 次 专家 1: 被选择 1861.0 次 专家 2: 被选择 377796.0 次 专家 3: 被选择 819459.0 次 测试 MoE 模型在 MNIST 和 Fashion-MNIST 上... 测试在 MNIST 数据集上... MNIST 测试集准确率:97.57% MNIST 专家选择次数统计(数据集层面): 专家 0: 被选择 0.0 次,占比 0.00% 专家 1: 被选择 0.0 次,占比 0.00% 专家 2: 被选择 14.0 次,占比 0.14% 专家 3: 被选择 9986.0 次,占比 99.86% 测试在 Fashion-MNIST 数据集上... Fashion-MNIST 测试集准确率:88.30% Fashion-MNIST 专家选择次数统计(数据集层面): 专家 0: 被选择 1.0 次,占比 0.01% 专家 1: 被选择 2.0 次,占比 0.02% 专家 2: 被选择 6474.0 次,占比 64.74% 专家 3: 被选择 3523.0 次,占比 35.23%
Epoch [1/10], Loss: 0.1878, Accuracy: 79.58% Epoch [2/10], Loss: 0.0981, Accuracy: 89.75% Epoch [3/10], Loss: 0.0807, Accuracy: 91.44% Epoch [4/10], Loss: 0.0717, Accuracy: 92.37% Epoch [5/10], Loss: 0.0658, Accuracy: 92.96% Epoch [6/10], Loss: 0.0612, Accuracy: 93.35% Epoch [7/10], Loss: 0.0576, Accuracy: 93.79% Epoch [8/10], Loss: 0.0549, Accuracy: 94.12% Epoch [9/10], Loss: 0.0520, Accuracy: 94.49% Epoch [10/10], Loss: 0.0492, Accuracy: 94.74% 专家选择次数统计(训练集): 专家 0: 被选择 2878.0 次 专家 1: 被选择 132251.0 次 专家 2: 被选择 1063670.0 次 专家 3: 被选择 1201.0 次 测试 MoE 模型在 MNIST 和 Fashion-MNIST 上... 测试在 MNIST 数据集上... MNIST 测试集准确率:97.67% MNIST 专家选择次数统计(数据集层面): 专家 0: 被选择 0.0 次,占比 0.00% 专家 1: 被选择 0.0 次,占比 0.00% 专家 2: 被选择 10000.0 次,占比 100.00% 专家 3: 被选择 0.0 次,占比 0.00% 测试在 Fashion-MNIST 数据集上... Fashion-MNIST 测试集准确率:87.77% Fashion-MNIST 专家选择次数统计(数据集层面): 专家 0: 被选择 0.0 次,占比 0.00% 专家 1: 被选择 2576.0 次,占比 25.76% 专家 2: 被选择 7424.0 次,占比 74.24% 专家 3: 被选择 0.0 次,占比 0.00%