Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/openpi/serving/websocket_policy_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import http
import logging
import os
import time
import traceback

Expand All @@ -11,6 +12,9 @@

logger = logging.getLogger(__name__)

# How many elements to print when sampling large arrays. Configure via env OPENPI_SAMPLE_MAX.
SAMPLE_MAX = int(os.environ.get("OPENPI_SAMPLE_MAX", "10"))


class WebsocketPolicyServer:
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
Expand Down Expand Up @@ -56,11 +60,47 @@ async def _handler(self, websocket: _server.ServerConnection):
try:
start_time = time.monotonic()
obs = msgpack_numpy.unpackb(await websocket.recv())
# Log received observation keys and a small sample for debugging.
try:
logger.info("Received obs from %s keys=%s", websocket.remote_address, list(obs.keys()) if isinstance(obs, dict) else None)
# Small sample at DEBUG level to avoid spamming INFO logs for large arrays.
try:
import numpy as _np

if isinstance(obs, dict) and "state" in obs:
sample = _np.array(obs.get("state"))
logger.debug(
"obs state shape=%s sample=%s",
getattr(sample, "shape", None),
str(sample.ravel()[:SAMPLE_MAX]),
)
except Exception:
logger.debug("Could not sample obs values", exc_info=True)
except Exception:
logger.exception("Failed to log incoming obs")

infer_time = time.monotonic()
action = self._policy.infer(obs)
infer_time = time.monotonic() - infer_time

# Log action keys and a small sample for debugging.
try:
logger.info("Sending action to %s keys=%s", websocket.remote_address, list(action.keys()) if isinstance(action, dict) else None)
try:
import numpy as _np

if isinstance(action, dict) and "actions" in action:
sample = _np.array(action.get("actions"))
logger.debug(
"action['actions'] shape=%s sample=%s",
getattr(sample, "shape", None),
str(sample.ravel()[:SAMPLE_MAX]),
)
except Exception:
logger.debug("Could not sample action values", exc_info=True)
except Exception:
logger.exception("Failed to log action output")

action["server_timing"] = {
"infer_ms": infer_time * 1000,
}
Expand Down
98 changes: 98 additions & 0 deletions tools/inspect_policy_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
"""Inspect policy_records saved by PolicyRecorder.

Usage:
python tools/inspect_policy_record.py policy_records/step_0.npy

This prints a nested summary of inputs and outputs, including array shapes and small samples.
"""
from __future__ import annotations

import sys
import numpy as np
from textwrap import shorten
from typing import Any


def load_record(path: str) -> Any:
arr = np.load(path, allow_pickle=True)
if isinstance(arr, np.ndarray) and arr.shape == ():
return arr.item()
return arr


def unflatten_if_needed(flat: dict) -> dict:
nested: dict = {}
for key, val in flat.items():
if isinstance(key, tuple):
parts = list(key)
elif isinstance(key, str) and "/" in key:
parts = key.split("/")
else:
nested[key] = val
continue

d = nested
for p in parts[:-1]:
if p not in d or not isinstance(d[p], dict):
d[p] = {}
d = d[p]
d[parts[-1]] = val
return nested


def sample_repr(x: Any, max_len: int = 240) -> str:
try:
import numpy as _np

if isinstance(x, _np.ndarray):
s = f"ndarray shape={x.shape} dtype={x.dtype}"
flat = x.ravel()
if flat.size > 0:
sample = flat[:10].tolist()
s += f" sample={sample}"
return s
except Exception:
pass

try:
if hasattr(x, "shape"):
try:
return f"{type(x).__name__} shape={x.shape}"
except Exception:
pass
return shorten(repr(x), width=max_len)
except Exception:
return "<unrepresentable>"


def print_summary(nested: dict) -> None:
for top in ("inputs", "outputs"):
if top not in nested:
continue
print(f"--- {top} ---")
section = nested[top]
if isinstance(section, dict):
for k, v in section.items():
print(f"{k}: {sample_repr(v)}")
else:
print(sample_repr(section))
print()


def main(path: str) -> None:
data = load_record(path)
if not isinstance(data, dict):
print("Loaded data is not a dict, type:", type(data))
print("Raw loaded object:", data)
return
nested = unflatten_if_needed(data)
print("Top-level keys:", list(nested.keys()))
print_summary(nested)


if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: inspect_policy_record.py <path-to-step_*.npy>")
sys.exit(1)
main(sys.argv[1])