Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions task_decomposition/scripts/record_robomimic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_data_to_record(env_name: str):

# Unique to each environment
if env_name == "NutAssemblySquare":
meta_data_to_record = ["nut0_pos", "nut1_pos"]
meta_data_to_record = ["nut0_pos", "peg1_pos"]
elif env_name == "ToolHang":
meta_data_to_record = ["tool_pos", "frame_pos", "stand_pos"]
else:
Expand All @@ -51,6 +51,10 @@ def query_sim_for_data(env, desired_obs):
return env.env.sim.data.body_xpos[env.env.obj_body_id[env.env.nuts[0].name]]
elif desired_obs == "nut1_pos":
return env.env.sim.data.body_xpos[env.env.obj_body_id[env.env.nuts[1].name]]
elif desired_obs == "peg1_pos":
return env.env.sim.data.body_xpos[env.env.peg1_body_id]
elif desired_obs == "peg2_pos":
return env.env.sim.data.body_xpos[env.env.peg2_body_id]
else:
raise ValueError(f"Environment {env.name} has no defined data to record.")
elif env.name == "ToolHang":
Expand Down Expand Up @@ -89,11 +93,33 @@ def extract_trajectory(env, initial_state, states, actions, done_mode):

data_to_record = get_data_to_record(env_name=env.name)
df = pd.DataFrame(columns=get_data_to_record(env_name=env.name))
gt_df = pd.DataFrame(columns=["step", "subtask", "stage"])

traj_len = states.shape[0]
frames = []
for k in range(traj_len):

if env.name == "NutAssemblySquare":
subtask_list = [
"Reach for the Square Nut",
"Grasp the Square Nut",
"Align the Square Nut with the Squre Peg",
"Insert the Square Nut",
]
elif env.name == "ToolHang":
subtask_list = [
"Reach for the Frame",
"Grasp the Frame",
"Align the Frame with the Stand",
"Insert the Frame into the Stand",
"Reach for the Tool",
"Grasp the Tool",
"Align the Tool with the Frame",
"Hang the Tool",
]

stage = 0

for k in range(traj_len):
obs = env.reset_to({"states": states[k]})
frame = obs[IMAGE_VIEW_TO_RECORD]
frames.append(frame)
Expand All @@ -119,14 +145,43 @@ def extract_trajectory(env, initial_state, states, actions, done_mode):
row_data[o] = np.around(actions[k], 2).tolist()
elif o == "robot0_eef_pos":
row_data[o] = np.around(obs[o], 2).tolist()
elif o == "sub_task":
pass
else:
row_data[o] = np.around(
query_sim_for_data(env, desired_obs=o), 2
).tolist()

df.loc[k] = row_data

print(" Done Running Simulation.")
return df, frames
# Advange the stage for ground truth label
if env.name == "NutAssemblySquare":
if stage == 0 and actions[k][6] > 0:
stage = 1
elif stage == 1 and row_data["nut0_pos"][2] > 0.83:
stage = 2
elif stage == 2 and actions[k][6] < 0:
stage = 3
elif env.name == "ToolHang":
if stage == 0 and actions[k][6] > 0:
stage = 1
elif stage == 1 and row_data["frame_pos"][2] > 0.81:
stage = 2
elif stage == 2 and actions[k][6] < 0:
stage = 3
elif stage == 3 and row_data["frame_pos"][2] < 1:
stage = 4
elif stage == 4 and actions[k][6] > 0:
stage = 5
elif stage == 5 and row_data["tool_pos"][2] > 0.81:
stage = 6
elif stage == 6 and actions[k][6] < 0:
stage = 7

gt_row_data = {"step": k, "subtask": subtask_list[stage], "stage": stage}
gt_df.loc[k] = gt_row_data

return df, gt_df, frames


def record_dataset(args):
Expand All @@ -143,6 +198,7 @@ def record_dataset(args):
)

save_txt = True if args.save_txt == 1 else False
save_gt = True if args.save_gt == 1 else False
save_video = True if args.save_video == 1 else False
print("==== Using environment with the following metadata ====")
print(json.dumps(env.serialize(), indent=4))
Expand Down Expand Up @@ -177,7 +233,7 @@ def record_dataset(args):
actions = f["data/{}/actions".format(ep)][()]
timenow = datetime.now().strftime("%Y%m%d-%H%M%S")
idx = timenow + f"_{idx}"
df, frames = extract_trajectory(
df, gt_df, frames = extract_trajectory(
env=env,
initial_state=initial_state,
states=states,
Expand All @@ -187,16 +243,22 @@ def record_dataset(args):

# save data
filename = env.name + f"_{idx}"
print(" Done Running Simulation.")
save_video_fn(frames=frames, filename=filename) if save_video else None
save_df_to_txt(df=df, filename=filename) if save_txt else None
save_df_to_txt(df=df, filename=filename, kind="raw") if save_txt else None
(
save_df_to_txt(df=gt_df, filename=filename + "_gt", kind="gt")
if save_gt
else None
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default=None,
default=None, # Example: "/Users/jonathansalfity/Documents/dev/task_decomposition/task_decomposition/data/robomimic/tool_hang/demo_v141.hdf5",
required=True,
help="path to input hdf5 dataset",
)
Expand All @@ -215,6 +277,13 @@ def record_dataset(args):
help="(Required) but default to True, save txt files",
)

parser.add_argument(
"--save_gt",
type=int,
default=1,
help="(Required) but default to True, save ground truth files",
)

parser.add_argument(
"--save_video",
type=int,
Expand Down
3 changes: 2 additions & 1 deletion task_decomposition/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def make_env(config):
return suite.make(**config)


def save_df_to_txt(df: pd.DataFrame, filename, kind):
def save_df_to_txt(df: pd.DataFrame, filename: str, kind: str):
"""
Dump pandas dataframe to file
kind: "raw" or "gt"
"""
if kind == "raw":
savepath = DATA_RAW_TXT_PATH
Expand Down