forked from vla-safe/openpi
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
160 lines (132 loc) · 5.11 KB
/
evaluate.py
File metadata and controls
160 lines (132 loc) · 5.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import collections
import csv
import datetime
import math
import pickle as pkl
import re
from functools import partial
from pathlib import Path
import fire
import imageio
import matplotlib.pyplot as plt
import numpy as np
from dask.distributed import Client, LocalCluster
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from openpi_client import websocket_client_policy as _websocket_client_policy
from ribs.archives import GridArchive
from ribs.emitters import EvolutionStrategyEmitter
from ribs.schedulers import Scheduler
from ribs.visualize import grid_archive_heatmap
from tqdm import tqdm, trange
task_5_bddl = (
Path(get_libero_path("bddl_files"))
/ "custom"
/ "pick_up_the_black_bowl_next_to_the_ramekin_and_place_it_on_the_plate.bddl"
)
TASK_ENV = partial(
OffScreenRenderEnv,
bddl_file_name=task_5_bddl,
camera_heights=256,
camera_widths=256,
)
max_steps = 220
num_steps_wait = 10
host = "0.0.0.0"
port = 8000
replan_steps = 5
def _quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def collect_pi0fast_embedding(metadata):
# using best ablation settings from SAFE paper
pre_logits = metadata["pre_logits"]
return np.mean(pre_logits, axis=0)
def evaluate(params,
ntrials,
seed,
encoder,
video_logdir=None):
np.random.seed(seed)
openpi_client = _websocket_client_policy.WebsocketClientPolicy(host, port)
env = TASK_ENV(
params=params,
repair_env=True,
repair_config={
'time_limit':1500,
'seed':seed
}
)
env.seed(seed)
obs = env.reset()
if obs is None:
# TODO: How to handle solutions that fail to evaluate
return 1e-6, 0, 0, None
if video_logdir is not None:
# ID each sol with datetime to prevent overwriting
sol_logdir = Path(video_logdir) / f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
sol_logdir.mkdir(parents=True)
# Get success rates by running openpi on env
success_rate = 0
all_embeddings = []
for trial_id in trange(ntrials):
print(f"RUNNING {trial_id}")
obs = env.reset()
action_plan = collections.deque()
trial_embeddings = []
for t in range(max_steps + num_steps_wait):
try:
if t < num_steps_wait:
# Do nothing at the start to wait for env to settle
obs, reward, done, info = env.step([0.0] * 6 + [-1.0])
continue
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
if not action_plan:
element = {
"observation/image": img,
"observation/wrist_image": wrist_img,
"observation/state": np.concatenate(
(
obs["robot0_eef_pos"],
_quat2axisangle(obs["robot0_eef_quat"]),
obs["robot0_gripper_qpos"],
)
),
"prompt": env.language_instruction,
}
model_data = openpi_client.infer(element)
t_embedding = collect_pi0fast_embedding(model_data)
trial_embeddings.append(t_embedding)
action_chunk = model_data["actions"]
assert (
len(action_chunk) >= replan_steps
), f"We want to replan every {replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
action_plan.extend(action_chunk[: replan_steps])
action = action_plan.popleft()
obs, reward, done, info = env.step(action.tolist())
if done:
success_rate += 1 / ntrials
break
except Exception as e:
print(e)
# TODO: How to handle solutions that fail to evaluate
return 1e-6, 0, 0, None
all_embeddings.append(trial_embeddings)
_, _, _, latent_measures = encoder.encode(all_embeddings)
# Maximizes entropy as objective, i.e. we want more uncertain
success_rate = np.clip(success_rate, 1e-6, 1 - 1e-6)
entropy = -success_rate*math.log2(success_rate) - (1-success_rate)*math.log2(1-success_rate)
openpi_client._ws.close()
return latent_measures, entropy, all_embeddings