基于LLaMAFactory的模型训练小项目
全流程
一、创建阿里云实例,采用本地编程+远程训练的模式
(一)实例创建
1、前往阿里云云服务器ECS,点击创建实例,选择合适的大区,按需选择包年包月或者按量付费。
2、实例规格按需选择,我选择ecs.gn7i-c8g1.2xlarge(NVIDIA A10,24GB显存,30GB内存)
3、前往镜像中选择ubuntu22.04 64位预装NVIDA GPU驱动镜像。
4、可以按需添加云盘,选用合适的流量计费方式(建议开启CDT接替阶梯计费),启用密钥对,点击创建。
5、至此实例创建完成。
(二)连接cursor(或其它本地编辑器)与实例
1、打开cursor,点击ctrl+shift+P,输入Remote-SSH,下载对应的插件,配置配置文件config:
Host aliyun-llm HostName 8.137.38.***(你的实例公网ip) User ubuntu IdentityFile C:/Users/你的用户名/.ssh/aliyun.pem(密钥低地址) Port 22 IdentitiesOnly yes 2、打开实例的远程连接,检查并添加实例配置:
#检查GPU nvidia-smi #不使用root用户,使用ubuntu用户,确实是否存在,不存在则创建一个,存在的话记得将root的公钥复制给ubuntu用户一份 id ubuntu # 1. 创建 ubuntu 用户(无密码,禁止直接密码登录) adduser ubuntu --gecos ""--disabled-password # 2. 将 ubuntu 加入 sudo 组(获得管理员权限) usermod -aG sudo ubuntu # 3. 创建 .ssh 目录并设置权限 mkdir -p /home/ubuntu/.ssh # 4. 【关键】将 root 的公钥复制给 ubuntu(这样你就能用原来的 .pem 文件登录)cp/root/.ssh/authorized_keys /home/ubuntu/.ssh/authorized_keys # 5. 修复文件所有权和权限(必须!否则 SSH 拒绝登录) chown -R ubuntu:ubuntu /home/ubuntu/.ssh chmod 700 /home/ubuntu/.ssh chmod 600 /home/ubuntu/.ssh/authorized_keys # 6. 验证用户是否存在 id ubuntu 如果显示了root@iZ2vcif30v6v9q9p6a9remZ:~# id ubuntu uid=1000(ubuntu) gid=1000(ubuntu) groups=1000(ubuntu),27(sudo)则操作正确。
3、回到cursor输入Remote-SSH:connect to Host,点击aliyun-llm进行连接,最终得到如下状态:

(三)检查python,cuda的版本,安装所需要的依赖:
#验证版本 python3 --version nvidia-smi #检查驱动 nvcc --version #检查cuda1、创建虚拟环境
sudo apt update sudo apt install -y python3.10-venv python3 -m venv llama_factory_env source llama_factory_env/bin/activate 2、升级pip
pip --version pip install --upgrade pip 3、安装pyTorch
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 4、克隆LLaMAFactory仓库
git clone https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory #安装依赖 pip install -r requirements.txt #安装图形界面 pip install gradio #将仓库注册为包,并自动安装了命令行工具(CLI) pip install -e .#验证是否成功 python -c "import llamafactory; print(llamafactory.__file__)"#验证 CLI 是否可用 llamafactory-cli --help llamafactory-cli train --help5、用终端复用工具tmux启动图形界面
tmux new -s mysession #启动一个新的会话 llamafactory-cli webui #挂起后使用ctrl+b再加D暂时退出,如果出现了报错ERROR: Exception in ASGI application,可以尝试把pydantic降级到2.8.2 tmux attach -t mysession #退出后重新进入这个会话(四)下载模型,准备数据集
1、准备模型
mkdir -p ~/models #创建存放模型的目录#设置镜像源 export HF_ENDPOINT=https://hf-mirror.com # 下载 Qwen2.5-7B 到 models 目录 hf download Qwen/Qwen2.5-7B \ --local-dir~/models/Qwen2.5-7B2、准备数据集
先对数据集格式进行调整清洗,以匹配合适的模板,可以查看LLaMAFactory/data/README.md文件查看对应的模板格式要求。(数据集用的自己做的functioncall模型简易训练集)
# 从本地通过SSH上传 scp -i "你的密钥""C:\Users\YourName\Documents\train.json" ubuntu@公网IP:/home/ubuntu/LLaMA-Factory/data 根据README.md文件创建并填写dataste_info.json注册文件(有就不用创),并将你上传的数据集注册进去,比如:
[dataset_info.json]
{"traindata":{"file_name":"train100fix.json","formatting":"sharegpt","columns":{"messages":"messages"},"tags":{"role_tag":"role","content_tag":"content","user_tag":"user","assistant_tag":"assistant","system_tag":"system"}}}这是将名字为function_call_dataset_combine_100_train.json的数据集进行了注册,file_name的路径默认是~/LLaMA-Factory/data/,这样就可以直接在图形界面进行选择预览,如下:


