跳到主要内容 斯坦福 Octopus v2:端侧运行大模型性能超越 GPT-4 | 极客日志
Python AI 算法
斯坦福 Octopus v2:端侧运行大模型性能超越 GPT-4 斯坦福大学发布 Octopus v2 开源语言模型,参数量 20 亿,专为 Android 设备优化。该模型在端侧运行,无需云端依赖,推理速度比 Llama7B+RAG 快 36 倍,准确率媲美 GPT-4。采用独特的函数 token 策略,支持复杂场景下的函数调用生成。训练基于 Google Gemma-2B,使用 AdamW 优化器及 LoRA 微调技术。数据集包含相关查询、不相关查询及 Gemini 验证。此模型标志着设备端 AI 智能体时代的到来,为边缘计算提供了高效解决方案。
斯坦福 Octopus v2:端侧运行大模型性能超越 GPT-4
在大模型落地应用的过程中,端侧 AI(Edge AI)是非常重要的一个方向。近日,斯坦福大学研究人员推出的 Octopus v2 火了,受到了开发者社区的极大关注,模型一夜下载量超 2k。
模型概述 Octopus-V2-2B 是一个拥有 20 亿参数的开源语言模型,专为 Android API 量身定制,旨在在 Android 设备上无缝运行,并将实用性扩展到从 Android 系统管理到多个设备的编排等各种应用程序。
该模型可以在智能手机、汽车、个人电脑等端侧运行,在准确性和延迟方面超越了 GPT-4,并将上下文长度减少了 95%。此外,Octopus v2 比 Llama7B + RAG 方案快 36 倍。这标志着设备端 AI 智能体的时代可能已经到来。
通常,检索增强生成 (RAG) 方法需要对潜在函数参数进行详细描述(有时需要多达数万个输入 token)。基于此,Octopus-V2-2B 在训练和推理阶段引入了独特的函数 token 策略,不仅使其能够达到与 GPT-4 相当的性能水平,而且还显著提高了推理速度,超越了基于 RAG 的方法,这使得它对边缘计算设备特别有利。
Octopus-V2-2B 能够在各种复杂场景中生成单独的、嵌套的和并行的函数调用,这对于构建复杂的智能体任务至关重要。
数据集构建 为了在训练、验证和测试阶段采用高质量数据集,特别是实现高效训练,研究团队用三个关键阶段创建数据集:
生成相关的查询及其关联的函数调用参数 :确保模型学习正确的映射关系。
由适当的函数组件生成不相关的查询 :增加负样本,提高模型的抗干扰能力。
通过 Google Gemini 实现二进制验证支持 :利用强大的基座模型对生成的数据进行自动化校验,保证数据质量。
研究团队编写了 20 个 Android API 描述,用于训练模型。下面是一个 Android API 描述示例:
def get_trending_news (category=None , region='US' , language='en' , max_results=5 ):
"""
Fetches trending news articles based on category, region, and language.
Parameters:
- category (str, optional): News category to filter by, by default use None for all categories. Optional to provide.
- region (str, optional): ISO 3166-1 alpha-2 country code for region-specific news, by default, uses 'US'. Optional to provide.
- language (str, optional): ISO 636-1 language code for article language, by default uses 'en'. Optional to provide.
- max_results (int, optional): Maximum number of articles to return, by default, uses 5. Optional to provide.
Returns:
- list [str]: A list of strings, each representing an article. Each string contains the article's heading and URL.
"""
模型开发与训练 该研究采用 Google Gemma-2B 模型作为框架中的预训练模型,并采用两种不同的训练方法:完整模型训练和 LoRA 模型训练。
完整模型训练 在完整模型训练中,该研究使用 AdamW 优化器,学习率设置为 5e-5,warm-up 的 step 数设置为 10,采用线性学习率调度器。这种配置有助于模型在大规模数据上稳定收敛。
LoRA 模型训练 LoRA 模型训练采用与完整模型训练相同的优化器和学习率配置。LoRA rank 设置为 16,并将 LoRA 应用于以下模块:q_proj、k_proj、v_proj、o_proj、up_proj、down_proj。其中,LoRA alpha 参数设置为 32。这种方法允许在不更新所有参数的情况下微调模型,大大降低了显存需求和训练成本。
对于两种训练方法,epoch 数均设置为 3。这种设置平衡了过拟合风险和训练充分性。
代码实现示例 使用以下代码,就可以在单个 GPU 上运行 Octopus-V2-2B 模型。请注意,实际部署时需要根据硬件环境调整 device_map 和 torch_dtype。
from transformers import AutoTokenizer, GemmaForCausalLM
import torch
import time
def inference (input_text ):
start_time = time.time()
input_ids = tokenizer(input_text, return_tensors="pt" ).to(model.device)
input_length = input_ids["input_ids" ].shape[1 ]
outputs = model.generate(
input_ids=input_ids["input_ids" ],
max_length=1024 ,
do_sample=False
)
generated_sequence = outputs[:, input_length:].tolist()
res = tokenizer.decode(generated_sequence[0 ])
end_time = time.time()
return {"output" : res, "latency" : end_time - start_time}
model_id = "NexaAIDev/Octopus-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = GemmaForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
input_text = "Take a selfie for me with front camera"
nexa_query = f"Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\nQuery: {input_text} \n\nResponse:"
start_time = time.time()
print ("nexa model result:\n" , inference(nexa_query))
print ("latency:" , time.time() - start_time, "s" )
性能评估 Octopus-V2-2B 在基准测试中表现出卓越的推理速度。在单个 A100 GPU 上,它比「Llama7B + RAG 解决方案」快 36 倍。此外,与依赖集群 A100/H100 GPU 的 GPT-4-turbo 相比,Octopus-V2-2B 速度提高了 168%。这种效率突破归功于 Octopus-V2-2B 的函数性 token 设计,减少了不必要的上下文开销。
Octopus-V2-2B 不仅在速度上表现出色,在准确率上也表现出色。在函数调用准确率上,它超越「Llama7B + RAG 方案」31%。同时,Octopus-V2-2B 实现了与 GPT-4 和 RAG + GPT-3.5 相当的函数调用准确率。这意味着在保持高性能的同时,并未牺牲任务的准确性。
总结与展望 Octopus v2 的发布展示了端侧大模型在功能性和效率上的巨大潜力。通过将参数量控制在 20 亿级别,并结合特定的函数调用优化,该模型成功在资源受限的设备上实现了接近云端大模型的能力。这对于隐私保护、低延迟响应以及离线场景下的 AI 应用具有重要意义。未来,随着端侧算力的提升,此类模型有望成为移动设备和物联网终端的标准配置。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online