import sys
import os
import torch
import numpy as np
from typing import Any
from omegaconf import OmegaConf
from PIL import Image, ImageFile
from collections import OrderedDict
from habitat import logger, Env
from habitat_baselines.config.default import get_config as get_habitat_config
from model.stream_video_vln import StreamVLNForCausalLM
from utils.utils import dict_to_cuda
class VLNEvaluator:
def __init__(self, config_path: str, split: str = "val_seen", env_num: int = 8, output_path: str = None, model: Any = None, tokenizer: Any = None):
self.args = args
self.device = torch.device('cuda')
self.split = split
self.env_num = env_num
self.save_video = args.save_video
self.output_path = output_path
self.epoch = epoch
self.config_path = config_path
self.config = get_habitat_config(config_path)
self.agent_config = get_agent_config(self.config.habitat.simulator)
self.sim_sensors_config = self.config.habitat.simulator.agents.main_agent.sim_sensors
def eval_action(self, idx) -> None:
env = self.config_env()
scene_episode_dict = {}
for episode in env.episodes:
if episode.scene_id not in scene_episode_dict:
scene_episode_dict[episode.scene_id] = []
scene_episode_dict[episode.scene_id].append(episode)
intrinsic_matrix = self.get_intrinsic_matrix(self.config.habitat.simulator.agents.main_agent.sim_sensors.rgb_sensor)
sucs, spls, oss, ones = [], [], [], []
done_res = []
for scene in sorted(scene_episode_dict.keys()):
episodes = scene_episode_dict[scene]
scene_id = scene.split('/')[-2]
print(f"当前场景 ID = {scene_id}")
process_bar = tqdm.tqdm(range(len(episodes[idx::self.env_num])), desc=f"场景 {scene_id}")
for episode in episodes[idx::self.env_num]:
episode_instruction = episode.instruction.instruction_text
episode_id = episode.episode_id
if [scene_id, episode_id, episode_instruction] in done_res:
continue
self.model.reset_for_env(idx)
env.current_episode = episode
observations = env.reset()
vis_frames = []
step_id = 0
rgb_list, depth_list, pose_list, intrinsic_list, time_ids = [], [], [], [], []
action_seq = []
past_key_values = None
output_ids = None
while not env.episode_over:
self.model.eval()
time_ids.append(step_id)
rgb = observations["rgb"]
depth = observations["depth"]
x, y = observations["gps"]
camera_yaw = observations["compass"][0]
depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None)
depth = depth * (self._max_depth - self._min_depth) + self._min_depth
depth = depth * 1000
agent_state = env.sim.get_agent_state()
height = agent_state.position[1] - initial_height
camera_position = np.array([x, -y, self._camera_height + height])
robot_xy = camera_position[:2]
tf_camera_to_episodic = self.xyz_yaw_to_tf_matrix(camera_position, camera_yaw)
image = Image.fromarray(rgb).convert('RGB')
image_size = image.size
image = self.image_processor.preprocess(images=image, return_tensors='pt')['pixel_values'][0]
depth_image, resize_shape = self.preprocess_depth_image(Image.fromarray(depth.astype(np.uint16), mode='I;16'), do_depth_scale=True)
intrinsic = self.preprocess_instrinsic(intrinsic_matrix, image_size, resize_shape)
intrinsic = torch.from_numpy(intrinsic).float()
rgb_list.append(image)
depth_list.append(torch.from_numpy(depth_image).float())
pose_list.append(torch.from_numpy(tf_camera_to_episodic) @ self.get_axis_align_matrix())
intrinsic_list.append(intrinsic)
if len(action_seq) == 0:
if output_ids is None:
sources = copy.deepcopy(self.conversation)
sources[0]["value"] = sources[0]["value"].replace('<instruction>.', episode.instruction.instruction_text)
add_system = True
else:
sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}]
add_system = False
input_ids, conversations = self.preprocess_qwen([sources], self.tokenizer, True, add_system=add_system)
if output_ids is not None:
input_ids = torch.cat([output_ids, input_ids.to(output_ids.device)], dim=1)
images = rgb_list[-1:]
depths = depth_list[-1:]
poses = pose_list[-1:]
intrinsics = intrinsic_list[-1:]
if step_id != 0 and step_id % self.num_frames == 0:
history_ids = slice(0, time_ids[0], (time_ids[0] // self.num_history))
images = rgb_list[history_ids] + images
depths = depth_list[history_ids] + depths
poses = pose_list[history_ids] + poses
intrinsics = intrinsic_list[history_ids] + intrinsics
input_dict = {
'images': torch.stack(images).unsqueeze(0),
'depths': torch.stack(depths).unsqueeze(0),
'poses': torch.stack(poses).unsqueeze(0),
'intrinsics': torch.stack(intrinsics).unsqueeze(0),
'inputs': input_ids,
'env_id': idx,
'time_ids': [time_ids],
'task_type': [0]
}
input_dict = dict_to_cuda(input_dict, self.device)
for key, value in input_dict.items():
if key in ['images', 'depths', 'poses', 'intrinsics']:
input_dict[key] = input_dict[key].to(torch.bfloat16)
outputs = self.model.generate(
**input_dict, do_sample=False, num_beams=1, max_new_tokens=10000,
use_cache=True, return_dict_in_generate=True, past_key_values=past_key_values
)
output_ids = outputs.sequences
past_key_values = outputs.past_key_values
llm_outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0].strip()
action_seq = self.parse_actions(llm_outputs)
if len(action_seq) == 0:
action_seq = [0]
action = action_seq.pop(0)
observations = env.step(action)
step_id += 1
if step_id % self.num_frames == 0:
self.model.reset_for_env(idx)
output_ids = None
past_key_values = None
time_ids = []
metrics = env.get_metrics()
sucs.append(metrics['success'])
spls.append(metrics['spl'])
oss.append(metrics['oracle_success'])
ones.append(metrics['distance_to_goal'])
print(f"场景-episode {scene_id}_{episode_id} 结果:成功={metrics['success']}, SPL={metrics['spl']}")
result = {
"scene_id": scene_id, "episode_id": episode_id, "success": metrics["success"],
"spl": metrics["spl"], "os": metrics['oracle_success'], "ne": metrics["distance_to_goal"],
"steps": step_id, "episode_instruction": episode_instruction
}
with open(os.path.join(self.output_path, f'result.json'), 'a') as f:
f.write(json.dumps(result) + "\n")
return (torch.tensor(sucs).to(self.device), torch.tensor(spls).to(self.device), torch.tensor(oss).to(self.device), torch.tensor(ones).to(self.device), torch.tensor(len(sucs)).to(self.device))