(五)开始训练
模型选择适合训练设备的(我选的Qwen2.5-7B,没有选择指令模型是为例方便查看训练效果),模板选择适配模型以及数据的(我选择qwen模板),其它的就是一些训练参数如LoRA的秩之类的,量化参数之类的,按照自己的模型以及数据进行配置,然后便是开始训练,数据较少,训练会很快,方便尝试不同的参数搭配。

(六)查看训练效果
进入Chat推理页面查看训练效果

调低温度系数,不需要太多创造力,需要的是准确
以下是原始Qwen2.5-7B的输出:

接下来是我们训练过后的模型的输出:

至此,训练完成,在export合并导出模型,接下来进行对应的前端页面的设计
二、部署模型服务,构建一个简单前端,并使用公开工具api进行实验
(一)通过vllm部署模型
#创建虚拟环境 python3 -m venv vllm_env source vllm_env/bin/activate pip install vllm #部署服务 python -m vllm.entrypoints.openai.api_server \ --model ~/modeltrain \ #模型路径--host 0.0.0.0 \ --port 8000 \ #端口--dtype bfloat16 \ --gpu-memory-utilization 0.85 \ --max-model-len32768 \ --trust-remote-code (二)构建前后端
1、配置文件:config.yaml
# Qwen3 Function Call Web服务器配置文件# 服务器配置 server: port:7210# Web服务器端口 host:"0.0.0.0"# 监听地址 log_level:"info"# 日志级别# vLLM服务配置 vllm: base_url:"http://8.137.38.***:8000"# ← 云端 (公网IP+端口)vLLM 地址 api_key:"fake-api-key"# vLLM 默认不验证,但 FastAPI 可能要求传 model:"/home/ubuntu/modeltrain"# ← 必须和上面 curl 返回的 id 完全一致! timeout:60# 建议设大一点,避免超时# 模型生成参数 model_params: temperature:0.01# 温度参数,控制生成文本的随机性 top_p:0.9# 控制多样性的参数 max_tokens:1024# 最大生成令牌数 function_call:"auto"# 函数调用模式 stop:["<|im_end|"]# 停止词列表# 提示词配置 prompts: system_prompt:| 你是一个AI助手,可以调用函数。以下是可用函数及其详细信息: 1. get_rubbish_category 描述:适用于生活垃圾分类时,判断物品属于哪种类型的垃圾? 参数:{"item":"垃圾名称,用于垃圾分类"} 必填参数:item 2. get_song_information 描述:根据用户提供的歌曲名称,查询歌曲相关信息,包括歌手、时长、专辑名称等。 参数:{"song_name":"歌曲名称"} 必填参数:song_name 3. get_cartoon_information 描述:根据用户提供的动漫标题,查询该动漫的相关信息。 参数:{"title":"动漫标题"} 必填参数:title 请根据用户需求选择合适的函数进行调用。 # CORS配置 cors: allow_origins:["*"]# 允许的源,生产环境应限制为特定域名 allow_credentials: true # 是否允许凭证 allow_methods:["*"]# 允许的HTTP方法 allow_headers:["*"]# 允许的HTTP头# 日志配置 logging: enabled: true # 是否启用日志 level:"INFO"# 日志级别: DEBUG, INFO, WARNING, ERROR, CRITICAL log_file:"logs/web_server.log"# 日志文件路径 max_bytes:10485760# 单个日志文件最大大小(10MB) backup_count:5# 保留的备份文件数量format:"%(asctime)s - %(name)s - %(levelname)s - %(message)s"# 日志格式 log_requests: true # 是否记录API请求 log_responses: true # 是否记录API响应 log_function_calls: true # 是否记录函数调用 log_function_results: true # 是否记录函数执行结果 log_errors: true # 是否记录错误信息2、前端文件web_interface.html
<!DOCTYPE html><html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>Qwen3 Function Call 智能助手</title><style>/* Mac风格的全局样式 */:root {--primary-color:#0071e3;--secondary-color:#86868b;--background-color:#f5f5f7;--card-background:#ffffff;--text-color:#1d1d1f;--text-secondary:#6e6e73;--border-radius: 12px;--shadow:0 4px 20px rgba(0,0,0,0.05);--transition:all0.2s ease-in-out;}*{ margin:0; padding:0; box-sizing: border-box;-webkit-font-smoothing: antialiased;-moz-osx-font-smoothing: grayscale;} body { font-family:-apple-system, BlinkMacSystemFont,'Segoe UI', Roboto,'Helvetica Neue', Arial, sans-serif; background-color: var(--background-color);min-height: 100vh; padding: 20px; color: var(--text-color);}.container {max-width: 800px; margin:0 auto; background: var(--card-background); border-radius: var(--border-radius); padding:0; box-shadow: var(--shadow); overflow: hidden;}/* Mac风格标题栏 */.title-bar { background-color: var(--card-background); padding: 12px 20px; border-bottom: 1px solid rgba(0,0,0,0.05); display: flex; align-items: center; justify-content: space-between;}.title-bar-left { display: flex; align-items: center; gap: 12px;}.window-controls { display: flex; gap: 8px;}.control-dot { width: 12px; height: 12px; border-radius:50%;}.control-dot.red { background-color:#ff3b30;}.control-dot.yellow { background-color:#ffcc00;}.control-dot.green { background-color:#34c759;}.title-text { font-size: 14px; color: var(--text-secondary); margin-left: 10px;}.header { text-align: center; padding: 30px 20px 20px;}.header h1 { color: var(--text-color); font-size: 2rem; font-weight:600; margin-bottom: 8px;}.header p { color: var(--text-secondary); font-size: 16px; line-height:1.5;}.functions-info { background: rgba(0,113,227,0.05); border-radius: var(--border-radius); padding: 20px; margin:0 20px 20px; border: none;}.functions-info h3 { color: var(--text-color); margin-bottom: 15px; font-size: 17px; font-weight:600;}.function-list{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 12px;}.function-item { background: var(--card-background); padding: 15px; border-radius: 8px; border: 1px solid rgba(0,0,0,0.05); transition: var(--transition);}.function-item:hover { transform: translateY(-1px); box-shadow: var(--shadow);}.function-item h4 { color: var(--primary-color); margin-bottom: 8px; font-size: 15px; font-weight:500;}.function-item p { color: var(--text-secondary); font-size: 13px; line-height:1.4;}.examples { background: rgba(0,0,0,0.02); border-radius: var(--border-radius); padding: 20px; margin:0 20px 20px;}.examples h3 { color: var(--text-color); margin-bottom: 12px; font-size: 17px; font-weight:600;}.example-item { background: var(--card-background); padding: 12px 16px; margin: 6px 0; border-radius: 8px; cursor: pointer; transition: var(--transition); border: 1px solid rgba(0,0,0,0.05); font-size: 14px; color: var(--text-secondary);}.example-item:hover { background: rgba(0,113,227,0.05); color: var(--primary-color);}.status { padding: 10px 16px; border-radius: 8px; margin:0 20px 15px; text-align: center; font-size: 14px;}.status.error { background: rgba(255,59,48,0.05); color:#ff3b30; border: 1px solid rgba(255,59,48,0.1);}.status.success { background: rgba(52,199,89,0.05); color:#34c759; border: 1px solid rgba(52,199,89,0.1);}.chat-container {min-height: 400px; border: 1px solid rgba(0,0,0,0.05); border-radius: var(--border-radius); padding: 20px; margin:0 20px 20px; background: rgba(0,0,0,0.01); overflow-y: auto;max-height: 500px;}.message { margin-bottom: 16px; padding: 14px 18px; border-radius: 18px;max-width:85%; word-wrap:break-word; font-size: 15px; line-height:1.5;}.user-message { background: var(--primary-color); color: white; margin-left: auto;}.ai-message { background: var(--card-background); border: 1px solid rgba(0,0,0,0.05); color: var(--text-color);}.function-call { background: rgba(255,204,0,0.05); border: 1px solid rgba(255,204,0,0.2); color:#ff9500; font-family:-apple-system, BlinkMacSystemFont,'Courier New', monospace; font-size: 13px; margin: 8px 0; padding: 12px 16px; border-radius: 8px;}.function-result { background: rgba(52,199,89,0.05); border: 1px solid rgba(52,199,89,0.2); color:#34c759; margin: 8px 0; padding: 12px 16px; border-radius: 8px; font-size: 13px;}.input-area { display: flex; gap: 12px; margin:0 20px 30px; padding:0 4px;}.input-area input{ flex:1; padding: 14px 18px; border: 1px solid rgba(0,0,0,0.1); border-radius: 25px; font-size: 15px; background-color: var(--background-color); color: var(--text-color); transition: var(--transition); font-family:-apple-system, BlinkMacSystemFont, sans-serif;}.input-area input:focus { outline: none; border-color: var(--primary-color); background-color: var(--card-background);}.input-area button { padding: 14px 24px; background: var(--primary-color); color: white; border: none; border-radius: 25px; font-size: 15px; font-weight:500; cursor: pointer; transition: var(--transition);min-width: 80px;}.input-area button:hover:not(:disabled){ background:#0077ed; transform: scale(1.03);}.input-area button:active:not(:disabled){ transform: scale(0.98);}.input-area button:disabled { background:#d1d1d6; cursor:not-allowed;}.loading { text-align: center; color: var(--primary-color); font-style: normal; margin: 20px 0; font-size: 15px;}.loading::after { content:''; animation: dots 1.5s infinite;}@keyframes dots {0%,20%{ content:'';}40%{ content:'.';}60%{ content:'..';}80%,100%{ content:'...';}}/* 滚动条样式 */.chat-container::-webkit-scrollbar { width: 8px;}.chat-container::-webkit-scrollbar-track { background: transparent;}.chat-container::-webkit-scrollbar-thumb { background: rgba(0,0,0,0.1); border-radius: 4px;}.chat-container::-webkit-scrollbar-thumb:hover { background: rgba(0,0,0,0.2);}/* 响应式设计 */@media(max-width: 768px){ body { padding: 10px;}.container {max-width:100%;}.header h1 { font-size:1.75rem;}.function-list{ grid-template-columns: 1fr;}.message {max-width:90%; font-size: 14px;}}</style></head><body><div class="container"><!-- Mac风格标题栏 --><div class="title-bar"><div class="title-bar-left"><div class="window-controls"><div class="control-dot red"></div><div class="control-dot yellow"></div><div class="control-dot green"></div></div><div class="title-text">Qwen3 智能助手</div></div></div><div class="header"><h1>Qwen3 智能助手</h1><p>支持垃圾分类、歌曲信息、动漫信息查询的AI助手</p></div><div class="functions-info"><h3>可用功能</h3><div class="function-list"><div class="function-item"><h4>垃圾分类</h4><p>查询物品属于哪种类型的垃圾</p></div><div class="function-item"><h4>歌曲信息</h4><p>查询歌曲的歌手、时长、专辑等信息</p></div><div class="function-item"><h4>动漫信息</h4><p>查询动漫的类型、语言、上映时间等</p></div></div></div><div class="examples"><h3> 示例问题</h3><div class="example-item" onclick="setInput('鸡蛋壳属于哪种类型的垃圾?')"> 鸡蛋壳属于哪种类型的垃圾? </div><div class="example-item" onclick="setInput('爱在西元前是谁唱的,来自哪张专辑?')"> 爱在西元前是谁唱的,来自哪张专辑? </div><div class="example-item" onclick="setInput('动漫《棋魂》是哪个国家的,什么时候上映的?')"> 动漫《棋魂》是哪个国家的,什么时候上映的? </div></div><div id="status"></div><div class="chat-container"id="chatContainer"><div class="ai-message"><p> 你好!我是Qwen3智能助手,可以帮你查询垃圾分类、歌曲信息和动漫信息。请输入你的问题吧!</p></div></div><div class="input-area"><inputtype="text"id="userInput" placeholder="请输入你的问题..." onkeypress="handleKeyPress(event)"><button id="sendBtn" onclick="sendMessage()">发送</button></div></div><script> const chatContainer = document.getElementById('chatContainer'); const userInput = document.getElementById('userInput'); const sendBtn = document.getElementById('sendBtn'); const statusDiv = document.getElementById('status'); function setInput(text){ userInput.value = text; userInput.focus();} function handleKeyPress(event){if(event.key ==='Enter'){ sendMessage();}} function addMessage(content, isUser = false, isFunction = false, isResult = false){ const messageDiv = document.createElement('div'); messageDiv.className = `message ${isUser ? 'user-message':(isFunction ? 'function-call':(isResult ? 'function-result':'ai-message'))}`;if(typeof content ==='object'){ messageDiv.innerHTML = `<pre>${JSON.stringify(content, null,2)}</pre>`;}else{ messageDiv.innerHTML = content.replace(/\n/g,'<br>');} chatContainer.appendChild(messageDiv); chatContainer.scrollTop = chatContainer.scrollHeight;} function showStatus(message,type='info'){ statusDiv.innerHTML = `<div class="status ${type}">${message}</div>`; setTimeout(()=>{ statusDiv.innerHTML ='';},5000);} function showLoading(){ const loadingDiv = document.createElement('div'); loadingDiv.className ='loading'; loadingDiv.id='loading'; loadingDiv.textContent ='AI正在思考中'; chatContainer.appendChild(loadingDiv); chatContainer.scrollTop = chatContainer.scrollHeight;} function hideLoading(){ const loadingDiv = document.getElementById('loading');if(loadingDiv){ loadingDiv.remove();}}async function sendMessage(){ const question = userInput.value.trim();if(!question){ showStatus('请输入问题!','error');return;}// 禁用输入和按钮 sendBtn.disabled = true; userInput.disabled = true;// 添加用户消息 addMessage(question, true);// 清空输入框 userInput.value ='';// 显示加载状态 showLoading();try{ const response =await fetch('/api/chat',{ method:'POST', headers:{'Content-Type':'application/json',}, body: JSON.stringify({ question: question })}); hideLoading();if(!response.ok){ throw new Error(`HTTP ${response.status}: ${response.statusText}`);} const result =await response.json();if(result.success){// 添加AI响应 if(result.ai_response){ addMessage(`AI回复:${result.ai_response}`);}// 显示函数调用信息 if(result.function_call){ addMessage(`🔧 调用函数:${result.function_call.name}\n📋 参数:${JSON.stringify(result.function_call.arguments, null,2)}`, false, true);}// 显示函数执行结果 if(result.function_result){ addMessage(`⚙️ 执行结果:\n${result.function_result}`, false, false, true);} showStatus('查询成功!','success');}else{ addMessage(`查询失败:${result.error}`); showStatus(`查询失败:${result.error}`,'error');}} catch (error){ hideLoading(); console.error('Error:', error); addMessage(`网络错误:${error.message}`); showStatus(`网络错误:${error.message}`,'error');}finally{// 重新启用输入和按钮 sendBtn.disabled = false; userInput.disabled = false; userInput.focus();}}// 页面加载完成后聚焦输入框 window.onload = function(){ userInput.focus();};</script></body></html>3、后端文件web_server_fastapi.py
import yaml import os import logging from pathlib import Path from logging.handlers import RotatingFileHandler from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware import json import requests import re from typing import Dict, Any, Optional from datetime import datetime # 读取配置文件defload_config(config_path:str=None)-> Dict[str, Any]:"""加载YAML配置文件"""if config_path isNone:# 自动查找配置文件:先尝试当前目录,再尝试脚本所在目录 script_dir = Path(__file__).parent.absolute() possible_paths =[ Path("config.yaml"),# 当前工作目录 script_dir /"config.yaml",# 脚本所在目录 script_dir.parent /"config.yaml",# 项目根目录]for path in possible_paths:if path.exists(): config_path =str(path)breakif config_path isNone:raise FileNotFoundError(f"找不到配置文件,已尝试: {[str(p)for p in possible_paths]}") config_file = Path(config_path)ifnot config_file.exists():raise FileNotFoundError(f"配置文件不存在: {config_path}")withopen(config_file,"r", encoding="utf-8")as f: config = yaml.safe_load(f)return config # 加载配置try: config = load_config()except Exception as e:print(f" 配置文件加载失败: {e}")print(" 使用默认配置继续运行...")# 使用默认配置 config ={"server":{"port":7210,"host":"0.0.0.0","log_level":"info"},"vllm":{"base_url":"http://localhost:11221","api_key":"fake-api-key","model":"qwen3-4b-functioncall","timeout":30},"model_params":{"temperature":0.1,"top_p":0.9,"max_tokens":512,"function_call":"auto","stop":["<|im_end|"]},"prompts":{"system_prompt":""" 你是一个AI助手,可以调用函数。以下是可用函数及其详细信息: 1. get_rubbish_category 描述:适用于生活垃圾分类时,判断物品属于哪种类型的垃圾? 参数:{"item": "垃圾名称,用于垃圾分类"} 必填参数:item 2. get_song_information 描述:根据用户提供的歌曲名称,查询歌曲相关信息,包括歌手、时长、专辑名称等。 参数:{"song_name": "歌曲名称"} 必填参数:song_name 3. get_cartoon_information 描述:根据用户提供的动漫标题,查询该动漫的相关信息。 参数:{"title": "动漫标题"} 必填参数:title 请根据用户需求选择合适的函数进行调用。 """},"cors":{"allow_origins":["*"],"allow_credentials":True,"allow_methods":["*"],"allow_headers":["*"]}}# 从配置中提取常用参数 VLLM_BASE_URL = config["vllm"]["base_url"] VLLM_API_KEY = config["vllm"]["api_key"] WEB_PORT = config["server"]["port"] WEB_HOST = config["server"]["host"] LOG_LEVEL = config["server"]["log_level"] SYSTEM_PROMPT = config["prompts"]["system_prompt"]# 日志配置 LOG_CONFIG = config.get("logging",{"enabled":True,"level":"INFO","log_file":"logs/web_server.log","max_bytes":10485760,"backup_count":5,"format":"%(asctime)s - %(name)s - %(levelname)s - %(message)s","log_requests":True,"log_responses":True,"log_function_calls":True,"log_function_results":True,"log_errors":True})# 初始化日志系统defsetup_logging():"""设置日志系统"""ifnot LOG_CONFIG.get("enabled",True):# 如果日志未启用,设置一个空处理器 logging.basicConfig(level=logging.WARNING)return# 创建日志目录 log_file = LOG_CONFIG.get("log_file","logs/web_server.log") log_dir = os.path.dirname(log_file)if log_dir andnot os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True)# 配置日志格式 log_format = LOG_CONFIG.get("format","%(asctime)s - %(name)s - %(levelname)s - %(message)s") log_level =getattr(logging, LOG_CONFIG.get("level","INFO").upper(), logging.INFO)# 创建logger logger = logging.getLogger("web_server") logger.setLevel(log_level)# 清除已有的处理器 logger.handlers.clear()# 文件处理器(带轮转)if log_file: file_handler = RotatingFileHandler( log_file, maxBytes=LOG_CONFIG.get("max_bytes",10485760), backupCount=LOG_CONFIG.get("backup_count",5), encoding='utf-8') file_handler.setLevel(log_level) file_formatter = logging.Formatter(log_format) file_handler.setFormatter(file_formatter) logger.addHandler(file_handler)# 控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(log_level) console_formatter = logging.Formatter(log_format) console_handler.setFormatter(console_formatter) logger.addHandler(console_handler)return logger # 初始化日志 logger = setup_logging()# 创建FastAPI应用defcreate_app()-> FastAPI: app = FastAPI(title="Qwen3 Function Call API", version="1.0")# 配置CORS app.add_middleware( CORSMiddleware, allow_origins=config["cors"]["allow_origins"], allow_credentials=config["cors"]["allow_credentials"], allow_methods=config["cors"]["allow_methods"], allow_headers=config["cors"]["allow_headers"],)return app app = create_app()# 工具函数defget_rubbish_category(item):"""垃圾分类查询(真实API)"""try: url =f"https://api.timelessq.com/garbage?keyword={item}" response = requests.request("GET", url)# 检查响应状态if response.status_code !=200:returnf"API请求失败,状态码: {response.status_code}"# 尝试解析JSONtry: data = response.json()except json.JSONDecodeError:returnf"API返回的不是有效JSON格式"# 检查数据结构ifnotisinstance(data,dict)or'data'notin data:returnf"API返回数据格式异常"# 处理嵌套的data结构 data_content = data['data']ifisinstance(data_content,dict)and'data'in data_content:# 这是嵌套结构,取内层的data items_list = data_content['data']elifisinstance(data_content,list):# 这是直接的列表结构 items_list = data_content else:returnf"API返回的data字段格式异常"ifnot items_list:returnf"未找到关于 '{item}' 的垃圾分类信息" output_str_list =[]for garbage_item in items_list:ifisinstance(garbage_item,dict):# 尝试不同的字段名组合 name = garbage_item.get('name','') category = garbage_item.get('category')or garbage_item.get('categroy','')if name and category: output_str_list.append(f"{name}: {category}")return'\n'.join(output_str_list)if output_str_list elsef"无法解析 '{item}' 的垃圾分类信息"except Exception as e:returnf"查询失败: {e}"defget_song_information(song_name):"""歌曲信息查询(真实API)"""try: url =f"https://api.timelessq.com/music/tencent/search?keyword={song_name}" response = requests.request("GET", url) song_infor = response.json()['data']['list'][0] singer =''ifnot song_infor['singer']else song_infor['singer'][0]['name']returnf"歌曲: {song_name}\n歌手: {singer}\n时长: {song_infor['interval']}秒\n专辑名称: {song_infor['albumname']}"except Exception as e:returnf"查询失败: {e}"defget_cartoon_information(title):"""动漫信息查询(真实API)"""try: url =f"https://api.timelessq.com/bangumi?title={title}" response = requests.request("GET", url) data = response.json()['data'][0]returnf"标题: {data['title']}\n类型:{data['type']}\n语言:{data['lang']}\n出品方:{data['officialSite']}\n上映时间:{data['begin']}\n完结事件:{data['end']}"except Exception as e:returnf"查询失败: {e}"# 工具映射表 tool_map ={"get_rubbish_category": get_rubbish_category,"get_song_information": get_song_information,"get_cartoon_information": get_cartoon_information }# 参数映射表 - 处理模型输出与函数实际参数的映射 param_mapping ={"get_rubbish_category":{"item":"item","rubbish_name":"item","query":"item","name":"item","keyword":"item"},"get_song_information":{"song_name":"song_name","name":"song_name","query":"song_name"},"get_cartoon_information":{"title":"title","cartoon_name":"title","name":"title","query":"title"}}defmap_function_arguments(function_name:str, raw_arguments: Dict[str, Any])-> Dict[str, Any]:"""将模型输出的参数映射到函数实际需要的参数"""if function_name notin param_mapping:return raw_arguments mapping = param_mapping[function_name] mapped_args ={}for raw_param, actual_param in mapping.items():if raw_param in raw_arguments: mapped_args[actual_param]= raw_arguments[raw_param]# 如果没有找到映射,保留原始参数ifnot mapped_args: mapped_args = raw_arguments return mapped_args defextract_function_call(response_text:str)-> Dict[str, Any]:"""从模型响应中提取function call - 采用对齐的格式""" patterns =[# 格式1: {"name": "func", "arguments": {...}}r'<function_call>\s*({.*?})\s*</function_call>',# 格式2: func_name\n{...}r'<function_call>\s*(\w+)\s*\n\s*({.*?})\s*</function_call>',# 格式3: 只有函数名r'<function_call>\s*(\w+)\s*</function_call>']for i, pattern inenumerate(patterns,1): match = re.search(pattern, response_text, re.DOTALL)if match:try:if i ==1:# 标准JSON格式 function_call = json.loads(match.group(1))# 处理不同的字段名称格式if"function"in function_call and"name"notin function_call: function_call["name"]= function_call.pop("function")return{"success":True,"function_call": function_call}elif i ==2:# 分行格式 function_name = match.group(1) arguments = json.loads(match.group(2)) function_call ={"name": function_name,"arguments": arguments}return{"success":True,"function_call": function_call}elif i ==3:# 只有函数名 function_name = match.group(1) function_call ={"name": function_name,"arguments":{}}return{"success":True,"function_call": function_call}except json.JSONDecodeError as e:continuereturn{"success":False,"error":"未找到有效的function_call格式"}defcheck_vllm_server():"""检查vLLM服务器状态"""try: headers ={"Authorization":f"Bearer {VLLM_API_KEY}"} response = requests.get(f"{VLLM_BASE_URL}/v1/models", headers=headers, timeout=5) is_ok = response.status_code ==200ifnot is_ok and LOG_CONFIG.get("log_errors",True): logger.warning(f"vLLM服务器检查失败 | 状态码: {response.status_code}")return is_ok except Exception as e:if LOG_CONFIG.get("log_errors",True): logger.warning(f"vLLM服务器检查异常 | 错误: {str(e)}")returnFalse# API端点@app.get('/', response_class=HTMLResponse, tags=["Web Interface"])asyncdefindex():"""主页 - 返回HTML界面"""# 尝试多个可能的HTML文件路径 script_dir = Path(__file__).parent.absolute() possible_html_paths =[ script_dir /'web_interface.html',# 脚本所在目录 script_dir /'static'/'web_interface.html',# static目录 Path('web_interface.html'),# 当前工作目录] html_file_path =Nonefor path in possible_html_paths:if path.exists(): html_file_path = path breakif html_file_path isNone:raise HTTPException( status_code=404, detail=f"错误:找不到web_interface.html文件,已尝试: {[str(p)for p in possible_html_paths]}")try:withopen(html_file_path,'r', encoding='utf-8')as f:return HTMLResponse(content=f.read())except Exception as e:raise HTTPException(status_code=500, detail=f"读取HTML文件失败: {str(e)}")@app.post('/api/chat', tags=["Chat API"])asyncdefchat(request: Request):"""处理聊天请求""" request_id = datetime.now().strftime("%Y%m%d%H%M%S%f") client_ip = request.client.host if request.client else"unknown"try: data =await request.json()ifnot data or'question'notin data: error_msg ='请求格式错误:缺少question字段'if LOG_CONFIG.get("log_errors",True): logger.warning(f"[{request_id}] {error_msg} | IP: {client_ip}")return JSONResponse( status_code=400, content={'success':False,'error': error_msg }) user_question = data['question'].strip()ifnot user_question: error_msg ='问题不能为空'if LOG_CONFIG.get("log_errors",True): logger.warning(f"[{request_id}] {error_msg} | IP: {client_ip}")return JSONResponse( status_code=400, content={'success':False,'error': error_msg })# 记录请求if LOG_CONFIG.get("log_requests",True): logger.info(f"[{request_id}] API请求 | IP: {client_ip} | 问题: {user_question}")# 检查vLLM服务器ifnot check_vllm_server(): error_msg ='vLLM服务器无法连接,请确保服务正在运行'if LOG_CONFIG.get("log_errors",True): logger.error(f"[{request_id}] {error_msg}")return JSONResponse( status_code=503, content={'success':False,'error': error_msg })# 构建请求 messages =[{"role":"system","content": SYSTEM_PROMPT},{"role":"user","content": user_question}]# 函数定义 functions =[{"name":"get_rubbish_category","description":"适用于生活垃圾分类时,判断物品属于哪种类型的垃圾?","parameters":{"type":"object","properties":{"item":{"type":"string","description":"垃圾名称,用于垃圾分类"}},"required":["item"]}},{"name":"get_song_information","description":"根据用户提供的歌曲名称,查询歌曲相关信息,包括歌手、时长、专辑名称等。","parameters":{"type":"object","properties":{"song_name":{"type":"string","description":"歌曲名称"}},"required":["song_name"]}},{"name":"get_cartoon_information","description":"根据用户提供的动漫标题,查询该动漫的相关信息。","parameters":{"type":"object","properties":{"title":{"type":"string","description":"动漫标题"}},"required":["title"]}}]# 调用vLLM API api_url =f"{VLLM_BASE_URL}/v1/chat/completions" headers ={"Authorization":f"Bearer {VLLM_API_KEY}","Content-Type":"application/json"}# 从配置中构建payload payload ={"model": config["vllm"]["model"],"messages": messages,"functions": functions,"function_call": config["model_params"]["function_call"],"temperature": config["model_params"]["temperature"],"top_p": config["model_params"]["top_p"],"max_tokens": config["model_params"]["max_tokens"],"stop": config["model_params"]["stop"]}# 记录vLLM API调用if LOG_CONFIG.get("log_requests",True): logger.debug(f"[{request_id}] 调用vLLM API | URL: {api_url} | 模型: {config['vllm']['model']}") response = requests.post(api_url, json=payload, headers=headers, timeout=config["vllm"]["timeout"])if response.status_code !=200: error_msg =f'vLLM API错误 {response.status_code}: {response.text}'if LOG_CONFIG.get("log_errors",True): logger.error(f"[{request_id}] {error_msg}")return JSONResponse( status_code=500, content={'success':False,'error': error_msg }) result = response.json()if"choices"notin result orlen(result["choices"])==0: error_msg ='vLLM API返回格式错误'if LOG_CONFIG.get("log_errors",True): logger.error(f"[{request_id}] {error_msg} | 响应: {json.dumps(result, ensure_ascii=False)[:500]}")return JSONResponse( status_code=500, content={'success':False,'error': error_msg }) assistant_message = result["choices"][0]["message"]["content"]# 记录模型响应if LOG_CONFIG.get("log_responses",True): logger.info(f"[{request_id}] 模型响应 | 内容: {assistant_message[:200]}{'...'iflen(assistant_message)>200else''}")# 提取function call function_call_result = extract_function_call(assistant_message)ifnot function_call_result["success"]:# 如果没有function call,返回原始回复if LOG_CONFIG.get("log_responses",True): logger.info(f"[{request_id}] 无函数调用 | 返回原始回复")return JSONResponse({'success':True,'ai_response': assistant_message,'function_call':None,'function_result':None}) function_call = function_call_result["function_call"] function_name = function_call.get("name") function_args = function_call.get("arguments",{})# 记录函数调用if LOG_CONFIG.get("log_function_calls",True): logger.info(f"[{request_id}] 函数调用 | 函数名: {function_name} | 参数: {json.dumps(function_args, ensure_ascii=False)}")if function_name notin tool_map: error_msg =f'未知的函数: {function_name}'if LOG_CONFIG.get("log_errors",True): logger.warning(f"[{request_id}] {error_msg}")return JSONResponse( status_code=400, content={'success':False,'error': error_msg })# 映射参数 mapped_args = map_function_arguments(function_name, function_args)# 执行函数try: start_time = datetime.now() tool_result = tool_map[function_name](**mapped_args) execution_time =(datetime.now()- start_time).total_seconds()# 记录函数执行结果if LOG_CONFIG.get("log_function_results",True): result_preview =str(tool_result)[:200]+('...'iflen(str(tool_result))>200else'') logger.info(f"[{request_id}] 函数执行成功 | 函数名: {function_name} | 执行时间: {execution_time:.3f}s | 结果: {result_preview}")# 记录完整响应if LOG_CONFIG.get("log_responses",True): logger.info(f"[{request_id}] 请求完成 | 成功")return JSONResponse({'success':True,'ai_response': assistant_message,'function_call': function_call,'function_result': tool_result })except Exception as e: error_msg =f'函数执行失败: {str(e)}'if LOG_CONFIG.get("log_errors",True): logger.error(f"[{request_id}] {error_msg} | 函数名: {function_name} | 参数: {json.dumps(mapped_args, ensure_ascii=False)}", exc_info=True)return JSONResponse( status_code=500, content={'success':False,'error': error_msg })except Exception as e: error_msg =f'服务器内部错误: {str(e)}'if LOG_CONFIG.get("log_errors",True): logger.error(f"[{request_id}] {error_msg}", exc_info=True)return JSONResponse( status_code=500, content={'success':False,'error': error_msg })@app.get('/api/status', tags=["System API"])asyncdefstatus():"""检查服务状态""" vllm_status = check_vllm_server()return JSONResponse({'server_status':'running','vllm_status':'connected'if vllm_status else'disconnected','functions':list(tool_map.keys()),'config':{'vllm_base_url': VLLM_BASE_URL,'web_port': WEB_PORT }})# 启动服务器的主函数if __name__ =='__main__':import uvicorn print(" 启动Qwen3 Function Call Web服务器...")print("="*50)print(f" vLLM服务地址: {VLLM_BASE_URL}")print(f" Web服务器端口: {WEB_PORT}")print(f" 支持的函数: {list(tool_map.keys())}")print(f" 模型参数: temperature={config['model_params']['temperature']}, top_p={config['model_params']['top_p']}")# 日志配置信息if LOG_CONFIG.get("enabled",True): log_file = LOG_CONFIG.get("log_file","logs/web_server.log") log_level = LOG_CONFIG.get("level","INFO")print(f" 日志功能: 已启用")print(f" 日志文件: {log_file}")print(f" 日志级别: {log_level}")print(f" 记录请求: {LOG_CONFIG.get('log_requests',True)}")print(f" 记录响应: {LOG_CONFIG.get('log_responses',True)}")print(f" 记录函数调用: {LOG_CONFIG.get('log_function_calls',True)}")print(f" 记录函数结果: {LOG_CONFIG.get('log_function_results',True)}")if logger: logger.info("="*50) logger.info("Web服务器启动") logger.info(f"vLLM服务地址: {VLLM_BASE_URL}") logger.info(f"Web服务器端口: {WEB_PORT}") logger.info(f"支持的函数: {list(tool_map.keys())}")else:print(f" 日志功能: 已禁用")# 检查vLLM服务器if check_vllm_server():print(" vLLM服务器连接正常")if logger: logger.info("vLLM服务器连接正常")else:print(" 警告: vLLM服务器无法连接")print(" 请确保vLLM服务正在运行:")print(f" 地址: {VLLM_BASE_URL}")print(f" 模型: {config['vllm']['model']}")if logger: logger.warning(f"vLLM服务器无法连接: {VLLM_BASE_URL}")print("="*50)print(" Web服务器启动中...")print(f" 访问地址: http://localhost:{WEB_PORT}")print("="*50) uvicorn.run(app, host=WEB_HOST, port=WEB_PORT, log_level=LOG_LEVEL)启动服务,进行效果测试
(三)测试效果

后端定义了三个工具,分别是垃圾分类,歌曲查询,动漫查询。现在输入问题:告白气球是谁唱的?

获取执行结果。
