-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize_artifact.py
More file actions
72 lines (59 loc) · 2.42 KB
/
visualize_artifact.py
File metadata and controls
72 lines (59 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# visualize_artifact.py
# Load a saved artifact and launch the MuJoCo viewer.
from pathlib import Path
import json
import numpy as np
import mujoco as mj
from mujoco import viewer
from networkx.readwrite import json_graph
from networkx import DiGraph
# Ariel imports
from ariel.body_phenotypes.robogen_lite.constructor import construct_mjspec_from_graph
from ariel.simulation.environments import OlympicArena
from ariel.utils.runners import simple_runner
from ariel.utils.tracker import Tracker
def _rebuild_graph(graph_payload: dict) -> DiGraph:
return json_graph.node_link_graph(graph_payload)
def _launch(robot_graph: DiGraph, cpg_params: np.ndarray, spawn_pos, duration: int = 30):
# Build world and model
mj.set_mjcb_control(None)
world = OlympicArena()
fresh_core = construct_mjspec_from_graph(robot_graph)
world.spawn(fresh_core.spec, spawn_position=np.array(spawn_pos))
model = world.spec.compile()
data = mj.MjData(model)
mj.mj_resetData(model, data)
mj.mj_forward(model, data)
# CPG controller (inline to avoid importing from training file)
def cpg_controller(m, d, params):
nu = m.nu
t = d.time
freqs = params[:nu]
amps = params[nu:2*nu]
phases = params[2*nu:3*nu]
return amps * np.sin(freqs * t + phases)
def control_cb(m, d):
d.ctrl[:] = cpg_controller(m, d, cpg_params)
mj.set_mjcb_control(control_cb)
print("[VIS] Launching viewer...")
viewer.launch(model=model, data=data)
def load_and_visualize(artifact_path: str):
artifact_path = Path(artifact_path)
with open(artifact_path, "r") as f:
payload = json.load(f)
meta = payload["meta"]
robot_graph = _rebuild_graph(payload["robot_graph"])
cpg_params = np.array(payload["best_cpg_params"], dtype=np.float32)
spawn_pos = meta.get("spawn_pos", [-0.8, 0, 0.1])
#spawn_pos = [-0.8, 0.0, 0.12] # Spawn at rugged
#spawn_pos = [0.9, 0.0, 0.12] # Spawn at rugged
#spawn_pos = [2.65, 0.0, 0.12] # Spawn at hill
_launch(robot_graph, cpg_params, spawn_pos, duration=meta.get("outer_duration", 30))
if __name__ == "__main__":
# Example usage:
# python visualize_artifact.py __data__/Nested_evolution_artifacts/artifacts/run-20251008-213000/best_robot.json
import sys
if len(sys.argv) < 2:
print("Usage: python visualize_artifact.py <path/to/best_robot.json>")
sys.exit(1)
load_and_visualize(sys.argv[1])