第八节 LLaVA模型CLI推理构建custom推理代码Demo

第八节 LLaVA模型CLI推理构建custom推理代码Demo

文章目录


前言

我在第七节介绍了cli.py推理源码解读,而我也因项目需要构建了推理demo,我们是用来自动生成标签和推理需要。想了想,我还是用一节将我的代码记录于此,供有需求读者使用。本节,介绍更改cli.py代码,实现一张图像推理、也为需要grounding的读者提供如何在图上给出目标box。


一、parser 参数设定

为什么我要单独介绍参数设定?因为它很重要,正确的设定会减少模型错误概率。我将介绍三个部分设定,一个是使用lora权重,一个是合并权重,最后一个是使用量化方式。

1、lora权重推理

我们训练模型多数使用lora训练,而未将lora训练结果合并的权重加载方式的方法。如果我们是使用自己训练方法,可以使用如下方式给出参数:

 parser.add_argument("--model-path", type=str, default="/extend_disk/disk3/tj/LLaVA/checkpoints/llava-v1.5-13b-lora_vaild_1epoch_clean2/checkpoint-10200") parser.add_argument("--model-base", type=str, default="/extend_disk/disk3/tj/LLaVA/llava_v1.5_lora/vicuna-13b-v1.5") 

如果我们是使用LLaVA自带lora方式,model-base基本不变,只需将model-path="/LLaVA/checkpoint/llava-v1.5-13b-lora",而权重下载我之前文章也介绍。

2、非lora权重推理

我们训练模型使用lora方法保存,想调用非lora方式,就需要将其转换。我们这里不说转换方法,给出非lora的权重加载方式。那这里只介绍官方给出权重加载参数设定,如下:

 parser.add_argument("--model-path", type=str, default="/LLaVA/llava_v1.5_lora/llava-v1.5-13b") parser.add_argument("--model-base", type=str, default=None) 

3、量化权重推理

量化只需打开load-8bit或load-4bit参数,但量化必须是非lora权重加载方式,其代码如下:

 parser.add_argument("--load-8bit", action="store_true") # parser.add_argument("--load-4bit", default=True) parser.add_argument("--load-4bit", action="store_true") 

当然量化显存占用测试,我们以LLAVA-13b量化显存测试:
不量化推理显存占用:28.4G
8bit量化推理显存占用:16.6G
4bit量化推理显存占用:10.6G

4、实验总结

我测试官方提供lora与非lora权重,我发现非lora效果会比lora好。当然这是我测试工程数据得到结论,只做参考。

二、初始化模型

我不在介绍,如下代码:

def llava_init(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) return tokenizer, model, image_processor, context_len,model_name 

我想说,每个权重名称需包含v1字符,以便后续对话加载方式。

三、模型推理

模型推理,我将提示改成列表方式,我也对有框目标的文本预测做了图上画框操作。其它基本都是流程,我不在解读了。

四、完整代码Demo

最后,我给出完整的Demo,可以直接复制粘贴即可使用。若还想按照自己custom方式,读者也可根据我提供的方法来修改。其完整带阿米如下:

