From 12e0711c14cbda9e6b4edbf0e1f5968d7f82e515 Mon Sep 17 00:00:00 2001 From: Skyler Date: Wed, 12 Nov 2025 20:39:25 -0500 Subject: [PATCH] feat: added skip norm stats --- scripts/serve_policy.py | 8 +++++++- src/openpi/policies/policy_config.py | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/scripts/serve_policy.py b/scripts/serve_policy.py index 30f121a60b..ccfe0cdadd 100644 --- a/scripts/serve_policy.py +++ b/scripts/serve_policy.py @@ -54,6 +54,9 @@ class Args: # Specifies how to load the policy. If not provided, the default policy for the environment will be used. policy: Checkpoint | Default = dataclasses.field(default_factory=Default) + # Disable normalization statistics + skip_norm_stats: bool = False + # Default checkpoints that should be used for each environment. DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { @@ -90,7 +93,10 @@ def create_policy(args: Args) -> _policy.Policy: match args.policy: case Checkpoint(): return _policy_config.create_trained_policy( - _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt + _config.get_config(args.policy.config), + args.policy.dir, + default_prompt=args.default_prompt, + skip_norm_stats=args.skip_norm_stats, ) case Default(): return create_default_policy(args.env, default_prompt=args.default_prompt) diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index 6570df05ed..a9b4c4d638 100644 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -21,6 +21,7 @@ def create_trained_policy( sample_kwargs: dict[str, Any] | None = None, default_prompt: str | None = None, norm_stats: dict[str, transforms.NormStats] | None = None, + skip_norm_stats: bool = False, pytorch_device: str | None = None, ) -> _policy.Policy: """Create a policy from a trained checkpoint. @@ -35,6 +36,7 @@ def create_trained_policy( data if it doesn't already exist. norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded from the checkpoint directory. + skip_norm_stats: If True, skips loading and using normalization stats (uses identity transform). pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0"). If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu". @@ -56,7 +58,9 @@ def create_trained_policy( else: model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) data_config = train_config.data.create(train_config.assets_dirs, train_config.model) - if norm_stats is None: + if skip_norm_stats: + norm_stats = {} + elif norm_stats is None: # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure # that the policy is using the same normalization stats as the original training process. if data_config.asset_id is None: