From 6592f70059296f1ab8b4524b35703a4c90515246 Mon Sep 17 00:00:00 2001 From: eHiollo <12532364@mail.sustech.edu.cn> Date: Sat, 11 Oct 2025 17:26:46 +0800 Subject: [PATCH 1/2] tools: add inspect_policy_record.py to pretty-print PolicyRecorder outputs --- tools/inspect_policy_record.py | 98 ++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tools/inspect_policy_record.py diff --git a/tools/inspect_policy_record.py b/tools/inspect_policy_record.py new file mode 100644 index 0000000000..5f281a45e1 --- /dev/null +++ b/tools/inspect_policy_record.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""Inspect policy_records saved by PolicyRecorder. + +Usage: + python tools/inspect_policy_record.py policy_records/step_0.npy + +This prints a nested summary of inputs and outputs, including array shapes and small samples. +""" +from __future__ import annotations + +import sys +import numpy as np +from textwrap import shorten +from typing import Any + + +def load_record(path: str) -> Any: + arr = np.load(path, allow_pickle=True) + if isinstance(arr, np.ndarray) and arr.shape == (): + return arr.item() + return arr + + +def unflatten_if_needed(flat: dict) -> dict: + nested: dict = {} + for key, val in flat.items(): + if isinstance(key, tuple): + parts = list(key) + elif isinstance(key, str) and "/" in key: + parts = key.split("/") + else: + nested[key] = val + continue + + d = nested + for p in parts[:-1]: + if p not in d or not isinstance(d[p], dict): + d[p] = {} + d = d[p] + d[parts[-1]] = val + return nested + + +def sample_repr(x: Any, max_len: int = 240) -> str: + try: + import numpy as _np + + if isinstance(x, _np.ndarray): + s = f"ndarray shape={x.shape} dtype={x.dtype}" + flat = x.ravel() + if flat.size > 0: + sample = flat[:10].tolist() + s += f" sample={sample}" + return s + except Exception: + pass + + try: + if hasattr(x, "shape"): + try: + return f"{type(x).__name__} shape={x.shape}" + except Exception: + pass + return shorten(repr(x), width=max_len) + except Exception: + return "" + + +def print_summary(nested: dict) -> None: + for top in ("inputs", "outputs"): + if top not in nested: + continue + print(f"--- {top} ---") + section = nested[top] + if isinstance(section, dict): + for k, v in section.items(): + print(f"{k}: {sample_repr(v)}") + else: + print(sample_repr(section)) + print() + + +def main(path: str) -> None: + data = load_record(path) + if not isinstance(data, dict): + print("Loaded data is not a dict, type:", type(data)) + print("Raw loaded object:", data) + return + nested = unflatten_if_needed(data) + print("Top-level keys:", list(nested.keys())) + print_summary(nested) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: inspect_policy_record.py ") + sys.exit(1) + main(sys.argv[1]) From 79726a19bfb6eaf1e1a419ec1a0637648e9c4a67 Mon Sep 17 00:00:00 2001 From: eHiollo <12532364@mail.sustech.edu.cn> Date: Sat, 11 Oct 2025 17:48:14 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=9C=A8websocket=5Fpolicy=5Fserver=20?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E4=B8=AD=E6=B7=BB=E5=8A=A0log=E6=89=93?= =?UTF-8?q?=E5=8D=B0=E6=A8=A1=E5=9E=8B=E7=9A=84=E8=BE=93=E5=85=A5=E8=BE=93?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/openpi/serving/websocket_policy_server.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/openpi/serving/websocket_policy_server.py b/src/openpi/serving/websocket_policy_server.py index bdefa98b87..bce0b2e534 100644 --- a/src/openpi/serving/websocket_policy_server.py +++ b/src/openpi/serving/websocket_policy_server.py @@ -1,6 +1,7 @@ import asyncio import http import logging +import os import time import traceback @@ -11,6 +12,9 @@ logger = logging.getLogger(__name__) +# How many elements to print when sampling large arrays. Configure via env OPENPI_SAMPLE_MAX. +SAMPLE_MAX = int(os.environ.get("OPENPI_SAMPLE_MAX", "10")) + class WebsocketPolicyServer: """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. @@ -56,11 +60,47 @@ async def _handler(self, websocket: _server.ServerConnection): try: start_time = time.monotonic() obs = msgpack_numpy.unpackb(await websocket.recv()) + # Log received observation keys and a small sample for debugging. + try: + logger.info("Received obs from %s keys=%s", websocket.remote_address, list(obs.keys()) if isinstance(obs, dict) else None) + # Small sample at DEBUG level to avoid spamming INFO logs for large arrays. + try: + import numpy as _np + + if isinstance(obs, dict) and "state" in obs: + sample = _np.array(obs.get("state")) + logger.debug( + "obs state shape=%s sample=%s", + getattr(sample, "shape", None), + str(sample.ravel()[:SAMPLE_MAX]), + ) + except Exception: + logger.debug("Could not sample obs values", exc_info=True) + except Exception: + logger.exception("Failed to log incoming obs") infer_time = time.monotonic() action = self._policy.infer(obs) infer_time = time.monotonic() - infer_time + # Log action keys and a small sample for debugging. + try: + logger.info("Sending action to %s keys=%s", websocket.remote_address, list(action.keys()) if isinstance(action, dict) else None) + try: + import numpy as _np + + if isinstance(action, dict) and "actions" in action: + sample = _np.array(action.get("actions")) + logger.debug( + "action['actions'] shape=%s sample=%s", + getattr(sample, "shape", None), + str(sample.ravel()[:SAMPLE_MAX]), + ) + except Exception: + logger.debug("Could not sample action values", exc_info=True) + except Exception: + logger.exception("Failed to log action output") + action["server_timing"] = { "infer_ms": infer_time * 1000, }