一、前言:为什么要自己实现神经网络?
很多人刚接触深度学习时,会直接用 TensorFlow、PyTorch 等框架。
但如果没有理解背后的数学原理,很容易出现'只会调库不会调脑'的情况。
因此,从零实现一个简单的神经网络,是入门深度学习最重要的一步。
在本文中,我们将一步步手写一个可以学习 XOR(异或)问题 的神经网络。 不借助任何高级框架,只用 Python + NumPy,彻底理解神经网络的计算过程。
本文使用 Python 和 NumPy 从零实现了一个简单的神经网络,用于解决 XOR 异或问题。文章详细讲解了神经元模型、三层网络结构、前向传播、损失函数计算及反向传播的数学原理。提供了完整的训练代码注释,展示了训练过程中的误差下降曲线,并分析了预测结果。通过手动推导链式法则,帮助读者深入理解神经网络的核心机制,为后续学习深度学习框架打下基础。
很多人刚接触深度学习时,会直接用 TensorFlow、PyTorch 等框架。
但如果没有理解背后的数学原理,很容易出现'只会调库不会调脑'的情况。
因此,从零实现一个简单的神经网络,是入门深度学习最重要的一步。
在本文中,我们将一步步手写一个可以学习 XOR(异或)问题 的神经网络。 不借助任何高级框架,只用 Python + NumPy,彻底理解神经网络的计算过程。
神经网络的灵感来源于人脑的神经元(Neuron)结构。 一个神经元会接收输入信号,通过加权求和后再经过激活函数产生输出:
y = f(w1*x1 + w2*x2 + ... + b)
其中:
xi:输入特征wi:权重(决定输入的重要性)b:偏置项(Bias,控制整体偏移)f:激活函数,用来增加非线性能力激活函数相当于'非线性开关',让网络能学习复杂关系。
我们要构建一个最简单的网络:
输入层 (2 个神经元) ↓ 隐藏层 (3 个神经元) ↓ 输出层 (1 个神经元)
输入是两个值(例如 XOR 的两个比特),输出是一个值(0 或 1)。
数据从输入层传向输出层,计算公式如下:
隐藏层输入:
Z1 = XW1 + b1
隐藏层输出(经过激活函数):
A1 = f(Z1)
输出层输入:
Z2 = A1W2 + b2
输出层输出:
A2 = f(Z2)
这里的 A2 就是最终预测结果。
我们使用最常见的平方误差函数:
L = 1/2 * (y - y_hat)^2
其中:
y:真实值y_hat:预测值核心目标:让误差越来越小。 我们通过梯度下降算法(Gradient Descent)调整权重。
梯度计算公式如下:
W = W - η * ∂L/∂W
其中:
η:学习率(learning rate)∂L/∂W:误差对权重的导数(梯度)环境:Python 3.x 依赖库:numpy
import numpy as np
# 固定随机种子,保证每次运行结果一致
np.random.seed(42)
# ========== 1. 数据集定义 ==========
# XOR 问题的输入与输出
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])
# 目标输出(异或结果)
# ========== 2. 定义激活函数 ==========
def sigmoid(x):
"""Sigmoid 激活函数:将值压缩到 (0,1)"""
return 1 / (1 + np.exp(-x))
def sigmoid_derivative(x):
"""Sigmoid 导数,用于反向传播"""
return x * (1 - x)
# ========== 3. 网络结构参数 ==========
input_size = 2 # 输入层节点数
hidden_size = 3 # 隐藏层节点数
output_size = 1 # 输出层节点数
lr = 0.1 # 学习率
epochs = 10000 # 训练迭代次数
# ========== 4. 权重和偏置初始化 ==========
W1 = np.random.uniform(-1, 1, (input_size, hidden_size))
b1 = np.zeros((1, hidden_size))
W2 = np.random.uniform(-1, 1, (hidden_size, output_size))
b2 = np.zeros((1, output_size))
# ========== 5. 开始训练 ==========
for epoch in range(epochs):
# ---- 前向传播 ----
hidden_input = np.dot(X, W1) + b1
hidden_output = sigmoid(hidden_input)
final_input = np.dot(hidden_output, W2) + b2
final_output = sigmoid(final_input)
# ---- 计算误差 ----
error = y - final_output
# ---- 反向传播 ----
d_output = error * sigmoid_derivative(final_output)
d_hidden = np.dot(d_output, W2.T) * sigmoid_derivative(hidden_output)
# ---- 更新权重与偏置 ----
W2 += np.dot(hidden_output.T, d_output) * lr
b2 += np.sum(d_output, axis=0, keepdims=True) * lr
W1 += np.dot(X.T, d_hidden) * lr
b1 += np.sum(d_hidden, axis=0, keepdims=True) * lr
# 每 1000 次打印一次误差
if epoch % 1000 == 0:
loss = np.mean(np.abs(error))
print(f"Epoch {epoch}, Loss: {loss:.4f}")
# ========== 6. 输出结果 ==========
print("\n训练完成后的预测输出:")
print(final_output)
输出结果示例:
Epoch 0, Loss: 0.51
Epoch 1000, Loss: 0.24
Epoch 2000, Loss: 0.12
Epoch 3000, Loss: 0.07
...
Epoch 9000, Loss: 0.03
训练完成后的预测输出:
[[0.03]
[0.97]
[0.96]
[0.04]]
预测效果如下表:
| 输入 | 期望输出 | 实际输出(约) |
|---|---|---|
| [0, 0] | 0 | 0.03 |
| [0, 1] | 1 | 0.97 |
| [1, 0] | 1 | 0.96 |
| [1, 1] | 0 | 0.04 |
模型已经非常成功地学习到了异或逻辑!
import matplotlib.pyplot as plt
losses = []
for epoch in range(epochs):
hidden_input = np.dot(X, W1) + b1
hidden_output = sigmoid(hidden_input)
final_input = np.dot(hidden_output, W2) + b2
final_output = sigmoid(final_input)
error = y - final_output
losses.append(np.mean(np.abs(error)))
d_output = error * sigmoid_derivative(final_output)
d_hidden = np.dot(d_output, W2.T) * sigmoid_derivative(hidden_output)
W2 += np.dot(hidden_output.T, d_output) * lr
b2 += np.sum(d_output, axis=0, keepdims=True) * lr
W1 += np.dot(X.T, d_hidden) * lr
b1 += np.sum(d_hidden, axis=0, keepdims=True) * lr
plt.plot(losses)
plt.title("Loss 下降曲线")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()
输出图像显示误差从 0.5 逐步降到接近 0,说明模型在不断学习。
反向传播(Backpropagation)的关键是链式法则(Chain Rule)。
举例:
∂L/∂W2 = ∂L/∂A2 · ∂A2/∂Z2 · ∂Z2/∂W2
这其实就是在计算:
通过不断循环这个过程,网络就会'自我纠正',让预测结果越来越接近真实值。
本文我们手写了一个完整的神经网络,从:
都进行了详细讲解。
下一步建议:

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online