diff --git a/packages/openpi-client/src/openpi_client/action_chunk_broker.py b/packages/openpi-client/src/openpi_client/action_chunk_broker.py index 8fa9d83d02..59bb451805 100644 --- a/packages/openpi-client/src/openpi_client/action_chunk_broker.py +++ b/packages/openpi-client/src/openpi_client/action_chunk_broker.py @@ -17,6 +17,12 @@ class ActionChunkBroker(_base_policy.BasePolicy): """ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): + """Initialize the ActionChunkBroker with a policy and action horizon. + + Args: + policy: The underlying policy to wrap for chunked action delivery. + action_horizon: The number of action steps in each chunk from the policy. + """ self._policy = policy self._action_horizon = action_horizon self._cur_step: int = 0 @@ -25,11 +31,20 @@ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): @override def infer(self, obs: Dict) -> Dict: # noqa: UP006 + """Return the next action from the current chunk or fetch a new chunk if needed. + + Args: + obs: Observation dictionary to pass to the underlying policy when fetching new chunks. + + Returns: + Dictionary containing the action for the current step, extracted from the chunk. + """ if self._last_results is None: self._last_results = self._policy.infer(obs) self._cur_step = 0 def slicer(x): + """Extract the current step from array data or return non-array data unchanged.""" if isinstance(x, np.ndarray): return x[self._cur_step, ...] else: @@ -45,6 +60,7 @@ def slicer(x): @override def reset(self) -> None: + """Reset the broker state and the underlying policy.""" self._policy.reset() self._last_results = None self._cur_step = 0 diff --git a/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py b/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py index 65227c44da..5c299a98ed 100644 --- a/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py +++ b/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py @@ -8,11 +8,25 @@ class PolicyAgent(_agent.Agent): """An agent that uses a policy to determine actions.""" def __init__(self, policy: _base_policy.BasePolicy) -> None: + """Initialize the policy agent with a given policy. + + Args: + policy: The policy instance used to infer actions from observations. + """ self._policy = policy @override def get_action(self, observation: dict) -> dict: + """Get an action by inferring from the observation using the policy. + + Args: + observation: The current observation state as a dictionary. + + Returns: + The action determined by the policy as a dictionary. + """ return self._policy.infer(observation) def reset(self) -> None: + """Reset the policy to its initial state.""" self._policy.reset() diff --git a/packages/openpi-client/src/openpi_client/websocket_client_policy.py b/packages/openpi-client/src/openpi_client/websocket_client_policy.py index d6244f0f78..2230eb437a 100644 --- a/packages/openpi-client/src/openpi_client/websocket_client_policy.py +++ b/packages/openpi-client/src/openpi_client/websocket_client_policy.py @@ -16,6 +16,13 @@ class WebsocketClientPolicy(_base_policy.BasePolicy): """ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None: + """Initialize the websocket client policy. + + Args: + host: The hostname or IP address of the server to connect to. + port: The port number to connect to. If None, no port is appended to the URI. + api_key: Optional API key for authentication. If provided, it will be sent in the Authorization header. + """ self._uri = f"ws://{host}" if port is not None: self._uri += f":{port}" @@ -24,9 +31,25 @@ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: O self._ws, self._server_metadata = self._wait_for_server() def get_server_metadata(self) -> Dict: + """Get metadata received from the server during connection. + + Returns: + Dictionary containing server metadata that was received during the initial connection. + """ return self._server_metadata def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: + """Establish connection to the server and retrieve metadata. + + Continuously attempts to connect to the server until successful, with 5-second intervals between attempts. + Once connected, receives and unpacks the server metadata. + + Returns: + Tuple containing the websocket connection and the server metadata dictionary. + + Raises: + Any exception that occurs during metadata unpacking or connection establishment (except ConnectionRefusedError). + """ logging.info(f"Waiting for server at {self._uri}...") while True: try: @@ -42,6 +65,17 @@ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dic @override def infer(self, obs: Dict) -> Dict: # noqa: UP006 + """Send observation to server and receive inference result. + + Args: + obs: Dictionary containing observation data to be sent to the server. + + Returns: + Dictionary containing the inference result from the server. + + Raises: + RuntimeError: If the server responds with an error message (string instead of bytes). + """ data = self._packer.pack(obs) self._ws.send(data) response = self._ws.recv() @@ -52,4 +86,8 @@ def infer(self, obs: Dict) -> Dict: # noqa: UP006 @override def reset(self) -> None: + """Reset the policy state. + + This implementation does nothing as the websocket client maintains no local state that needs resetting. + """ pass diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 203b36be8a..9cabef76a8 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -10,6 +10,13 @@ class PaliGemmaWithExpertModel(nn.Module): + """ + A PyTorch module that combines PaliGemma vision-language model with a Gemma expert model. + + This model integrates a PaliGemma model for multimodal processing with an additional + Gemma expert model for specialized language processing tasks. + """ + def __init__( self, vlm_config, @@ -17,6 +24,16 @@ def __init__( use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", ): + """ + Initialize the PaliGemma with Expert model. + + Args: + vlm_config: Configuration object for the vision-language model + action_expert_config: Configuration object for the action expert model + use_adarms: List of two booleans indicating whether to use AdaRMS for each model. + Defaults to [False, False] if None + precision: Precision type for model parameters, either "bfloat16" or "float32" + """ if use_adarms is None: use_adarms = [False, False] super().__init__() @@ -61,6 +78,15 @@ def __init__( self.to_bfloat16_for_selected_params(precision) def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + """ + Convert model parameters to specified precision while keeping certain parameters in float32. + + Args: + precision: Target precision for most parameters, either "bfloat16" or "float32" + + Raises: + ValueError: If precision is not "bfloat16" or "float32" + """ if precision == "bfloat16": self.to(dtype=torch.bfloat16) elif precision == "float32": @@ -83,9 +109,27 @@ def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float3 param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): + """ + Extract image features using the PaliGemma vision tower. + + Args: + image: Input image tensor + + Returns: + torch.Tensor: Image feature embeddings + """ return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): + """ + Convert language tokens to embeddings using the PaliGemma language model. + + Args: + tokens: Input token tensor + + Returns: + torch.Tensor: Token embeddings + """ return self.paligemma.language_model.embed_tokens(tokens) def forward( @@ -97,6 +141,26 @@ def forward( use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None, ): + """ + Forward pass through the combined PaliGemma and expert models. + + Processes input embeddings through either the PaliGemma language model alone, + the Gemma expert model alone, or both models in a combined fashion with + shared attention computation. + + Args: + attention_mask: Mask to avoid attention on padding tokens + position_ids: Position indices for positional embeddings + past_key_values: Cached key-value pairs from previous forward passes + inputs_embeds: List of two embedding tensors, one for each model + use_cache: Whether to return cached key-value pairs + adarms_cond: Conditioning tensors for AdaRMS normalization + + Returns: + tuple: A tuple containing: + - List of output tensors [prefix_output, suffix_output] + - Past key values for caching + """ if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: @@ -156,6 +220,19 @@ def forward( # Define the complete layer computation function for gradient checkpointing def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + """ + Compute a complete transformer layer for both models with shared attention. + + Args: + layer_idx: Index of the current layer + inputs_embeds: Input embeddings for both models + attention_mask: Attention mask tensor + position_ids: Position indices + adarms_cond: AdaRMS conditioning tensors + + Returns: + list: Output embeddings for both models after layer processing + """ models = [self.paligemma.language_model, self.gemma_expert.model] query_states = [] @@ -260,6 +337,16 @@ def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_id # final norm # Define final norm computation function for gradient checkpointing def compute_final_norms(inputs_embeds, adarms_cond): + """ + Apply final layer normalization to both model outputs. + + Args: + inputs_embeds: Input embeddings for both models + adarms_cond: AdaRMS conditioning tensors + + Returns: + list: Normalized output embeddings for both models + """ outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) diff --git a/src/openpi/serving/websocket_policy_server.py b/src/openpi/serving/websocket_policy_server.py index bdefa98b87..8b04827908 100644 --- a/src/openpi/serving/websocket_policy_server.py +++ b/src/openpi/serving/websocket_policy_server.py @@ -25,6 +25,14 @@ def __init__( port: int | None = None, metadata: dict | None = None, ) -> None: + """Initialize the WebSocket policy server. + + Args: + policy: The policy instance to serve over WebSocket + host: The host address to bind the server to + port: The port number to bind the server to, or None for automatic assignment + metadata: Additional metadata to send to clients upon connection + """ self._policy = policy self._host = host self._port = port @@ -32,9 +40,11 @@ def __init__( logging.getLogger("websockets.server").setLevel(logging.INFO) def serve_forever(self) -> None: + """Start the server and run it indefinitely in a blocking manner.""" asyncio.run(self.run()) async def run(self): + """Run the WebSocket server asynchronously with the configured handler.""" async with _server.serve( self._handler, self._host, @@ -46,6 +56,14 @@ async def run(self): await server.serve_forever() async def _handler(self, websocket: _server.ServerConnection): + """Handle incoming WebSocket connections and process inference requests. + + Sends metadata upon connection, then continuously receives observations, + runs inference, and sends back actions with timing information. + + Args: + websocket: The WebSocket connection to handle + """ logger.info(f"Connection from {websocket.remote_address} opened") packer = msgpack_numpy.Packer() @@ -84,6 +102,15 @@ async def _handler(self, websocket: _server.ServerConnection): def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None: + """Handle health check requests on the /healthz endpoint. + + Args: + connection: The server connection instance + request: The incoming HTTP request + + Returns: + HTTP OK response for /healthz path, None otherwise to continue normal processing + """ if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") # Continue with the normal request handling.