【算法岗面试】手撕Self-Attention、Multi-head Attention

【算法岗面试】手撕Self-Attention、Multi-head Attention
输入 X: [B, L, d_model

Q/K/V 权重: [d_model, d_model] (合头写法,拆开后每头是 [d_model, d_k])

多头时:先全量 linear 得 [B, L, d_model],再 view/reshape 成 [B, L, num_heads, d_k],再 permute 成 [B, num_heads, L, d_k]

先用简单的Self-Attention捋一遍数据流动的过程:

import torch import torch.nn as nn import torch.nn.functional as F import math class SelfAttention(nn.Module): def __init__(self, embed_dim,d_k): super().__init__() self.embed_dim = embed_dim self.W_Q = nn.Linear(embed_dim, d_k) self.W_K = nn.Linear(embed_dim, d_k) self.W_V = nn.Linear(embed_dim, d_k) def forward(self, x): # x: [batch_size, seq_len, embed_dim] Q = self.W_Q(x) # [B, L, D] K = self.W_K(x) # [B, L, D] V = self.W_V(x) # [B, L, D] # Attention scores: [B, L, L] score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) attn_weights = F.softmax(score, dim=-1) # [B, L, L] att_output = torch.matmul(attn_weights, V) # [B, L, D] return att_output 

然后再拓展到多头:

import torch import torch.nn as nn import math import torch.nn.functional as F class MultiHeadAttention(nn.Module): #定义参数 def __init__(self,embed_dim,head_num): super().__init__() self.embed_dim=embed_dim self.head_num=head_num self.head_dim=embed_dim//head_num #每个头的维度 #定义好Q,K,V矩阵和最后的输出变换矩阵 self.W_Q=nn.Linear(embed_dim, embed_dim) self.W_K=nn.Linear(embed_dim, embed_dim) self.W_V=nn.Linear(embed_dim, embed_dim) self.W_O=nn.Linear(embed_dim, embed_dim)# 注意力输出后再投回原维度 #前向传播 def forward(self,x): # x维度是BLD,batch_size seq_len embed_dim batch_size,seq_len,embed_dim=x.size() # 先全量投影得到了QKV矩阵再拆头 Q = self.W_Q(x) # [B, L, embed_dim] K = self.W_K(x) # [B, L, embed_dim] V = self.W_V(x) # [B, L, embed_dim] #拆分多头 # 方法:先view,再transpose # 拆分成[B, L, num_heads, head_dim],再变成[B, num_heads, L, head_dim] Q=Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2) K=K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2) V=V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2) # 此时shape均为[B, num_heads, L, head_dim] # Q @ K^T:最后两维做乘法 # K.transpose(-2, -1): [B, num_heads, head_dim, L] score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # [B, num_heads, L, L] attn_weights = F.softmax(score, dim=-1) # [B, num_heads, L, L] # 得到每个头的注意力输出 att_output = torch.matmul(attn_weights, V) # [B, num_heads, L, head_dim] # 变回 [B, L, embed_dim] # 先transpose(1,2): [B, L, num_heads, head_dim] # 然后view为 [B, L, num_heads*head_dim] = [B, L, embed_dim] att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) output = self.W_O(att_output) # [B, L, embed_dim] return output

为什么要拆分成 (num_heads, head_dim)

  • 背景:你的输入每个 token 有 embed_dim 维(比如 768)。多头注意力机制本质上是把输入特征维度切成 num_heads 块,每块 head_dim 维,分别做自注意力,然后拼回去。
  • 本质:每个头都是一个“小的单头 self-attention”,但只用一部分特征(head_dim = embed_dim // num_heads)。
  • 举例:如果 embed_dim=768, num_heads=12, 每头 head_dim=64。768=12*64。

原始Q的shape

  • Q = [B, L, embed_dim] (batch, sequence, feature维)

目标:希望得到一个 shape = [B, num_heads, L, head_dim]

这样后续每个head可以独立做 Attention(矩阵乘法/softmax/加权 …)。


为什么用 view(B, L, num_heads, head_dim).transpose(1, 2)

Step 1: view(B, L, num_heads, head_dim)

  • 把最后一维 embed_dim 拆成 num_heads * head_dim
  • 假设 embed_dim=768, num_heads=12, head_dim=64,则拆分成 [B, L, 12, 64]

Step 2: transpose(1, 2)

  • 把 head 数移到序列长度前面
  • [B, L, num_heads, head_dim] --> [B, num_heads, L, head_dim]
  • 这样每个 batch 下,对每个头进行独立计算(更方便并行处理多头)

过程可视化

比如有 Q: [2, 10, 768](batch=2, seq=10, 768维)

  • view(2, 10, 12, 64) -- 12个头,每头64维
  • transpose(1, 2) 得到 (2, 12, 10, 64)

为什么顺序不能交换?

如果你写成 view(B, num_heads, L, head_dim),就完全不对了!因为:

  • 原始数据是按 [B, L, embed_dim] 顺序排列的。
  • view 顺序必须是先序列后特征,特征维度用于拆分
  • 而且在 PyTorch、Tensorflow 中,view 后的数据不会自动乱序分配,只是“重新组织 shape”,不会帮助你把循环顺序换掉。
  • transpose(1, 2) 是在 [B, L, num_heads, head_dim] 基础上,把 head 放到序列之前。

如果互换 num_heads、L 顺序,会把 batch 里的时间步和头搞混,后续 Attention 计算也会错。


为什么最终要 [B, num_heads, L, head_dim]

  • 这样每个头彼此独立,并且都遍历了全部 batch 的序列。
  • 方便后续在每个头上分别做 Attention 计算。

总结口诀

view 拆头之前,总是最后一维(embed_dim)先拆成 (num_heads, head_dim),再用 transpose 把 head 移到 L 之前,得到 [B, num_heads, L, head_dim]。不能交换顺序,因为原始数据排列是 batch, seq, feat,再拆 feat。

Read more

【CS创世SD NAND征文】为无人机打造可靠数据仓:工业级存储芯片CSNP32GCR01-AOW在飞控系统中的应用实践

【CS创世SD NAND征文】为无人机打造可靠数据仓:工业级存储芯片CSNP32GCR01-AOW在飞控系统中的应用实践

一、引言:无人机时代的数据存储挑战 在无人机(UAV)技术飞速发展的今天,其应用范畴早已突破消费级航拍的界限,深度渗透至测绘勘察、基础设施巡检、精准农业、安防监控乃至国防军事等工业级领域。每一次精准的自动巡航、每一帧高清图像的实时图传、每一条飞行轨迹的忠实记录,都离不开飞控系统这颗"大脑"的精密运算。然而,大脑的决策依赖于记忆与学习,而承担这一"记忆"任务的存储单元,其可靠性直接决定了飞行任务的成败与数据的价值。一次意外的数据丢失或存储故障,不仅可能导致珍贵的测绘数据付诸东流,造成重大的经济损失,甚至可能引发严重的飞行安全事故。因此,为无人机飞控系统选择一款高性能、高可靠的存储芯片,已成为行业设计中不可或缺的关键一环。 本文将围绕基于全志MR100主控平台与CS创世SD NAND(具体型号:CSNP32GCR01-AOW)构建的新一代无人机飞控存储方案,深入探讨工业级存储芯片如何为高端无人机赋予稳定、可靠的"数据生命线",助力无人机技术在各个领域发挥更大的价值。 二、应用产品介绍:无人机飞控系统——空中机器人的智能核心

By Ne0inhk
使用Flutter导航组件TabBar、AppBar等为鸿蒙应用程序构建完整的应用导航体系

使用Flutter导航组件TabBar、AppBar等为鸿蒙应用程序构建完整的应用导航体系

📖 前言 导航是移动应用中最重要的功能之一,它帮助用户在不同页面和功能之间切换。Flutter 提供了丰富的导航组件,包括 AppBar、BottomNavigationBar、TabBar、Drawer、Scaffold 等,能够构建完整的应用导航体系。 🎯 导航组件概览 Flutter 提供了以下导航组件: 组件名功能说明适用场景Scaffold页面骨架应用页面基础结构AppBar顶部导航栏页面标题、操作按钮BottomNavigationBar底部导航栏主要功能切换TabBar标签栏页面内内容分类TabBarView标签内容区标签对应的内容Drawer侧边抽屉导航菜单、设置BackButton返回按钮页面返回CloseButton关闭按钮关闭对话框 🏗️ Scaffold 组件 Scaffold 是 Material Design 应用的基础结构,提供了页面骨架。 基础用法 Scaffold( appBar:AppBar( title:Text('标题'),), body:Center( child:Text('页面内容'),),) 完整结构

By Ne0inhk
Flutter 三方库 whatsapp_bot_flutter 自动化社交矩阵鸿蒙多维协同适配指引:横向打通设备生态通信拦截管道、打造多模态实体机器人事件分发-适配鸿蒙 HarmonyOS ohos

Flutter 三方库 whatsapp_bot_flutter 自动化社交矩阵鸿蒙多维协同适配指引:横向打通设备生态通信拦截管道、打造多模态实体机器人事件分发-适配鸿蒙 HarmonyOS ohos

欢迎加入开源鸿蒙跨平台社区:https://openharmonycrossplatform.ZEEKLOG.net Flutter 三方库 whatsapp_bot_flutter 自动化社交矩阵鸿蒙多维协同适配指引:横向打通设备生态通信拦截管道、打造多模态实体机器人事件分发极限制化与消息群发堡垒 前言 在 OpenHarmony 的企业级服务助理、自动化通知分发系统或者是个人智能机器人应用中,如何打通全球主流的即时通讯链路是开发者必须跨越的门槛。whatsapp_bot_flutter 库为 Flutter 开发者提供了一套基于协议或 Web 端桥接的自动化社交机器人方案。本文将带大家在鸿蒙端实战适配该库,探索社交自动化的无限可能。 一、原直线性 / 概念介绍 1.1 基础原理/概念介绍 whatsapp_bot_flutter 的核心逻辑是基于 基于流的会话状态机与加密协议握手 (Encryption Protocol Handshake)。它模拟官方客户端的连接逻辑,通过与指定网关建立受保护的 WebSocket 链路,并实时监听业务事件流(消息、

By Ne0inhk
Formality:原语(primitive)的概念

Formality:原语(primitive)的概念

相关阅读 Formalityhttps://blog.ZEEKLOG.net/weixin_45791458/category_12841971.html?spm=1001.2014.3001.5482         原语(primitive)一般指的是语言内置的基本构件,它们代表了基本的逻辑门和构件,通常用于建模电路的基本功能,例如Verilog中的门级建模会使用and、or等关键词表示单元门。Formality也存在原语的概念,这一般出现在对门级网表进行建模时,本文将对此进行详细解释。         假设以例1所示的RTL代码作为参考设计(可以看出添加了// synopsys sync_set_reset综合指令让Design Compiler将其实现为带同步复位端的D触发器),例2所示的综合后网表作为实现设计,其中data_out_reg原语是一个带同步复位端的D触发器(FDS2)。 // 例1 module ref( input clk, input reset, input data_in, output reg data_

By Ne0inhk