import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
class MoEModel(nn.Module):
def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts):
super(MoEModel, self).__init__()
self.num_experts = num_experts
self.output_experts_dim = output_experts_dim
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, experts_hidden1_dim),
nn.ReLU(),
nn.Linear(experts_hidden1_dim, experts_hidden2_dim),
nn.ReLU(),
nn.Linear(experts_hidden2_dim, output_experts_dim),
nn.ReLU()
) for _ in range(num_experts)
])
self.task1_head = nn.Sequential(
nn.Linear(output_experts_dim, task_hidden1_dim),
nn.ReLU(),
nn.Linear(task_hidden1_dim, task_hidden2_dim),
nn.ReLU(),
nn.Linear(task_hidden2_dim, output_task1_dim),
nn.Sigmoid()
)
self.task2_head = nn.Sequential(
nn.Linear(output_experts_dim, task_hidden1_dim),
nn.ReLU(),
nn.Linear(task_hidden1_dim, task_hidden2_dim),
nn.ReLU(),
nn.Linear(task_hidden2_dim, output_task2_dim),
nn.Sigmoid()
)
self.gating_network = nn.Sequential(
nn.Linear(input_dim, gate_hidden1_dim),
nn.ReLU(),
nn.Linear(gate_hidden1_dim, gate_hidden2_dim),
nn.ReLU(),
nn.Linear(gate_hidden2_dim, num_experts),
nn.Softmax(dim=1)
)
def forward(self, x):
gates = self.gating_network(x)
batch_size, _ = x.shape
task1_inputs = torch.zeros(batch_size, self.output_experts_dim)
task2_inputs = torch.zeros(batch_size, self.output_experts_dim)
for i in range(self.num_experts):
expert_output = self.experts[i](x)
task1_inputs += expert_output * gates[:, i].unsqueeze(1)
task2_inputs += expert_output * gates[:, i].unsqueeze(1)
task1_outputs = self.task1_head(task1_inputs)
task2_outputs = self.task2_head(task2_inputs)
return task1_outputs, task2_outputs
num_experts = 4
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 3
output_task2_dim = 2
torch.manual_seed(42)
input_dim = 10
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)
y_train_task2 = torch.rand(num_samples, output_task2_dim)
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
model = MoEModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts)
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 100
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):
outputs_task1, outputs_task2 = model(X_batch)
loss_task1 = criterion_task1(outputs_task1, y_task1_batch)
loss_task2 = criterion_task2(outputs_task2, y_task2_batch)
total_loss = loss_task1 + loss_task2
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
running_loss += total_loss.item()
if epoch % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
print(model)
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
model.eval()
with torch.no_grad():
test_input = torch.randint(0, 2, (1, input_dim)).float()
pred_task1, pred_task2 = model(test_input)
print(f'一级场景预测结果:{pred_task1}')
print(f'二级场景预测结果:{pred_task2}')