"""
推理脚本:基于'万物识别 - 中文 - 通用领域'模型进行成人内容检测
"""
import torch
from torchvision import transforms
from PIL import Image
import os
MODEL_PATH = "/root/model/wwts_model.pth"
IMAGE_PATH = "/root/workspace/bailing.png"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def load_model():
model = torch.hub.load('pytorch/vision:v0.16.0', 'resnet50', pretrained=False)
num_classes = 1000
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
if os.path.exists(MODEL_PATH):
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
print("✅ Model loaded successfully.")
else:
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
model.to(DEVICE)
model.eval()
return model
def predict(image_path, model, top_k=5):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_probs, top_indices = torch.topk(probabilities, top_k)
labels_zh = {
0: "正常风景",
1: "人物肖像",
2: "动物世界",
3: "食物饮品",
4: "暴露服装",
5: "亲密行为",
6: "暴力场面",
7: "广告营销",
}
results = []
for i in range(top_k):
idx = top_indices[i].item()
prob = top_probs[i].item()
label = labels_zh.get(idx, f"未知类别 ({idx})")
results.append({"label": label, "score": round(prob, 4)})
return results
if __name__ == "__main__":
try:
model = load_model()
print("🔍 Starting inference...")
results = predict(IMAGE_PATH, model, top_k=5)
print("\n📋 Top-5 Predictions:")
for r in results:
print(f" {r['label']} : {r['score']:.4f}")
risk_labels = ["暴露服装", "亲密行为"]
threshold = 0.7
is_risky = any(r["label"] in risk_labels and r["score"] >= threshold for r in results)
if is_risky:
print("\n🚨 检测到潜在违规内容!建议人工复核。")
else:
print("\n✅ 内容初步判断为合规。")
except Exception as e:
print(f"❌ Error during inference: {str(e)}")