Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion scripts/serve_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class Args:
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)

# Disable normalization statistics
skip_norm_stats: bool = False


# Default checkpoints that should be used for each environment.
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
Expand Down Expand Up @@ -90,7 +93,10 @@ def create_policy(args: Args) -> _policy.Policy:
match args.policy:
case Checkpoint():
return _policy_config.create_trained_policy(
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
_config.get_config(args.policy.config),
args.policy.dir,
default_prompt=args.default_prompt,
skip_norm_stats=args.skip_norm_stats,
)
case Default():
return create_default_policy(args.env, default_prompt=args.default_prompt)
Expand Down
6 changes: 5 additions & 1 deletion src/openpi/policies/policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def create_trained_policy(
sample_kwargs: dict[str, Any] | None = None,
default_prompt: str | None = None,
norm_stats: dict[str, transforms.NormStats] | None = None,
skip_norm_stats: bool = False,
pytorch_device: str | None = None,
) -> _policy.Policy:
"""Create a policy from a trained checkpoint.
Expand All @@ -35,6 +36,7 @@ def create_trained_policy(
data if it doesn't already exist.
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
from the checkpoint directory.
skip_norm_stats: If True, skips loading and using normalization stats (uses identity transform).
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".

Expand All @@ -56,7 +58,9 @@ def create_trained_policy(
else:
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
if norm_stats is None:
if skip_norm_stats:
norm_stats = {}
elif norm_stats is None:
# We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
# that the policy is using the same normalization stats as the original training process.
if data_config.asset_id is None:
Expand Down