import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" import argparse import torch from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from PIL import Image import requests from PIL import Image from io import BytesIO from transformers import TextStreamer def img_drawingbox(image,conversation_info,res_img_path=None): from PIL import Image, ImageDraw, ImageFont import re width, height = image.size draw = ImageDraw.Draw(image) box_lst = [] for info in conversation_info['conversations']: value = info['value'] gpt = info['from'] if gpt == 'gpt': result = re.search(r'\[(.*?)\]', value) if result: content_in_brackets = result.group(1) # 将提取的内容转换为浮点数列表 float_list = [float(num) for num in content_in_brackets.split(',')] if float_list not in box_lst: box_lst.append(float_list) if len(box_lst)>0: for b in box_lst: if len(b)==4: x1,y1,x2,y2 = b[0]*width,b[1]*height,b[2]*width,b[3]*height x1,y1,x2,y2=max(0,int(x1)),max(0,int(y1)),min(width,int(x2)),min(y2,height) box=(x1,y1,x2,y2) # 绘制矩形框 draw.rectangle(box, outline="red", width=2) # 红色边框,宽度为2像素 if res_img_path is not None: image.save(res_img_path,encoding="utf-8") return image def load_image(image_file): if image_file.startswith('http://') or image_file.startswith('https://'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def llava_init(args): # Model disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) return tokenizer, model, image_processor, context_len,model_name def llava_infer(image,test_prompt,args,tokenizer, model, image_processor, model_name='llava_v1.5'): assert isinstance(test_prompt,list), "test_prompt提示文本必须是问题构成的列表!" if 'llama-2' in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conversations_json = {'conversations':[]} conv = conv_templates[args.conv_mode].copy() if "mpt" in model_name.lower(): roles = ('user', 'assistant') else: roles = conv.roles width, height = image.size # Similar operation in model_worker.py image_tensor = process_images([image], image_processor, model.config) if type(image_tensor) is list: image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] else: image_tensor = image_tensor.to(model.device, dtype=torch.float16) for i ,inp in enumerate(test_prompt): conversations_json['conversations'].append({"from": "human","value":inp}) if i==0: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: # inp = DEFAULT_IMAGE_TOKEN + '\n' + inp # 走这步变成 <image>\n描述图像内容 conv.append_message(conv.roles[0], inp) else: # later messages # 后面循环对话添加内容 conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] # '</s>' ,这个是每句结束标志 stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # 下面开始走模型 with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True if args.temperature > 0 else False, temperature=args.temperature, max_new_tokens=args.max_new_tokens, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() # ouput_ids中去除input_ids位置prompt conv.messages[-1][-1] = outputs conversations_json['conversations'].append({"from": "gpt","value":outputs.replace('</s>','')}) print(conversations_json) img_drawingbox(image,conversations_json,res_img_path=None) return conversations_json def parse_args(): parser = argparse.ArgumentParser() ## 直接使用合并后的模型进行推理 # parser.add_argument("--model-path", type=str, default="/LLaVA/llava_v1.5_lora/llava-v1.5-13b") # parser.add_argument("--model-base", type=str, default=None) ## lora推理方法 parser.add_argument("--model-path", type=str, default="/LLaVA/checkpoints/llava-v1.5-13b-lora_vaild_1epoch/checkpoint-10200") parser.add_argument("--model-base", type=str, default="/LLaVA/llava_v1.5_lora/vicuna-13b-v1.5") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") # parser.add_argument("--load-4bit", default=True) parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() return args if __name__ == "__main__": args=parse_args() tokenizer, model, image_processor, context_len,model_name=llava_init(args) img_path = '/LLaVA/llava/serve/examples/1.jpg' images = load_image(img_path) test_prompt = ["图中是否有城市管理相关目标?若有,请提供相应坐标。"] predect_information_dict = llava_infer(images,test_prompt,args,tokenizer, model, image_processor, model_name) 

Read more

最新电子电气架构(EEA)调研-3

而新一代的强实时性、高确定性,以及满足CAP定理的同步分布式协同技术(SDCT),可以实现替代TSN、DDS的应用,且此技术已经在无人车辆得到验证,同时其低成本学习曲线、无复杂二次开发工作,将开发人员的劳动强度、学习曲线极大降低,使开发人员更多的去完成算法、执行器功能完善。 五、各大车厂的EEA 我们调研策略是从公开信息中获得各大车厂的EEA信息,并在如下中进行展示。 我们集中了华为、特斯拉、大众、蔚来、小鹏、理想、东风(岚图)等有代表领先性的车辆电子电气架构厂商。        1、华为 图12 华为的CCA电子电气架构              (1)华为“计算+通信”CC架构的三个平台                         1)MDC智能驾驶平台;                         2)CDC智能座舱平台                         3)VDC整车控制平台。        联接指的是华为智能网联解决方案,解决车内、车外网络高速连接问题,云服务则是基于云计算提供的服务,如在线车主服务、娱乐和OTA等。 华

By Ne0inhk
Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践

Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践

Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践 文章目录 * Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践 * Apache IoTDB 核心特性与价值 * Apache IoTDB 监控面板完整部署方案 * 安装步骤 * 步骤一:IoTDB开启监控指标采集 * 步骤二:安装、配置Prometheus * 步骤三:安装grafana并配置数据源 * 步骤四:导入IoTDB Grafana看板 * TimechoDB(基于 Apache IoTDB)增强特性 * 总结与应用场景建议 Apache IoTDB 核心特性与价值 Apache IoTDB 专为物联网场景打造的高性能轻量级时序数据库,以 “设备 - 测点” 原生数据模型贴合物理设备与传感器关系,通过高压缩算法、百万级并发写入能力和毫秒级查询响应优化海量时序数据存储成本与处理效率,同时支持边缘轻量部署、

By Ne0inhk
SQL Server 2019安装教程(超详细图文)

SQL Server 2019安装教程(超详细图文)

SQL Server 介绍) SQL Server 是由 微软(Microsoft) 开发的一款 关系型数据库管理系统(RDBMS),支持结构化查询语言(SQL)进行数据存储、管理和分析。自1989年首次发布以来,SQL Server 已成为企业级数据管理的核心解决方案,广泛应用于金融、电商、ERP、CRM 等业务系统。它提供高可用性、安全性、事务处理(ACID)和商业智能(BI)支持,并支持 Windows 和 Linux 跨平台部署。 一、获取 SQL Server 2019 安装包 1. 官方下载方式 前往微软官网注册账号后,即可下载 SQL Server Developer 版本(

By Ne0inhk