7
zJ{q_b4&^{4gb;@cA#OQ=VRjI0qFR2ApLaM~fTgS0oZw$NInfEa}
zAHSKMw`O18Q30;EU+|PZ9T9~4IBCB8FnJB0T0|5CDp1i7h{Hu9Qi({KL_J}FdZ~~4
zb&2*|_rOY^+X_8$-9tR|DDl!>;?o5h{6-{x;C;Y*Xg>+iW8?@8k)!lD>7^$~ke(!c
zbb$2JQ{)&Oq~VVQ5~2~#kJBjUC+HC8C+RTf19XJ*Q*@N`K|02Hn2vKEp%a`(=_Kbv
z^fc$gbc*v4I?eegJ;V7Jo#A|33u=83;n`t!0jR5KmO@L6^lH^o&7_)QXcZC!VGR^kfYr96+=`q`FJgoZB;8v&C
zNEml6suj3f^y!xezFWMv{dwZiEW{r36W@jkzLG$#Xs#PnW5lN!1#^Q0-!lOK1`(Nt
zL3(*aidxYPjRc(uGcDE9Vj%`F2oGsUQvR1Kbab^}w
z%PGeAnpvq?Dy+U(-LeC(7;3pd)e8oCE&&M%>;y2^@EG|R(3G;S0JU0g1Es}!$GtTb(ovw^YFdzHw-I
zet;4k&F*8sW)942(<#o+b3Dr7
z*P?9Fbwj&=KSwnJBp?Pv@2h^FB>F|~Gk;GTf9wr=K>}nvUXkB*9`NESzonK@eHN1b
E1z1wCX#fBK
literal 0
HcmV?d00001
diff --git a/experiments/robot/aloha/__pycache__/real_env.cpython-310.pyc b/experiments/robot/aloha/__pycache__/real_env.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d9dd0a30981643af847ba936b27492ee956e6db
GIT binary patch
literal 7934
zcmcgxOK;rP73OPZIHS=>mgRTs(3X>U8pm?tBu!nXjcQw|QhVfBlNW=4q2|&f%6yD3
zIkv2f(1lZ=iyGZ@SD=uW+O~h7D9}H_Re^5mU4bkL6amuj+)GNDv7Ew4S`K)5ABTtM
zea^iTkB`?BTz~v$!{2&QQT|SY-ItERJGhgt(J+Onfnv+2YHMPw+q!rfVaYD5a;;)l
zFsB7(ShcHSP7i9~m^~KO?Rq$FkB1GrA?6LtPuLUTq&+FdC5-pk`@$)EN{q|F{&3o!
z4iDG|#JGa-LHl5M$UYPvwhxC#>?2~{4337!>|>%|4W0{++s8#;4Nll6yc4U>+t#}Z
ztFf^!6*l%jvrn=*8~;+VPq71RlI{CawP)BA+mD&k-V1E{w#EnE=?4mMs7vB*9()W5
zD6OWN)!y&BsmCKfj+#cMU1?{HrS{y-wljD0x^rW3>1uZJ`rK0c>P_d}n+rE?;Q8L-
zLaXhx7H?jkyS89ob>>^y%#9DQ)E)u(ptT#Qsb!}}z<_?Cy|CCCMm7rd;ni!4mlxWf
z3_$_gkr8Z?uEiTyTiMY)mo8reJ9yE&>2+exJf4j&gl^X>hTwlP@zQ?J3Ec!TahBt>
znCbO-hjaY}2u>0CJ&|hQLQ*4zfFiaIjd5RgV1jQPx%qmc9iJ7d5XIYv7OjMeH
z$r@||^HpI8ie1YMVT?_)1DLO~gX|EV<7|cUpV&{)=|-0!ZWBu?i}FHWp8Z#K`$mAsv?+~p(=`l)BFk_0V4`GiJb
z90JV;5@W6vrvL})S0(BhD~fsO2L8IotQ>3`t99o5MXM9UiFej|^`aH`yr}tKzVuYu
zzhBZnv#e*QT#&xFoOt}MyIb~OoFZwTx@cXFqmG+?YJIC3o}H0Vz3i=%TqV!W)+qX?
zjlYen@0%%UzOOc+CtTkjPp9tO^@6=Mg@s{4D3q!4MN>nEAqP1pD=pX?+eHJKk-OHCUg82?R_^EPx+K<$mPB`)p0{-1#u69
zGq1mOe(lZIztBiF-5F%315)S*vr_V(I;!_x4MD1irUkdTJcnr9Oc>i0?fGHSF>`
z;+CdykNZjb#5ewT$Tz;7@b-Mi*p*1BU%`DDck&*Z6lD+7n9huW%1RGQ1I1AX8lM~}
zo7y83MyXDDm6;DUtQcuY^i}lBX$5apA5!0B3bJ9XWtZ=+T{wR}$#fWKR>?DnUC(O-
zA)u|2Jtoqj6UNL7vRaqNC@dm3^z3?(n#8Ck!Kj~A{m4(92ybL1fo6V7<%dB(>IZ`$
znpM6J!)$V=9&rRi#7qREMZqK+&~Pzf1{7u*ux
z=En)7Ch;%3iMOTj8BG0qes(1ez1h@D(%E3;{%nuOUG9dnA6}b#WsbQ$R6?`f8Y&~7
z%7kQA&^GG^@rpZ3^@8I?chB|iXQtx_zMb>HnoOcGbwf43F?3UHl^?^!N^OCZqwwfBJdHx(4i1Z}&i3i!7rXF#51$RQ28tVFV
z@yXqS9r`R;8Qt{RuZVVh6rbHtVPmkjUuj2_4K39$LhpgvRX6lh-!ul=K<{dgP>OHq
zt;Z;+l*eSN%@RL>M`lnS;WU$JQ7;?ohy?9LsM4E;pxkaqA0$AByZ4H5QHR)5a=lQh
zKsd~8W_D4{(!#s=&{rlJQ?08F{B~qrfIO9Heg^OHCYl}M5hOXGCy|d(i=`oIb%4-)8Fyl#*;_@M`atWdq>`rA;x7XZL<{y;~@6
z(_P%?QU{Y##&TKklS^H?bN<>0{WTJ)Tv{TENckl9y1G79nA;A>5GO-6PTO4KWK(?v
z`{vl|f^hyk9tbq}giJO3B}|NnkiUVsPjCw}uZ?_HzSkiSD6Sz1lD^?v@T?6DW(;F@
zsMS}E)zWHtvqCeEa<7wAnVfj!7Me`&giLrKKR`nRn=sF=_4>HM^AewXj4FD>$fAO`kL?X5HVY%%i`|@4_yzEIAt=e!V5*1J^}~f4;Ppd*L&_il
zm2eBQqUeGAQh22qGOwjvlz&**;P|ePx+E4
z-=5Pkr^`8Jz)?!;pD%f-mBu?;CrWh|_O%w$$$~0#3+h
zT7!HJj>snm8WIqxne1Pv`4T#*@Hhl`Q8pm3tC!%1oBAX81k*&{5Pe2ShhSNu-OX9@Z>YF?*?+&8z5Oplk7QIjwo
z@_#IR0lHpvrmRvfQ#QdS
zBQccm)E<&;=q-VF%gCxYNa;CL#Aapg-E%p!P39+^xF4n280aiErp-!LW+9aXnYn~_
z+Vo0ke&Oa)JJYW(T*=HfJcU+dh?B_p90<(AMW&{lh#nQmL@tryOAi2(P&IHE64BJM
zxXWL!A1W+j1b)aQNL+La>5#M0ks2ipnv`GOE9(9NtG~f5qAv6>6zwG|LrNga+g8FA
z31-{q_si_Xz2!JgQM%&rD&ao{6esiN^VCp1&8ckQKc|LRii=8$ZGl7>qzd9J5)o2S2^OjhN}}E>
zQ|zULG$rT#@FlhLMQaXmez&d~0Q+JckK!T*x0$3gO4)sYo$v#F
zLm%ixDX9waAk)&;Q)=rJdYO*sJW9fN2yaKCmD{;L11zsW43#fq)WN-dQL3)s7DE0D
zI!xVwC^uBpPuQEY5_aJ#t2`_ryAW0I(^R9A?uXbr4fItVt2OkEfq~r=gwHRe2IZfP
z@}LY~HdZTaj1)vZ4u3W_O%WEWto}f~qI~w!hMCp~??!c?Z;lPBfHDT9&HBI)F+t&<
z3@V%B!+1c0vDru`2;$BU)5*AfJe
zD@lL37sRQJ;*RvquZku7Ux6bXs&a*2iZF%%kN=G#mYI|%n47v;spmq6oP0JjXlsy_
zlECwN&6+(?SP7ghLVal0NTT|RBa$_Pa|>H@vPxzXw$TR)r@~}%F2`sLf4NF
z=|9G>`66GX*;3RGd-vrgnqq6VFLYCqQs6J+1c-WBSw1<#x=4*UP`i&5LgoeOagE-q
z;xq&y6Q@Z+MsnLkAyaNMDYnSdhwYs~mLJ3QQ$LXBV&WW3>W->Fx$B}}Xdgii&;Px1
z){;DzCFkH|Lf`hmk@cjToC}k#I77253pehkD{*9(yl~lLnd*rsD%CBFgapKC@4YLA
z&d^ZAa`TeRoj;*UGC7CFMD0%1JKcx1NtJ$w%ap85qpUVG<(KBs+Bg#1%h1C&{{ddV
B#KHgo
literal 0
HcmV?d00001
diff --git a/experiments/robot/aloha/__pycache__/robot_utils.cpython-310.pyc b/experiments/robot/aloha/__pycache__/robot_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..310595bc89a30af0cb8b09acc3e54c411030c617
GIT binary patch
literal 7982
zcmb7J+mjnvdGC9-E~C+0Jf88cy}K-IRtY-;n}l#NUKa0qF(DHNJ0vJ?YuZOGpA+PR4m+9`Zm@dEnp3Lmti(RlLMid7C#XDf0WiZb>cKW1v<2
zook=-o$q#j=O|iOC~A0q`kxD}E9W)s7gWeUCJLXz8~+$YXhQF3b)8eAZXh)}X4k4)
zI<=V{yX({)&g&hw?uKsLs~6tVgeB~Iny?RzdJ!c@xG1?Om4qh>_q2Lh6h#T?f;cZ0
zMCG1duZTraMa`mE63a-dVnwVXT@q)+8q#HPR;(jk5gXzh(pB+{coyjyp16c+d~g)JP6ZwXg;
zNNrINMWl`>i84}``PS-Q#&;p(yC_z`&l$CFwHq|U&xegjicsFT(xBu54>3D>xbmF|4@(b>
zp(&Qd%6*-9+uuxCy`71+p-lWs?s8D_;Rd{W4&GodQtCgH#&%w<8!mJ{Wxq;vB#-PqjaL%s=}FY
zHydqWZym%t{-C&(jT+omB|m9(!#D}L{rJ(VDX&g9xYG%HO_cEb{I%MuDm9``Cu}6G
zs258IwaS)J+`n5dG;aAjydFu0p%&w?7e~_X#?5#WM0!-~wPIg{I|ohWq@7h6tc5c0
zdq`qcX!Tl&--|@3+&D~vB#~+Ng7d39?q>V{hh1ijVSySvqwQBwxptuK7;^;cdOK|cxQ
zR(Pi$O3d6#;w>5NM9EgriS~jt?;j+sPQ2N_t32OlVfhpeGj0R#&jMShsIH1{N*?m?y7cH8~YUdmA5Syz>O;P
zgE&t1WOUHntLccnVce3T@WWms5(uKQf?jfw$TCRt2mkfbf0s6wF4c+>dx$2_fz%gT
z-6ky6Pa?l@>s96Kg#jX8-DamC^~1Gn_AE94IEtt>Xxmkl4voUeMQ~8mvThjUqf?Dk=sEmSwl=2tk14y+I5D1
zvZYD9_$&&95ALNa(6EB_5n{zqPp~3h1>(e^ai~9|jZCbJ)@5Tqb>6ccn1|Ry_ua%e
z)Uv(Q-JY3>oP(yjj`!q)Ac=N_VIS%vgcX7!-!L{P2joY&bufV*d2&@Qnteu$lkH~tR_WX+6r=@^lsT^nhNcIh~w
zixlm;U3cV>-7?Z8THt(>B7)NYh`H2j|}bhBB5g8G6uGfs>0*Fw!rvPYcDYZ
zoz@po?;SqV0&Eu647J3ix3KZ{vtx=RUDUk^QkAsL4ZN-X0XE6RA1rLO>1g9)R|O0joY`um
z$D25-Z_-$UwfX9pk3w<*(EwE{;7z?K$3%#%uk3vW!=F-6a`@>N1=>yuT^
zseT4JC9EXzEy68p%WU6%Y3y4D4a;WBCiH
z+sAw2fp}wPpPOekXIEr;r}8ZS8pCAV$bj-Uh_D>%lzTcie}OKysM`+6;Kg|^=B$`G
z$zvdjH?$5m6;c~TUBVEA93VT5@dD!rl=KumWNj&yuo7H_;9O$AX9)**f>IZej#BS|
zOUQyx5rxwch5Q&MmLDhbNg}TknUne(s2fuKn;?T%rtADoH~v8>^?ly*8BPqAi#a9v
z+%Q-x&WS9oS9%d}fbE)1gf;UPSgF}lkWep9r>Ixw$F8qFY2^8Jegt}E6sP#W#E4)6h`=t}l3HfK
zAlVPab+#GWCdf|cblNrn0CfU@0DZ>xBib!$mNFp;W#Nh|ROS4(B;Q6mZfi{J9dlw(
zX9@F?1k(@3(I6}
z2>uW^eeYpJ%UhPWQ?aMgUOm+YX_5Os#*{*vH|QWl>Az5A1GieuzajFQVeb
z4ldQOz4tGD2TMF*t!abv-RM?Gs?9<4Y0nvbl6dHVLwMwbj?hiwri4Y}IL;>!9Le6*
z9-2A6u-RNfPvv!5y|^DV!edkRQGvU&11Wu|DXf=pX+Zj;0_?BQJLm!_h5a~P0&mR}
z;lGQQ*}kA%0Vwl!u;<+bM8Z9l3D6^HMamqsaGa4ruNkUBh@o)N20XM&Z7aB5o44Q`
zdhjr0l=vhjRn75~))ZJ!d_7i~@SZg#GH9n+J;5y>25E_$5^iI1Ny|{_ASCU-Lou#`
z5Q10-mZ<9IK&}6id
zRd&D+lwr{+XvSJd8{~JqnfH(x+bR2CG_IJ_@sFrpf(%|@d1wCo+Zm^fC!;5n;|-aW
z@G=!A2K645?t;`7Qe;e#p3O+|-ZrA*SBNkWprvwnQpNtkj-R4)ZBZQe9iEtA@|MdtbSEhEC%CHvLL;mcfH;qOFnif)udGFBUDB{^e+6Go80hv`&Jcj4%O
zzY2T++ZHBFMKe*Fs^Xi2zMtVw_}RbFGAVpl;XhT#M1CKI(+Vsp53SHK#Ht{{?+(PR>U*i+aiqyE~y!ddSyQj87R0reKS9GY(VinUXA=?Ld-Fc-_oe
wPDB$qDeq!>*LI0CACsCqobc%H9uJ4|icz`Zebh6HmRImfp5xhcUsksMAHMuGDgXcg
literal 0
HcmV?d00001
diff --git a/experiments/robot/aloha/aloha_utils.py b/experiments/robot/aloha/aloha_utils.py
new file mode 100644
index 0000000..002fbbf
--- /dev/null
+++ b/experiments/robot/aloha/aloha_utils.py
@@ -0,0 +1,85 @@
+"""Utils for evaluating policies in real-world ALOHA environments."""
+
+import os
+
+import imageio
+import numpy as np
+from PIL import Image
+
+from experiments.robot.aloha.real_env import make_real_env
+from experiments.robot.robot_utils import (
+ DATE,
+ DATE_TIME,
+)
+
+
+def get_next_task_label(task_label):
+ """Prompt the user to input the next task."""
+ if task_label == "":
+ user_input = ""
+ while user_input == "":
+ user_input = input("Enter the task name: ")
+ task_label = user_input
+ else:
+ user_input = input("Enter the task name (or leave blank to repeat the previous task): ")
+ if user_input == "":
+ pass # Do nothing -> Let task_label be the same
+ else:
+ task_label = user_input
+ print(f"Task: {task_label}")
+ return task_label
+
+
+def get_aloha_env():
+ """Initializes and returns the ALOHA environment."""
+ env = make_real_env(init_node=True)
+ return env
+
+
+def resize_image_for_preprocessing(img):
+ """
+ Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done
+ in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS.
+ """
+ ALOHA_PREPROCESS_SIZE = 256
+ img = np.array(
+ Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC)
+ ) # BICUBIC is default; specify explicitly to make it clear
+ return img
+
+
+def get_aloha_image(obs):
+ """Extracts third-person image from observations and preprocesses it."""
+ # obs: dm_env._environment.TimeStep
+ img = obs.observation["images"]["cam_high"]
+ img = resize_image_for_preprocessing(img)
+ return img
+
+
+def get_aloha_wrist_images(obs):
+ """Extracts both wrist camera images from observations and preprocesses them."""
+ # obs: dm_env._environment.TimeStep
+ # left_wrist_img = obs.observation["images"]["cam_left_wrist"]
+ right_wrist_img = obs.observation["images"]["cam_right_wrist"]
+ # left_wrist_img = resize_image_for_preprocessing(left_wrist_img)
+ right_wrist_img = resize_image_for_preprocessing(right_wrist_img)
+ return right_wrist_img
+
+
+def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None):
+ """Saves an MP4 replay of an episode."""
+ rollout_dir = f"./rollouts/{DATE}"
+ os.makedirs(rollout_dir, exist_ok=True)
+ processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
+ filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}"
+ if notes is not None:
+ filetag += f"--{notes}"
+ mp4_path = f"{filetag}.mp4"
+ video_writer = imageio.get_writer(mp4_path, fps=25)
+ for img in rollout_images:
+ video_writer.append_data(img)
+ video_writer.close()
+ print(f"Saved rollout MP4 at path {mp4_path}")
+ if log_file is not None:
+ log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
+ return mp4_path
diff --git a/experiments/robot/aloha/constants.py b/experiments/robot/aloha/constants.py
new file mode 100755
index 0000000..b0c0ea5
--- /dev/null
+++ b/experiments/robot/aloha/constants.py
@@ -0,0 +1,128 @@
+### Task parameters
+import pathlib
+import os
+
+# # DATA_DIR = os.path.expanduser('~/data')
+# DATA_DIR ='/home/test/data/aloha_real'
+#
+# TASK_CONFIGS = {
+#
+# 'single_test': {
+# 'dataset_dir': DATA_DIR + '/single_test',
+# 'episode_len': 300,
+# 'camera_names': ['cam_arm_wrist']
+# },
+#
+# }
+#
+# ### ALOHA fixed constants
+# DT = 0.02
+# FPS = 50
+# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
+# START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
+#
+# XML_DIR = str(pathlib.Path(__file__).parent.resolve()) + '/assets/' # note: absolute path
+#
+# # Left finger position limits (qpos[7]), right_finger = -1 * left_finger
+# MASTER_GRIPPER_POSITION_OPEN = 0.02417
+# MASTER_GRIPPER_POSITION_CLOSE = 0.01244
+# PUPPET_GRIPPER_POSITION_OPEN = 0.05800
+# PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
+#
+# # Gripper joint limits (qpos[6])
+# MASTER_GRIPPER_JOINT_OPEN = 0.3083 # noetic
+# MASTER_GRIPPER_JOINT_CLOSE = -0.6842 # noetic
+# PUPPET_GRIPPER_JOINT_OPEN = 1.4910
+# PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
+#
+# ############################ Helper functions ############################
+#
+# MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+# PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+# MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
+# PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
+# MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
+#
+# MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+# PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+# MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
+# PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
+# MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
+#
+# MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+# PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+#
+# MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
+# MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
+# PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
+# PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
+#
+# MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
+#
+# ### Task parameters
+
+DATA_DIR ='/home/test/data/aloha_real'
+
+TASK_CONFIGS = {
+
+ 'put_capybara_into_the_box': {
+ 'dataset_dir': DATA_DIR + '/put_capybara_into_the_box',
+ 'episode_len': 500,
+ 'camera_names': ['cam_high', 'cam_right_wrist']
+ },
+
+ 'put_NAILONG_into_the_box': {
+ 'dataset_dir': DATA_DIR + '/put_NAILONG_into_the_box',
+ 'episode_len': 500,
+ 'camera_names': ['cam_high', 'cam_right_wrist']
+ },
+
+ 'put_Banana_into_the_box': {
+ 'dataset_dir': DATA_DIR + '/put_NAILONG_into_the_box',
+ 'episode_len': 500,
+ 'camera_names': ['cam_high', 'cam_right_wrist']
+ },
+
+}
+
+### ALOHA fixed constants
+DT = 0.02
+FPS = 50
+JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
+START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
+
+# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
+MASTER_GRIPPER_POSITION_OPEN = 0.02417
+MASTER_GRIPPER_POSITION_CLOSE = 0.01244
+PUPPET_GRIPPER_POSITION_OPEN = 0.05800
+PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
+
+# Gripper joint limits (qpos[6])
+MASTER_GRIPPER_JOINT_OPEN = 0.3083
+MASTER_GRIPPER_JOINT_CLOSE = -0.6842
+PUPPET_GRIPPER_JOINT_OPEN = 1.4910
+PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
+
+############################ Helper functions ############################
+
+MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
+PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
+MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
+
+MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
+PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
+MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
+
+MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+
+MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
+MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
+PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
+PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
+
+MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
diff --git a/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_14_50.txt b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_14_50.txt
new file mode 100644
index 0000000..2b22f32
--- /dev/null
+++ b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_14_50.txt
@@ -0,0 +1,7 @@
+Loading local VLA model...
+
+Task: put_capybara_into_the_box
+Starting episode 1...
+Prepare the scene, and then press Enter to begin...
+
+Caught exception: RealEnv.get_observation() got an unexpected keyword argument 't'
diff --git a/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_18_25.txt b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_18_25.txt
new file mode 100644
index 0000000..eb151ea
--- /dev/null
+++ b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_18_25.txt
@@ -0,0 +1 @@
+Loading local VLA model...
diff --git a/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_20_19.txt b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_20_19.txt
new file mode 100644
index 0000000..2b22f32
--- /dev/null
+++ b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_20_19.txt
@@ -0,0 +1,7 @@
+Loading local VLA model...
+
+Task: put_capybara_into_the_box
+Starting episode 1...
+Prepare the scene, and then press Enter to begin...
+
+Caught exception: RealEnv.get_observation() got an unexpected keyword argument 't'
diff --git a/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_23_06.txt b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_23_06.txt
new file mode 100644
index 0000000..e770282
--- /dev/null
+++ b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_23_06.txt
@@ -0,0 +1,5 @@
+Loading local VLA model...
+
+Task: put_capybara_into_the_box
+Starting episode 1...
+Prepare the scene, and then press Enter to begin...
diff --git a/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_27_04.txt b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_27_04.txt
new file mode 100644
index 0000000..eb151ea
--- /dev/null
+++ b/experiments/robot/aloha/experiments/logs/EVAL-LOCAL-openvla-2025_11_08-15_27_04.txt
@@ -0,0 +1 @@
+Loading local VLA model...
diff --git a/experiments/robot/aloha/preprocess_split_aloha_data.py b/experiments/robot/aloha/preprocess_split_aloha_data.py
new file mode 100644
index 0000000..8de07f2
--- /dev/null
+++ b/experiments/robot/aloha/preprocess_split_aloha_data.py
@@ -0,0 +1,260 @@
+"""
+Preprocesses ALOHA dataset(s) and splits them into train/val sets.
+
+Preprocessing includes downsizing images from 480x640 to 256x256.
+Splits happen at the episode level (not step level), which means that
+an episode is treated as an atomic unit that entirely goes to either
+the train set or val set.
+
+Original ALOHA data layout:
+ /PATH/TO/DATASET/dataset_name/
+ - episode_0.hdf5
+ - episode_1.hdf5
+ - ...
+ - episode_N.hdf5
+
+Preprocessed data layout (after running this script):
+ /PATH/TO/PREPROCESSED_DATASETS/dataset_name/
+ - train/
+ - episode_0.hdf5
+ - episode_1.hdf5
+ - ...
+ - episode_M.hdf5
+ - val/
+ - episode_0.hdf5
+ - episode_1.hdf5
+ - ...
+ - episode_K.hdf5
+
+ where N > M > K
+
+Example usage:
+ # "put X into pot" task
+ python experiments/robot/aloha/preprocess_split_aloha_data.py \
+ --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \
+ --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
+ --percent_val 0.05 && \
+ python experiments/robot/aloha/preprocess_split_aloha_data.py \
+ --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \
+ --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
+ --percent_val 0.05 && \
+ python experiments/robot/aloha/preprocess_split_aloha_data.py \
+ --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \
+ --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \
+ --percent_val 0.05
+"""
+
+import argparse
+import glob
+import os
+import random
+
+import h5py
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+def load_hdf5(demo_path):
+ """Loads single episode."""
+ if not os.path.isfile(demo_path):
+ print(f"Dataset does not exist at \n{demo_path}\n")
+ exit()
+
+ print(f"Loading {demo_path}...")
+ with h5py.File(demo_path, "r") as root:
+ is_sim = root.attrs["sim"]
+ qpos = root["/observations/qpos"][()]
+ qvel = root["/observations/qvel"][()]
+ effort = root["/observations/effort"][()]
+ action = root["/action"][()]
+ image_dict = dict()
+ for cam_name in root["/observations/images/"].keys():
+ image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()]
+ print(f"Loading episode complete: {demo_path}")
+
+ return qpos, qvel, effort, action, image_dict, is_sim
+
+
+def load_and_preprocess_all_episodes(demo_paths, out_dataset_dir):
+ """
+ Loads and preprocesses all episodes.
+ Resizes all images in one episode before loading the next, to reduce memory usage.
+ """
+ cam_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
+ idx = 0
+ for demo in tqdm(demo_paths):
+ qpos, qvel, effort, action, image_dict, is_sim = load_hdf5(demo)
+ # Save non-image info
+ episode_len = image_dict["cam_high"].shape[0]
+ # Resize all images
+ print("Resizing images in episode...")
+ for k in cam_names:
+ resized_images = []
+ for i in range(episode_len):
+ resized_images.append(
+ np.array(
+ Image.fromarray(image_dict[k][i]).resize(
+ (args.img_resize_size, args.img_resize_size), resample=Image.BICUBIC
+ )
+ ) # BICUBIC is default; specify explicitly to make it clear
+ )
+ image_dict[k] = np.stack(resized_images)
+ print("Resizing images in episode complete!")
+ # Save preprocessed episode
+ data_dict = dict(
+ qpos=qpos,
+ qvel=qvel,
+ effort=effort,
+ action=action,
+ image_dict=image_dict,
+ is_sim=is_sim,
+ )
+ save_new_hdf5(out_dataset_dir, data_dict, idx)
+ idx += 1
+
+
+def randomly_split(full_qpos, full_qvel, full_effort, full_action, full_image_dict, percent_val):
+ """Randomly splits dataset into train and validation sets."""
+ # Create a list of episode indices
+ num_episodes_total = len(full_qpos)
+ indices = list(range(num_episodes_total))
+ # Shuffle the episode indices
+ random.shuffle(indices)
+ # Create new lists using the shuffled indices
+ shuffled_qpos = [full_qpos[idx] for idx in indices]
+ shuffled_qvel = [full_qvel[idx] for idx in indices]
+ shuffled_effort = [full_effort[idx] for idx in indices]
+ shuffled_action = [full_action[idx] for idx in indices]
+ shuffled_image_dict = {
+ "cam_high": [],
+ "cam_left_wrist": [],
+ "cam_right_wrist": [],
+ }
+ for k in full_image_dict.keys():
+ shuffled_image_dict[k] = [full_image_dict[k][idx] for idx in indices]
+ # Split into train and val sets
+ num_episodes_val = int(num_episodes_total * percent_val)
+ print(f"Total # steps: {num_episodes_total}; using {num_episodes_val} ({percent_val:.2f}%) for val set")
+ num_episodes_train = num_episodes_total - num_episodes_val
+ train_dict = dict(
+ qpos=shuffled_qpos[:num_episodes_train],
+ qvel=shuffled_qvel[:num_episodes_train],
+ effort=shuffled_effort[:num_episodes_train],
+ action=shuffled_action[:num_episodes_train],
+ image_dict=dict(
+ cam_high=shuffled_image_dict["cam_high"][:num_episodes_train],
+ cam_left_wrist=shuffled_image_dict["cam_left_wrist"][:num_episodes_train],
+ cam_right_wrist=shuffled_image_dict["cam_right_wrist"][:num_episodes_train],
+ ),
+ )
+ val_dict = dict(
+ qpos=shuffled_qpos[num_episodes_train:],
+ qvel=shuffled_qvel[num_episodes_train:],
+ effort=shuffled_effort[num_episodes_train:],
+ action=shuffled_action[num_episodes_train:],
+ image_dict=dict(
+ cam_high=shuffled_image_dict["cam_high"][num_episodes_train:],
+ cam_left_wrist=shuffled_image_dict["cam_left_wrist"][num_episodes_train:],
+ cam_right_wrist=shuffled_image_dict["cam_right_wrist"][num_episodes_train:],
+ ),
+ )
+ return train_dict, val_dict
+
+
+def save_new_hdf5(out_dataset_dir, data_dict, episode_idx):
+ """Saves an HDF5 file for a new episode."""
+ camera_names = data_dict["image_dict"].keys()
+ H, W, C = data_dict["image_dict"]["cam_high"][0].shape
+ out_path = os.path.join(out_dataset_dir, f"episode_{episode_idx}.hdf5")
+ # Save HDF5 with same structure as original demos (except that now we combine all episodes into one HDF5 file)
+ with h5py.File(
+ out_path, "w", rdcc_nbytes=1024**2 * 2
+ ) as root: # Magic constant for rdcc_nbytes comes from ALOHA codebase
+ episode_len = data_dict["qpos"].shape[0]
+ root.attrs["sim"] = data_dict["is_sim"]
+ obs = root.create_group("observations")
+ _ = obs.create_dataset("qpos", (episode_len, 14))
+ _ = obs.create_dataset("qvel", (episode_len, 14))
+ _ = obs.create_dataset("effort", (episode_len, 14))
+ root["/observations/qpos"][...] = data_dict["qpos"]
+ root["/observations/qvel"][...] = data_dict["qvel"]
+ root["/observations/effort"][...] = data_dict["effort"]
+ image = obs.create_group("images")
+ for cam_name in camera_names:
+ _ = image.create_dataset(
+ cam_name,
+ (episode_len, H, W, C),
+ dtype="uint8",
+ chunks=(1, H, W, C),
+ )
+ root[f"/observations/images/{cam_name}"][...] = data_dict["image_dict"][cam_name]
+ _ = root.create_dataset("action", (episode_len, 14))
+ root["/action"][...] = data_dict["action"]
+ # Compute and save *relative* actions as well
+ actions = data_dict["action"]
+ relative_actions = np.zeros_like(actions)
+ relative_actions[:-1] = actions[1:] - actions[:-1] # Relative actions are the changes in joint pos
+ relative_actions[-1] = relative_actions[-2] # Just copy the second-to-last action for the last action
+ _ = root.create_dataset("relative_action", (episode_len, 14))
+ root["/relative_action"][...] = relative_actions
+ print(f"Saved dataset: {out_path}")
+
+
+def main(args):
+ # Create directory to save preprocessed dataset (if it doesn't exist already)
+ os.makedirs(args.out_base_dir, exist_ok=True)
+ out_dataset_dir = os.path.join(args.out_base_dir, os.path.basename(args.dataset_path.rstrip("/")))
+ os.makedirs(out_dataset_dir, exist_ok=True)
+ # Get list of filepaths of all episodes
+ all_demo_paths = glob.glob(os.path.join(args.dataset_path, "*.hdf5")) # List of HDF5 filepaths
+ all_demo_paths.sort()
+ # Create a list of episode indices
+ num_episodes_total = len(all_demo_paths)
+ indices = list(range(num_episodes_total))
+ # Shuffle the episode indices
+ random.shuffle(indices)
+ # Split into train and val sets
+ num_episodes_val = int(num_episodes_total * args.percent_val)
+ print(f"Total # episodes: {num_episodes_total}; using {num_episodes_val} ({args.percent_val:.2f}%) for val set")
+ num_episodes_train = num_episodes_total - num_episodes_val
+ train_indices = indices[:num_episodes_train]
+ val_indices = indices[num_episodes_train:]
+ train_demo_paths = [all_demo_paths[i] for i in train_indices]
+ val_demo_paths = [all_demo_paths[i] for i in val_indices]
+ # Preprocess all episodes and save the result
+ out_dataset_dir_train = os.path.join(out_dataset_dir, "train")
+ out_dataset_dir_val = os.path.join(out_dataset_dir, "val")
+ os.makedirs(out_dataset_dir_train, exist_ok=True)
+ os.makedirs(out_dataset_dir_val, exist_ok=True)
+ load_and_preprocess_all_episodes(train_demo_paths, out_dataset_dir_train)
+ load_and_preprocess_all_episodes(val_demo_paths, out_dataset_dir_val)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset_path",
+ required=True,
+ help="Path to raw ALOHA dataset directory. Example: /PATH/TO/USER/data/aloha_raw/put_green_pepper_into_pot/",
+ )
+ parser.add_argument(
+ "--out_base_dir",
+ required=True,
+ help="Path to directory in which to save preprocessed dataset. Example: /PATH/TO/USER/data/aloha_preprocessed/",
+ )
+ parser.add_argument(
+ "--percent_val",
+ type=float,
+ help="Percent of dataset to use as validation set (measured in episodes, not steps).",
+ default=0.05,
+ )
+ parser.add_argument(
+ "--img_resize_size",
+ type=int,
+ help="Size to resize images to. Final images will be square (img_resize_size x img_resize_size pixels).",
+ default=256,
+ )
+ args = parser.parse_args()
+
+ main(args)
diff --git a/experiments/robot/aloha/real_env.py b/experiments/robot/aloha/real_env.py
new file mode 100644
index 0000000..ddc6b66
--- /dev/null
+++ b/experiments/robot/aloha/real_env.py
@@ -0,0 +1,242 @@
+import time
+import numpy as np
+import collections
+import matplotlib.pyplot as plt
+import dm_env
+from pyquaternion import Quaternion
+
+from constants import DT, START_ARM_POSE, MASTER_GRIPPER_JOINT_NORMALIZE_FN, PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN
+from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
+from constants import PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
+from robot_utils import Recorder, ImageRecorder
+
+# from base_recorder import BaseRecorder
+# from scan_recorder import SCANRecorder
+# from imu_recorder import IMURecorder
+
+from robot_utils import setup_master_bot, setup_puppet_bot, move_arms, move_grippers
+from interbotix_xs_modules.arm import InterbotixManipulatorXS
+from interbotix_xs_msgs.msg import JointSingleCommand
+# import pyrealsense2 as rs
+# from dynamixel_client import DynamixelClient
+
+import IPython
+e = IPython.embed
+
+class RealEnv:
+ """
+ Environment for real robot bi-manual manipulation
+ Action space: [right_arm_qpos (6), # absolute joint position
+ right_gripper_positions (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
+
+ Observation space: {"qpos": Concat[ right_arm_qpos (6), # absolute joint position
+ right_gripper_position (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
+ "qvel": Concat[ right_arm_qvel (6), # absolute joint velocity (rad)
+ right_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
+ right_arm_qvel (6), # absolute joint velocity (rad)
+ right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
+ "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
+ "cam_low": (480x640x3), # h, w, c, dtype='uint8'
+ "cam_right_wrist": (480x640x3), # h, w, c, dtype='uint8'
+ "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
+ """
+
+ def __init__(self, init_node, setup_robots=True, setup_base=False):
+ # self.puppet_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
+ # robot_name=f'puppet_right', init_node=init_node)
+ self.puppet_bot_right = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper",
+ robot_name=f'puppet_right', init_node=init_node)
+
+ if setup_robots:
+ self.setup_robots()
+
+ # #if setup_base:
+ # self.setup_base()
+
+
+ self.recorder_right = Recorder('right', init_node=False)
+ # self.base_recorder = BaseRecorder(init_node=False)
+ # self.scan_recorder = SCANRecorder(init_node=False) # ι·θΎΎscan
+ # self.imu_recorder = IMURecorder(init_node=False) # imu
+ self.image_recorder = ImageRecorder(init_node=False)
+ self.gripper_command = JointSingleCommand(name="gripper")
+
+ def setup_robots(self):
+ setup_puppet_bot(self.puppet_bot_right)
+
+ def get_qpos(self):
+ right_qpos_raw = self.recorder_right.qpos
+ right_arm_qpos = right_qpos_raw[:6]
+ right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])] # this is position not joint
+ return np.concatenate([right_arm_qpos, right_gripper_qpos])
+
+ def get_qvel(self):
+ right_qvel_raw = self.recorder_right.qvel
+ right_arm_qvel = right_qvel_raw[:6]
+ right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
+ return np.concatenate([right_arm_qvel, right_gripper_qvel])
+
+ def get_effort(self):
+ right_effort_raw = self.recorder_right.effort
+ right_robot_effort = right_effort_raw[:7]
+ return np.concatenate([right_robot_effort])
+
+ # cam
+ def get_images(self):
+ return self.image_recorder.get_images() # noetic
+
+ # -------------------------------------------
+ # def get_base_vel(self):
+ # return self.base_recorder.get_vel()
+
+ # ι·θΎΎscan
+ # def get_scan_vel(self):
+ # return self.scan_recorder.get_scan_vel()
+
+ # ι·θΎΎscan
+ # def get_imu_vel(self):
+ # return self.imu_recorder.get_imu_vel()
+
+ # def get_tracer_vel(self):
+ # linear_vel, angular_vel = self.tracer.GetLinearVelocity(), self.tracer.GetAngularVelocity()
+ # return np.array([linear_vel, angular_vel])
+
+
+ def set_gripper_pose(self, right_gripper_desired_pos_normalized):
+ right_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(right_gripper_desired_pos_normalized)
+ self.gripper_command.cmd = right_gripper_desired_joint
+ self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
+
+ def _reset_joints(self):
+ reset_position = START_ARM_POSE[:6]
+ move_arms([self.puppet_bot_right], [reset_position], move_time=1)
+
+ def _reset_gripper(self):
+ """Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
+ move_grippers([self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN], move_time=0.5)
+ move_grippers([self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_CLOSE] , move_time=1)
+
+ def _get_obs(self):
+ obs = collections.OrderedDict()
+ obs['qpos'] = self.get_qpos()
+ obs['qvel'] = self.get_qvel()
+ obs['effort'] = self.get_effort()
+ obs['images'] = self.get_images()
+ return obs
+
+ def get_observation(self, t=0):
+ step_type = dm_env.StepType.FIRST if t == 0 else dm_env.StepType.MID
+ return dm_env.TimeStep(
+ step_type=step_type,
+ reward=self.get_reward(),
+ discount=None,
+ observation=self._get_obs()
+ )
+
+ def get_reward(self):
+ return 0
+
+ def reset(self, fake=False):
+ if not fake:
+ # Reboot puppet robot gripper motors
+ self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
+ self._reset_joints()
+ self._reset_gripper()
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.FIRST,
+ reward=self.get_reward(),
+ discount=None,
+ observation=self.get_observation())
+
+ def step(self, action, base_action=None, get_tracer_vel=False, get_obs=True):
+ state_len = int(len(action))
+ right_action = action[:state_len]
+
+ self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
+ self.set_gripper_pose(right_action[-1])
+ #if base_action is not None:
+ # linear_vel_limit = 1.5
+ # angular_vel_limit = 1.5
+ # base_action_linear = np.clip(base_action[0], -linear_vel_limit, linear_vel_limit)
+ # base_action_angular = np.clip(base_action[1], -angular_vel_limit, angular_vel_limit)
+ # base_action_linear, base_action_angular = base_action
+ # self.tracer.SetMotionCommand(linear_vel=base_action_linear, angular_vel=base_action_angular)
+ # time.sleep(DT)
+ if get_obs:
+ obs = self.get_observation(get_tracer_vel)
+ else:
+ obs = None
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.MID,
+ reward=self.get_reward(),
+ discount=None,
+ observation=obs)
+
+def get_action(master_bot_right):
+ action = np.zeros(7) # 6 joint + 1 gripper, for two arms
+ # Arm actions
+ action[:6] = master_bot_right.dxl.joint_states.position[:6]
+ # Gripper actions
+ action[6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
+
+ return action
+
+# def get_base_action():
+
+
+
+def make_real_env(init_node, setup_robots=True, setup_base=False):
+ env = RealEnv(init_node, setup_robots, setup_base)
+ return env
+
+
+def test_real_teleop():
+ """
+ Test bimanual teleoperation and show image observations onscreen.
+ It first reads joint poses from both master arms.
+ Then use it as actions to step the environment.
+ The environment returns full observations including images.
+
+ An alternative approach is to have separate scripts for teleoperation and observation recording.
+ This script will result in higher fidelity (obs, action) pairs
+ """
+
+ onscreen_render = True
+ render_cam = 'cam_right_wrist'
+
+ # source of data
+ master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
+ robot_name=f'master_right', init_node=True)
+
+ setup_master_bot(master_bot_right)
+
+
+ # setup the environment
+ env = make_real_env(init_node=False)
+ ts = env.reset(fake=True)
+ episode = [ts]
+ # setup visualization
+ if onscreen_render:
+ ax = plt.subplot()
+ plt_img = ax.imshow(ts.observation['images'][render_cam])
+ plt.ion()
+
+ for t in range(1000):
+ action = get_action(master_bot_right)
+ ts = env.step(action)
+ episode.append(ts)
+
+ if onscreen_render:
+ plt_img.set_data(ts.observation['images'][render_cam])
+ plt.pause(DT)
+ else:
+ time.sleep(DT)
+
+
+if __name__ == '__main__':
+ test_real_teleop()
+
diff --git a/experiments/robot/aloha/requirements_aloha.txt b/experiments/robot/aloha/requirements_aloha.txt
new file mode 100644
index 0000000..c84c6d0
--- /dev/null
+++ b/experiments/robot/aloha/requirements_aloha.txt
@@ -0,0 +1,26 @@
+numpy<2
+draccus
+torchvision
+torch
+pyquaternion
+pyyaml
+rospkg
+pexpect
+mujoco==2.3.7
+dm_control==1.0.14
+opencv-python
+matplotlib
+einops
+packaging
+h5py
+traitlets
+ipdb
+IPython
+modern_robotics
+Pillow
+termcolor
+imageio[ffmpeg]
+uvicorn
+fastapi
+requests
+json_numpy
diff --git a/experiments/robot/aloha/robot_utils.py b/experiments/robot/aloha/robot_utils.py
new file mode 100644
index 0000000..3db3008
--- /dev/null
+++ b/experiments/robot/aloha/robot_utils.py
@@ -0,0 +1,187 @@
+import numpy as np
+import time
+from experiments.robot.aloha.constants import DT
+from interbotix_xs_msgs.msg import JointSingleCommand
+
+import IPython
+e = IPython.embed
+
+class ImageRecorder:
+ def __init__(self, init_node=True, is_debug=False):
+ from collections import deque
+ import rospy
+ from cv_bridge import CvBridge
+ from sensor_msgs.msg import Image
+ self.is_debug = is_debug
+ self.bridge = CvBridge()
+ self.camera_names = ['cam_high', 'cam_right_wrist']
+ if init_node:
+ rospy.init_node('image_recorder', anonymous=True)
+ for cam_name in self.camera_names:
+ setattr(self, f'{cam_name}_image', None)
+ setattr(self, f'{cam_name}_secs', None)
+ setattr(self, f'{cam_name}_nsecs', None)
+ if cam_name == 'cam_high':
+ callback_func = self.image_cb_cam_high
+ elif cam_name == 'cam_low':
+ callback_func = self.image_cb_cam_low
+ elif cam_name == 'cam_left_wrist':
+ callback_func = self.image_cb_cam_left_wrist
+ elif cam_name == 'cam_right_wrist':
+ callback_func = self.image_cb_cam_right_wrist
+ else:
+ raise NotImplementedError
+ rospy.Subscriber(f"/usb_{cam_name}/image_raw", Image, callback_func)
+ if self.is_debug:
+ setattr(self, f'{cam_name}_timestamps', deque(maxlen=50))
+ time.sleep(0.5)
+
+ def image_cb(self, cam_name, data):
+ setattr(self, f'{cam_name}_image', self.bridge.imgmsg_to_cv2(data, desired_encoding='passthrough'))
+ setattr(self, f'{cam_name}_secs', data.header.stamp.secs)
+ setattr(self, f'{cam_name}_nsecs', data.header.stamp.nsecs)
+ # cv2.imwrite('/home/tonyzhao/Desktop/sample.jpg', cv_image)
+ if self.is_debug:
+ getattr(self, f'{cam_name}_timestamps').append(data.header.stamp.secs + data.header.stamp.secs * 1e-9)
+
+ def image_cb_cam_high(self, data):
+ cam_name = 'cam_high'
+ return self.image_cb(cam_name, data)
+
+ def image_cb_cam_low(self, data):
+ cam_name = 'cam_low'
+ return self.image_cb(cam_name, data)
+
+ def image_cb_cam_left_wrist(self, data):
+ cam_name = 'cam_left_wrist'
+ return self.image_cb(cam_name, data)
+
+ def image_cb_cam_right_wrist(self, data):
+ cam_name = 'cam_right_wrist'
+ return self.image_cb(cam_name, data)
+
+ def get_images(self):
+ image_dict = dict()
+ for cam_name in self.camera_names:
+ image_dict[cam_name] = getattr(self, f'{cam_name}_image')
+ return image_dict
+
+ def print_diagnostics(self):
+ def dt_helper(l):
+ l = np.array(l)
+ diff = l[1:] - l[:-1]
+ return np.mean(diff)
+ for cam_name in self.camera_names:
+ image_freq = 1 / dt_helper(getattr(self, f'{cam_name}_timestamps'))
+ print(f'{cam_name} {image_freq=:.2f}')
+ print()
+
+class Recorder:
+ def __init__(self, side, init_node=True, is_debug=False):
+ from collections import deque
+ import rospy
+ from sensor_msgs.msg import JointState
+ from interbotix_xs_msgs.msg import JointGroupCommand, JointSingleCommand
+
+ self.secs = None
+ self.nsecs = None
+ self.qpos = None
+ self.effort = None
+ self.arm_command = None
+ self.gripper_command = None
+ self.is_debug = is_debug
+
+ if init_node:
+ rospy.init_node('recorder', anonymous=True)
+ rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
+ rospy.Subscriber(f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb)
+ rospy.Subscriber(f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb)
+ if self.is_debug:
+ self.joint_timestamps = deque(maxlen=50)
+ self.arm_command_timestamps = deque(maxlen=50)
+ self.gripper_command_timestamps = deque(maxlen=50)
+ time.sleep(0.1)
+
+ def puppet_state_cb(self, data):
+ self.qpos = data.position
+ self.qvel = data.velocity
+ self.effort = data.effort
+ self.data = data
+ if self.is_debug:
+ self.joint_timestamps.append(time.time())
+
+ def puppet_arm_commands_cb(self, data):
+ self.arm_command = data.cmd
+ if self.is_debug:
+ self.arm_command_timestamps.append(time.time())
+
+ def puppet_gripper_commands_cb(self, data):
+ self.gripper_command = data.cmd
+ if self.is_debug:
+ self.gripper_command_timestamps.append(time.time())
+
+ def print_diagnostics(self):
+ def dt_helper(l):
+ l = np.array(l)
+ diff = l[1:] - l[:-1]
+ return np.mean(diff)
+
+ joint_freq = 1 / dt_helper(self.joint_timestamps)
+ arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
+ gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
+
+ print(f'{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n')
+
+def get_arm_joint_positions(bot):
+ return bot.arm.core.joint_states.position[:6]
+
+def get_arm_gripper_positions(bot):
+ joint_position = bot.gripper.core.joint_states.position[6]
+ return joint_position
+
+def move_arms(bot_list, target_pose_list, move_time=1):
+ num_steps = int(move_time / DT)
+ curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
+ traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)]
+ for t in range(num_steps):
+ for bot_id, bot in enumerate(bot_list):
+ bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
+ time.sleep(DT)
+
+def move_grippers(bot_list, target_pose_list, move_time):
+ gripper_command = JointSingleCommand(name="gripper")
+ num_steps = int(move_time / DT)
+ curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
+ traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)]
+ for t in range(num_steps):
+ for bot_id, bot in enumerate(bot_list):
+ gripper_command.cmd = traj_list[bot_id][t]
+ bot.gripper.core.pub_single.publish(gripper_command)
+ time.sleep(DT)
+
+def setup_puppet_bot(bot):
+ bot.dxl.robot_reboot_motors("single", "gripper", True)
+ bot.dxl.robot_set_operating_modes("group", "arm", "position")
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
+ torque_on(bot)
+
+def setup_master_bot(bot):
+ bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
+ torque_off(bot)
+
+def set_standard_pid_gains(bot):
+ bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 800)
+ bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0)
+
+def set_low_pid_gains(bot):
+ bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 100)
+ bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0)
+
+def torque_off(bot):
+ bot.dxl.robot_torque_enable("group", "arm", False)
+ bot.dxl.robot_torque_enable("single", "gripper", False)
+
+def torque_on(bot):
+ bot.dxl.robot_torque_enable("group", "arm", True)
+ bot.dxl.robot_torque_enable("single", "gripper", True)
\ No newline at end of file
diff --git a/experiments/robot/aloha/run_aloha_eval.py b/experiments/robot/aloha/run_aloha_eval.py
new file mode 100644
index 0000000..b3da22c
--- /dev/null
+++ b/experiments/robot/aloha/run_aloha_eval.py
@@ -0,0 +1,449 @@
+"""
+run_aloha_eval_local.py
+
+Evaluates a model in a real-world ALOHA environment with local model deployment.
+"""
+
+import logging
+import os
+import sys
+import time
+from collections import deque
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import draccus
+import tqdm
+import torch
+import json
+import numpy as np
+
+# Append current directory so that interpreter can find experiments.robot
+sys.path.append(".")
+from experiments.robot.aloha.aloha_utils import (
+ get_aloha_env,
+ get_aloha_image,
+ get_aloha_wrist_images,
+ get_next_task_label,
+ save_rollout_video,
+)
+from experiments.robot.openvla_utils import (
+ get_vla,
+ get_vla_action,
+ get_action_head,
+ get_processor,
+ get_proprio_projector,
+)
+from experiments.robot.robot_utils import (
+ DATE_TIME,
+ get_image_resize_size,
+ set_seed_everywhere,
+)
+
+# Set up logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ handlers=[logging.StreamHandler()],
+)
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class GenerateConfig:
+ # fmt: off
+
+ #################################################################################################################
+ # Model-specific parameters
+ #################################################################################################################
+ model_family: str = "openvla" # Model family
+ pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
+
+ use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective
+ use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM)
+ num_diffusion_steps: int = 50 # (When `diffusion==True`) Number of diffusion steps for inference
+ use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
+ num_images_in_input: int = 3 # Number of images in the VLA input (default: 3)
+ use_proprio: bool = True # Whether to include proprio state in input
+
+ center_crop: bool = True # Center crop? (if trained w/ random crop image aug)
+ num_open_loop_steps: int = 25 # Number of actions to execute open-loop before requerying policy
+
+ unnorm_key: Union[str, Path] = "" # Action un-normalization key
+
+ load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
+ load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization
+
+ #################################################################################################################
+ # ALOHA environment-specific parameters
+ #################################################################################################################
+ num_rollouts_planned: int = 50 # Number of test rollouts
+ max_steps: int = 1500 # Max number of steps per rollout
+ use_relative_actions: bool = False # Whether to use relative actions (delta joint angles)
+
+ #################################################################################################################
+ # Utils
+ #################################################################################################################
+ run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
+ local_log_dir: str = "./experiments/logs" # Local directory for eval logs
+
+ seed: int = 7 # Random Seed (for reproducibility)
+
+ save_version: str = "vla-adapter" # version of
+ use_pro_version: bool = True # encourage to use the pro models we released.
+ phase: str = "Inference"
+
+ # fmt: on
+
+
+class LocalVLAModel:
+ """Local VLA model for direct inference without server."""
+
+ def __init__(self, cfg: GenerateConfig):
+ self.cfg = cfg
+
+ # Load model
+ self.vla = get_vla(cfg)
+
+ # Load proprio projector
+ self.proprio_projector = None
+ if cfg.use_proprio:
+ self.proprio_projector = get_proprio_projector(cfg, self.vla.llm_dim, 7) # PROPRIO_DIM = 14
+
+ # Load continuous action head
+ self.action_head = None
+ if cfg.use_l1_regression or cfg.use_diffusion:
+ self.action_head = get_action_head(cfg, self.vla.llm_dim)
+
+ # Check that the model contains the action un-normalization key
+ assert cfg.unnorm_key in self.vla.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!"
+
+ # Get Hugging Face processor
+ self.processor = get_processor(cfg)
+
+ # Get expected image dimensions
+ self.resize_size = get_image_resize_size(cfg)
+
+ logger.info("Local VLA model loaded successfully")
+
+ def get_action(self, observation: Dict[str, Any]) -> np.ndarray:
+ """Get action from local model."""
+ instruction = observation["instruction"]
+
+ action = get_vla_action(
+ self.cfg,
+ self.vla,
+ self.processor,
+ observation,
+ instruction,
+ action_head=self.action_head,
+ proprio_projector=self.proprio_projector,
+ use_film=self.cfg.use_film,
+ )
+
+ return action
+
+
+def validate_config(cfg: GenerateConfig) -> None:
+ """Validate configuration parameters."""
+ assert cfg.pretrained_checkpoint, "Must provide pretrained_checkpoint for local model deployment!"
+ assert os.path.exists(cfg.pretrained_checkpoint), f"Checkpoint path {cfg.pretrained_checkpoint} does not exist!"
+
+
+def setup_logging(cfg: GenerateConfig):
+ """Set up logging to file."""
+ # Create run ID
+ run_id = f"EVAL-LOCAL-{cfg.model_family}-{DATE_TIME}"
+ if cfg.run_id_note is not None:
+ run_id += f"--{cfg.run_id_note}"
+
+ # Set up local logging
+ os.makedirs(cfg.local_log_dir, exist_ok=True)
+ local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
+ log_file = open(local_log_filepath, "w")
+ logger.info(f"Logging to local log file: {local_log_filepath}")
+
+ return log_file, local_log_filepath, run_id
+
+
+def log_message(message: str, log_file=None):
+ """Log a message to console and optionally to a log file."""
+ print(message)
+ logger.info(message)
+ if log_file:
+ log_file.write(message + "\n")
+ log_file.flush()
+
+
+def prepare_observation(obs, resize_size):
+ """Prepare observation for policy input."""
+ # Get preprocessed images
+ img = get_aloha_image(obs)
+ right_wrist_img = get_aloha_wrist_images(obs)
+
+ # Resize images to size expected by model
+ from experiments.robot.openvla_utils import resize_image_for_policy
+ img_resized = resize_image_for_policy(img, resize_size)
+ # left_wrist_img_resized = resize_image_for_policy(left_wrist_img, resize_size)
+ right_wrist_img_resized = resize_image_for_policy(right_wrist_img, resize_size)
+
+ # Prepare observations dict
+ observation = {
+ "full_image": img_resized,
+ # "left_wrist_image": left_wrist_img_resized,
+ "right_wrist_image": right_wrist_img_resized,
+ "state": obs.observation["qpos"],
+ }
+
+ # return observation, img_resized, left_wrist_img_resized, right_wrist_img_resized
+ return observation, img_resized, right_wrist_img_resized
+
+
+def run_episode(
+ cfg: GenerateConfig,
+ env,
+ task_description: str,
+ local_model: LocalVLAModel,
+ resize_size,
+ log_file=None,
+):
+ """Run a single episode in the ALOHA environment."""
+ # Define control frequency
+ STEP_DURATION_IN_SEC = 1.0 / 50.0
+
+ # Reset environment
+ obs = env.reset()
+
+ # Initialize action queue
+ action_queue = deque(maxlen=cfg.num_open_loop_steps)
+
+ # Setup
+ t = 0
+ curr_state = None
+ replay_images = []
+ replay_images_resized = []
+ replay_images_left_wrist_resized = []
+ replay_images_right_wrist_resized = []
+
+ log_message("Prepare the scene, and then press Enter to begin...", log_file)
+ input()
+
+ # Reset environment again to fetch first timestep observation
+ obs = env.reset()
+
+ # Fetch initial robot state (but sleep first so that robot stops moving)
+ time.sleep(2)
+ curr_state = env.get_qpos()
+
+ episode_start_time = time.time()
+ total_model_query_time = 0.0
+
+ try:
+ while t < cfg.max_steps:
+ # Get step start time (used to compute how much to sleep between steps)
+ step_start_time = time.time()
+
+ # Get observation
+ obs = env.get_observation(t=t)
+
+ # Save raw high camera image for replay video
+ replay_images.append(obs.observation["images"]["cam_high"])
+
+ # If action queue is empty, requery model
+ if len(action_queue) == 0:
+ # Prepare observation
+ observation, img_resized, right_wrist_resized = prepare_observation(obs, resize_size)
+ observation["instruction"] = task_description
+
+ # Save processed images for replay
+ replay_images_resized.append(img_resized)
+ # replay_images_left_wrist_resized.append(left_wrist_resized)
+ replay_images_right_wrist_resized.append(right_wrist_resized)
+
+ # Query model to get action
+ log_message("Querying local model...", log_file)
+ model_query_start_time = time.time()
+ actions = local_model.get_action(observation)
+ actions = actions[: cfg.num_open_loop_steps]
+ total_model_query_time += time.time() - model_query_start_time
+ action_queue.extend(actions)
+
+ # Get action from queue
+ action = action_queue.popleft()
+ log_message("-----------------------------------------------------", log_file)
+ log_message(f"t: {t}", log_file)
+ log_message(f"action: {action}", log_file)
+
+ # Execute action in environment
+ if cfg.use_relative_actions:
+ # Get absolute joint angles from relative action
+ rel_action = action
+ target_state = curr_state + rel_action
+ obs = env.step(target_state.tolist())
+ # Update current state (assume it is the commanded target state)
+ curr_state = target_state
+ else:
+ obs = env.step(action.tolist())
+ t += 1
+
+ # Sleep until next timestep
+ step_elapsed_time = time.time() - step_start_time
+ if step_elapsed_time < STEP_DURATION_IN_SEC:
+ time_to_sleep = STEP_DURATION_IN_SEC - step_elapsed_time
+ log_message(f"Sleeping {time_to_sleep} sec...", log_file)
+ time.sleep(time_to_sleep)
+
+ except (KeyboardInterrupt, Exception) as e:
+ if isinstance(e, KeyboardInterrupt):
+ log_message("\nCaught KeyboardInterrupt: Terminating episode early.", log_file)
+ else:
+ log_message(f"\nCaught exception: {e}", log_file)
+
+ episode_end_time = time.time()
+
+ # Get success feedback from user
+ user_input = input("Success? Enter 'y' or 'n': ")
+ success = True if user_input.lower() == 'y' else False
+
+ # Calculate episode statistics
+ episode_stats = {
+ "success": success,
+ "total_steps": t,
+ "model_query_time": total_model_query_time,
+ "episode_duration": episode_end_time - episode_start_time,
+ }
+
+ return (
+ episode_stats,
+ replay_images,
+ replay_images_resized,
+ replay_images_left_wrist_resized,
+ replay_images_right_wrist_resized,
+ )
+
+
+def save_episode_videos(
+ replay_images,
+ replay_images_resized,
+ replay_images_left_wrist,
+ replay_images_right_wrist,
+ episode_idx,
+ success,
+ task_description,
+ log_file=None,
+):
+ """Save videos of the episode from different camera angles."""
+ # Save main replay video
+ save_rollout_video(replay_images, episode_idx, success=success, task_description=task_description,
+ log_file=log_file)
+
+ # Save processed view videos
+ save_rollout_video(
+ replay_images_resized,
+ episode_idx,
+ success=success,
+ task_description=task_description,
+ log_file=log_file,
+ notes="resized",
+ )
+ save_rollout_video(
+ replay_images_left_wrist,
+ episode_idx,
+ success=success,
+ task_description=task_description,
+ log_file=log_file,
+ notes="left_wrist_resized",
+ )
+ save_rollout_video(
+ replay_images_right_wrist,
+ episode_idx,
+ success=success,
+ task_description=task_description,
+ log_file=log_file,
+ notes="right_wrist_resized",
+ )
+
+
+@draccus.wrap()
+def eval_aloha_local(cfg: GenerateConfig) -> None:
+ """Main function to evaluate a trained policy in a real-world ALOHA environment with local model."""
+ # Validate configuration
+ validate_config(cfg)
+
+ # Set random seed
+ set_seed_everywhere(cfg.seed)
+
+ # Setup logging
+ log_file, local_log_filepath, run_id = setup_logging(cfg)
+
+ # Load local model
+ log_message("Loading local VLA model...", log_file)
+ local_model = LocalVLAModel(cfg)
+
+ # Get expected image dimensions
+ resize_size = get_image_resize_size(cfg)
+
+ # Get ALOHA environment
+ env = get_aloha_env()
+
+ # Initialize task description
+ task_description = ""
+
+ # Start evaluation
+ num_rollouts_completed, total_successes = 0, 0
+
+ for episode_idx in tqdm.tqdm(range(cfg.num_rollouts_planned)):
+ # Get task description from user
+ task_description = get_next_task_label(task_description)
+ log_message(f"\nTask: {task_description}", log_file)
+
+ log_message(f"Starting episode {num_rollouts_completed + 1}...", log_file)
+
+ # Run episode
+ episode_stats, replay_images, replay_images_resized, replay_images_left_wrist, replay_images_right_wrist = (
+ run_episode(cfg, env, task_description, local_model, resize_size, log_file)
+ )
+
+ # Update counters
+ num_rollouts_completed += 1
+ if episode_stats["success"]:
+ total_successes += 1
+
+ # Save videos
+ save_episode_videos(
+ replay_images,
+ replay_images_resized,
+ replay_images_left_wrist,
+ replay_images_right_wrist,
+ num_rollouts_completed,
+ episode_stats["success"],
+ task_description,
+ log_file,
+ )
+
+ # Log results
+ log_message(f"Success: {episode_stats['success']}", log_file)
+ log_message(f"# episodes completed so far: {num_rollouts_completed}", log_file)
+ log_message(f"# successes: {total_successes} ({total_successes / num_rollouts_completed * 100:.1f}%)", log_file)
+ log_message(f"Total model query time: {episode_stats['model_query_time']:.2f} sec", log_file)
+ log_message(f"Total episode elapsed time: {episode_stats['episode_duration']:.2f} sec", log_file)
+
+ # Calculate final success rate
+ final_success_rate = float(total_successes) / float(num_rollouts_completed) if num_rollouts_completed > 0 else 0
+
+ # Log final results
+ log_message("\nFinal results:", log_file)
+ log_message(f"Total episodes: {num_rollouts_completed}", log_file)
+ log_message(f"Total successes: {total_successes}", log_file)
+ log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file)
+
+ # Close log file
+ if log_file:
+ log_file.close()
+
+ return final_success_rate
+
+
+if __name__ == "__main__":
+ eval_aloha_local()
\ No newline at end of file
diff --git a/prismatic/vla/constants.py b/prismatic/vla/constants.py
index 0985a83..f604309 100644
--- a/prismatic/vla/constants.py
+++ b/prismatic/vla/constants.py
@@ -41,8 +41,8 @@ class NormalizationType(str, Enum):
ALOHA_CONSTANTS = {
"NUM_ACTIONS_CHUNK": 25,
- "ACTION_DIM": 14,
- "PROPRIO_DIM": 14,
+ "ACTION_DIM": 7,
+ "PROPRIO_DIM": 7,
"ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS,
}
diff --git a/prismatic/vla/datasets/datasets.py b/prismatic/vla/datasets/datasets.py
index 2f2daff..b21d7aa 100644
--- a/prismatic/vla/datasets/datasets.py
+++ b/prismatic/vla/datasets/datasets.py
@@ -167,7 +167,7 @@ def __init__(
# fmt: off
if "aloha" in self.data_mix:
- load_camera_views = ("primary", "left_wrist", "right_wrist")
+ load_camera_views = ("primary", "right_wrist")
else:
load_camera_views = ("primary", "wrist")
diff --git a/prismatic/vla/datasets/rlds/oxe/configs.py b/prismatic/vla/datasets/rlds/oxe/configs.py
index 1367cd7..28e00b5 100644
--- a/prismatic/vla/datasets/rlds/oxe/configs.py
+++ b/prismatic/vla/datasets/rlds/oxe/configs.py
@@ -37,6 +37,7 @@ class StateEncoding(IntEnum):
POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
JOINT = 3 # Joint Angles (7, if fewer) + Gripper Open/Close (1)
JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ])
+ JOINT_SINGLE = 5 # Joint Angles (1 x [ Joint Angles (6) + Gripper Open/Close (1) ])
# fmt: on
@@ -47,6 +48,7 @@ class ActionEncoding(IntEnum):
JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1)
JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1)
+ JOINT_POS_SINGLE = 5 # Joint Delta Position (1 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
# fmt: on
@@ -720,4 +722,11 @@ class ActionEncoding(IntEnum):
"state_encoding": StateEncoding.JOINT_BIMANUAL,
"action_encoding": ActionEncoding.JOINT_POS_BIMANUAL,
},
+ "aloha_put_x_into_the_box_80_demos": {
+ "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": None, "right_wrist": "right_wrist_image"},
+ "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
+ "state_obs_keys": ["state"],
+ "state_encoding": StateEncoding.JOINT_SINGLE,
+ "action_encoding": ActionEncoding.JOINT_POS_SINGLE,
+ },
}
diff --git a/prismatic/vla/datasets/rlds/oxe/materialize.py b/prismatic/vla/datasets/rlds/oxe/materialize.py
index fd4103d..0b3209b 100644
--- a/prismatic/vla/datasets/rlds/oxe/materialize.py
+++ b/prismatic/vla/datasets/rlds/oxe/materialize.py
@@ -29,8 +29,8 @@ def make_oxe_dataset_kwargs(
) -> Dict[str, Any]:
"""Generates config (kwargs) for given dataset from Open-X Embodiment."""
dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name])
- if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL]:
- raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 & JOINT_POS_BIMANUAL actions supported!")
+ if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL, ActionEncoding.JOINT_POS_SINGLE]:
+ raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 & JOINT_POS_BIMANUAL & JOINT_POS_SINGLE actions supported!")
# [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute!
# Normalize all action dimensions *except* the gripper
@@ -43,6 +43,9 @@ def make_oxe_dataset_kwargs(
elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL:
dataset_kwargs["absolute_action_mask"] = [True] * 14
dataset_kwargs["action_normalization_mask"] = [True] * 14
+ elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_SINGLE:
+ dataset_kwargs["absolute_action_mask"] = [True] * 7
+ dataset_kwargs["action_normalization_mask"] = [True] * 7
dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type
# Adjust Loaded Camera Views