from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"# the device to load the model onto# Now you do not need to add "trust_remote_code=True"
model = AutoModelForCausalLM.from_pretrained(
"Qwen1.5-0.5B-Chat", # 修改大模型位置
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen1.5-0.5B-Chat") # 修改大模型位置# Instead of using model.chat(), we directly use model.generate()# But you need to use tokenizer.apply_chat_template() to format your inputs as shown below# 改成中文提问
prompt = "给我简单介绍一下大型语言模型。"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
# Directly use generate() and tokenizer.decode() to get the output.# Use `max_new_tokens` to control the maximum output length.
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids inzip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 打印一下助手回复的内容print(response)
注:这里要修改对大模型的引入路径。
执行 py 文件:
python qwen.py
如果发现缺依赖,安装:
conda install conda-forge::accelerate
再重新执行,这时候就可以看到模型的回复了。
如果想要流式的,可以改为以下代码:
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
device = "cuda"# the device to load the model onto# Now you do not need to add "trust_remote_code=True"
model = AutoModelForCausalLM.from_pretrained(
"Qwen1.5-0.5B-Chat",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen1.5-0.5B-Chat")
# Instead of using model.chat(), we directly use model.generate()# But you need to use tokenizer.apply_chat_template() to format your inputs as shown below# 改成中文提问
prompt = "给我简单介绍一下大型语言模型。"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
# Directly use generate() and tokenizer.decode() to get the output.# Use `max_new_tokens` to control the maximum output length.
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""for new_text in streamer:
generated_text += new_text
print(generated_text)
部署为 API
借助 FastAPI 和 Uvicorn 来实现 API 接口的支持。
FastAPI:一个用于构建 API 的现代、快速(高性能)的 web 框架,使用 Python 并基于标准的 Python 类型提示。
Uvicorn:一个快速的 ASGI(Asynchronous Server Gateway Interface)服务器,用于构建异步 Web 服务。
快速启动
首先安装这两个库:
conda install fastapi uvicorn
新建 web.py,写入代码:
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from argparse import ArgumentParser
app = FastAPI()
# 支持 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
@app.get("/")asyncdefindex():
return {"message": "Hello World"}
def_get_args():
parser = ArgumentParser()
parser.add_argument('--server-port',
type=int,
default=8000,
help='Demo server port.')
parser.add_argument('--server-name',
type=str,
default='127.0.0.1',
help='Demo server name. Default: 127.0.0.1, which is only visible from the local computer.'' If you want other computers to access your server, use 0.0.0.0 instead.',
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _get_args()
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
运行 web.py:
python web.py
请求接口,可以看到返回 hello world。
接入大模型测试
把前面 Qwen 的代码写进去:
from contextlib import asynccontextmanager
import torch
import uvicorn
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from argparse import ArgumentParser
from typing importList, Literal, Optional, Unionfrom pydantic import BaseModel, Field
@asynccontextmanagerasyncdeflifespan(app: FastAPI): # collects GPU memoryyieldif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
# 支持 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
classChatMessage(BaseModel):
role: Literal['user', 'assistant', 'system']
content: Optional[str]
classDeltaMessage(BaseModel):
role: Optional[Literal['user', 'assistant', 'system']] = None
content: Optional[str] = NoneclassChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
stream: Optional[bool] = FalseclassChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal['stop', 'length']
classChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal['stop', 'length']]
classChatCompletionResponse(BaseModel):
model: strobject: Literal['chat.completion', 'chat.completion.chunk']
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/")asyncdefindex():
return {"message": "Hello World"}
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)asyncdefcreate_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
# 简单的错误校验if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
text = tokenizer.apply_chat_template(
request.messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Directly use generate() and tokenizer.decode() to get the output.# Use `max_new_tokens` to control the maximum output length.
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids inzip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
def_get_args():
parser = ArgumentParser()
parser.add_argument(
'-c',
'--checkpoint-path',
type=str,
default='Qwen1.5-0.5B-Chat',
help='Checkpoint name or path, default to %(default)r',
)
parser.add_argument('--server-port',
type=int,
default=8000,
help='Demo server port.')
parser.add_argument('--server-name',
type=str,
default='127.0.0.1',
help='Demo server name. Default: 127.0.0.1, which is only visible from the local computer.'' If you want other computers to access your server, use 0.0.0.0 instead.',
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _get_args()
# Now you do not need to add "trust_remote_code=True"
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path)
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
调用下接口,可以看到接口已经返回内容了。
增加流式支持
需要安装 sse_starlette 的库,来支持流式的返回:
pip install sse_starlette
安装完把代码再改一下,通过参数 stream 来判断是否流式返回:
from contextlib import asynccontextmanager
from threading import Thread
import torch
import uvicorn
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BatchEncoding
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from argparse import ArgumentParser
from typing importList, Literal, Optional, Unionfrom pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
@asynccontextmanagerasyncdeflifespan(app: FastAPI): # collects GPU memoryyieldif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
# 支持 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
classChatMessage(BaseModel):
role: Literal['user', 'assistant', 'system']
content: Optional[str]
classDeltaMessage(BaseModel):
role: Optional[Literal['user', 'assistant', 'system']] = None
content: Optional[str] = NoneclassChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
stream: Optional[bool] = FalseclassChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal['stop', 'length']
classChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal['stop', 'length']]
classChatCompletionResponse(BaseModel):
model: strobject: Literal['chat.completion', 'chat.completion.chunk']
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/")asyncdefindex():
return {"message": "Hello World"}
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)asyncdefcreate_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
# 简单的错误校验if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
text = tokenizer.apply_chat_template(
request.messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
if request.stream:
generate = predict(model_inputs, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
# Directly use generate() and tokenizer.decode() to get the output.# Use `max_new_tokens` to control the maximum output length.
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids inzip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
asyncdefpredict(model_inputs: BatchEncoding, model_id: str):
global model, tokenizer
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
model_inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield"{}".format(chunk.model_dump_json(exclude_unset=True))
thread.start()
for new_text in streamer:
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield"{}".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield"{}".format(chunk.model_dump_json(exclude_unset=True))
yield'[DONE]'def_get_args():
parser = ArgumentParser()
parser.add_argument(
'-c',
'--checkpoint-path',
type=str,
default='Qwen1.5-0.5B-Chat',
help='Checkpoint name or path, default to %(default)r',
)
parser.add_argument('--server-port',
type=int,
default=8000,
help='Demo server port.')
parser.add_argument('--server-name',
type=str,
default='127.0.0.1',
help='Demo server name. Default: 127.0.0.1, which is only visible from the local computer.'' If you want other computers to access your server, use 0.0.0.0 instead.',
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _get_args()
# Now you do not need to add "trust_remote_code=True"
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path)
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
再调用接口,流式返回也支持了。
常见问题与优化
显存不足:如果部署过程中遇到 OOM(Out Of Memory),可以尝试减小 batch_size 或加载量化版本模型(如 INT4)。
依赖冲突:确保虚拟环境中只安装必要的包,避免全局环境干扰。使用 conda list 查看当前环境包列表。
网络问题:下载模型或依赖时若速度慢,请确认 Conda 源已切换至国内镜像。
端口占用:启动 API 服务时若报错端口被占用,可修改 --server-port 参数指定其他端口。
完成上述步骤后,即可在本地 Windows 环境下成功运行 Qwen1.5 大模型并提供 API 服务。