部署前的准备
Llama 3.1 是一个资源需求较高的模型,因此在部署之前,首先要确保拥有合适的硬件环境。按照文档中的要求,选择了 Llama 3.1 8B 版本进行测试。8B 模型对 GPU 显存的需求为 16GB,因此选择了 NVIDIA RTX 4090 作为实例,并且配置了 60GB 的数据硬盘容量,来满足下载模型和存储相关文件的需求。
在云平台的控制台创建 GPU 云实例非常简单,整个流程仅需几分钟的时间。在实例创建页面中,能够灵活选择 GPU 的数量和型号,平台还提供了便捷的镜像选择功能,省去了大量的环境配置工作。选择了预装 PyTorch 2.4.0 的镜像,确保在后续的部署过程中不需要手动安装繁杂的依赖环境。
创建实例
进入实例管理控制台,点击创建实例。
进入创建页面后,首先在实例配置中选择付费类型,一般短期需求可以选择按量付费或者包日,长期需求可以选择包月套餐。
其次选择 GPU 数量和需求的 GPU 型号,首次创建实例推荐选择:按量付费–GPU 数量 1–NVIDIA-GeForce-RTX-4090,该配置为 60GB 内存,24GB 的显存(本次测试的 LLaMA3.1 8B 版本至少需要 GPU 显存 16G)。
接下来配置数据硬盘的大小,每个实例默认附带了 50GB 的数据硬盘,首次创建可以就选择默认大小 50GB。
继续选择安装的镜像,平台提供了一些基础镜像供快速启动,镜像中安装了对应的基础环境和框架,可通过勾选来筛选框架,这里筛选 PyTorch,选择 PyTorch 2.4.0。
为保证安全登录,创建密钥对,输入自定义的名称,然后选择自动创建并将创建好的私钥保存到自己电脑中并将后缀改为.pem,以便后续本地连接使用。
创建好密钥对后,选择刚刚创建好的密钥对,并点击立即创建,等待一段时间后即可启动成功!
部署与配置 Llama 3.1
实例成功创建后,通过 JupyterLab 的在线登录入口进入了实例的操作界面。在这个环境中,所有的文件路径和资源配置都已经预先设置好,这极大地简化了操作。通过 conda 创建了一个新的环境,并安装了部署 Llama 3.1 所需的依赖库,包括 LangChain、Streamlit、Transformers 和 Accelerate。
以下是安装依赖的关键命令:
pip install langchain==0.1.15
pip install streamlit==1.36.0
pip install transformers==4.44.0
pip install accelerate==0.32.1
依赖安装完成后,平台提供了内网下载 Llama-3.1-8B 模型的功能,下载速度非常快。解压完模型后,编写了一个简单的 Streamlit 脚本,用于启动 Llama 3.1 模型的聊天界面。Streamlit 的使用非常简便,可以快速搭建一个 Web 服务来和模型进行交互。
代码核心部分如下:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
# 创建标题和副标题
st.title("💬 LLaMA3.1 Chatbot")
st.caption("🚀 A streamlit chatbot powered by Self-LLM")
# 定义模型路径
mode_name_or_path = '/root/workspace/Llama-3.1-8B-Instruct'
# 获取模型和 tokenizer
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(mode_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(mode_name_or_path, torch_dtype=torch.bfloat16).cuda()
tokenizer, model
tokenizer, model = get_model()
prompt := st.chat_input():
st.chat_message().write(prompt)
input_ids = tokenizer([prompt], return_tensors=).to()
generated_ids = model.generate(input_ids.input_ids, max_new_tokens=)
response = tokenizer.decode(generated_ids[], skip_special_tokens=)
st.chat_message().write(response)


