import streamlit as st
from PIL import Image
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import os
@st.cache_resource
def load_model():
model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
return model, processor
model, processor = load_model()
def load_image(image_file):
return Image.open(image_file)
def resize_image_to_height(image, height=300):
width = int(image.width * height / image.height)
return image.resize((width, height))
def process_input(messages):
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
del inputs, generated_ids, generated_ids_trimmed
torch.cuda.empty_cache()
return output_text[0] if output_text else "模型未返回结果"
st.title("🧠 Qwen3-VL 多模态智能交互平台")
uploaded_file = st.file_uploader("📤 上传图片或视频", type=["jpg", "jpeg", "png", "mp4"])
if uploaded_file is not None:
upload_dir = "uploads"
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
messages = []
if uploaded_file.type.startswith("image"):
img = load_image(file_path)
img_resized = resize_image_to_height(img, 300)
st.image(img_resized, caption="已上传图像", use_container_width=False)
st.subheader("💬 输入你的问题")
user_input = st.text_input("例如:这张图里有什么?请描述细节。", key="img_input")
messages = [{
"role": "user",
"content": [
{"type": "image", "image": file_path, "max_pixels": 1024 * 960},
{"type": "text", "text": user_input}
]
}]
elif uploaded_file.type.startswith("video"):
st.video(file_path)
st.markdown(
"""<style>video {height: 300px; width: auto;}</style>""",
unsafe_allow_html=True
)
st.subheader("💬 输入你的问题")
user_input = st.text_input("例如:这个视频讲了什么?关键事件有哪些?", key="vid_input")
messages = [{
"role": "user",
"content": [
{"type": "video", "video": file_path, "max_pixels": 960*480, "fps": 1.0},
{"type": "text", "text": user_input}
]
}]
if st.button("🚀 开始推理") and user_input.strip():
with st.spinner("模型正在思考..."):
result = process_input(messages)
st.markdown("### ✅ 推理结果:")
st.markdown(f'<div>{result}</div>', unsafe_allow_html=True)
try:
os.remove(file_path)
except Exception as e:
st.warning(f"临时文件清理失败:{e}")