Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ request = CreateSandboxRequest(
memory_gb=4,
disk_size_gb=20,
gpu_count=0,
gpu_type=None,
timeout_minutes=60,
environment_vars={"ENV": "production"},
secrets={"API_KEY": "your-secret-key"},
Expand All @@ -165,11 +166,19 @@ print(f"Created sandbox: {sandbox.id}")
prime sandbox list [--team_id TEAM] [--status STATUS] [--page N] [--per_page N]

# Create sandbox
prime sandbox create NAME --docker_image IMAGE [OPTIONS]
prime sandbox create IMAGE [OPTIONS]

# Create GPU sandbox
prime sandbox create --gpu-count 1 --gpu-type H100_80GB

# Note: DOCKER_IMAGE is not supported for GPU sandboxes

# With environment variables and secrets:
prime sandbox create python:3.11-slim --env KEY=VALUE --secret API_KEY=secret123

# Run command in sandbox
prime sandbox run SANDBOX_ID -- python script.py

# Get sandbox details
prime sandbox get SANDBOX_ID

Expand Down
3 changes: 3 additions & 0 deletions packages/prime-sandboxes/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dependencies = [
"pydantic>=2.0.0",
"aiofiles>=23.0.0",
"tenacity>=8.0.0",
"connect-python>=0.8.0",
"protobuf>=6.31.1",
]
keywords = ["sandboxes", "remote-execution", "containers", "cloud", "sdk"]
classifiers = [
Expand Down Expand Up @@ -60,6 +62,7 @@ testpaths = ["tests"]
[tool.ruff]
line-length = 100
target-version = "py310"
extend-exclude = ["src/prime_sandboxes/_proto/**/*.py"]

[tool.ruff.lint]
extend-select = ["E", "F", "I"]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Generated protobuf modules used by prime-sandboxes."""
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Generated command_session protobuf messages for sandboxd RPCs."""

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions packages/prime-sandboxes/src/prime_sandboxes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator


class SandboxStatus(str, Enum):
Expand All @@ -13,7 +13,7 @@ class SandboxStatus(str, Enum):
PENDING = "PENDING"
PROVISIONING = "PROVISIONING"
RUNNING = "RUNNING"
STOPPED = "STOPPED"
PAUSED = "PAUSED"
ERROR = "ERROR"
TERMINATED = "TERMINATED"
TIMEOUT = "TIMEOUT"
Expand All @@ -39,6 +39,7 @@ class Sandbox(BaseModel):
disk_size_gb: float = Field(..., alias="diskSizeGB")
disk_mount_path: str = Field(..., alias="diskMountPath")
gpu_count: int = Field(..., alias="gpuCount")
gpu_type: Optional[str] = Field(None, alias="gpuType")
network_access: bool = Field(True, alias="networkAccess")
status: str
timeout_minutes: int = Field(..., alias="timeoutMinutes")
Expand Down Expand Up @@ -83,6 +84,7 @@ class CreateSandboxRequest(BaseModel):
memory_gb: float = 2.0
disk_size_gb: float = 5.0
gpu_count: int = 0
gpu_type: Optional[str] = None
network_access: bool = True
timeout_minutes: int = 60
environment_vars: Optional[Dict[str, str]] = None
Expand All @@ -92,6 +94,12 @@ class CreateSandboxRequest(BaseModel):
advanced_configs: Optional[AdvancedConfigs] = None
registry_credentials_id: Optional[str] = None

@model_validator(mode="after")
def validate_gpu_fields(self) -> "CreateSandboxRequest":
if self.gpu_count > 0 and not self.gpu_type:
raise ValueError("gpu_type is required when gpu_count is greater than 0")
return self


class UpdateSandboxRequest(BaseModel):
"""Update sandbox request model"""
Expand All @@ -103,6 +111,7 @@ class UpdateSandboxRequest(BaseModel):
memory_gb: Optional[float] = None
disk_size_gb: Optional[float] = None
gpu_count: Optional[int] = None
gpu_type: Optional[str] = None
timeout_minutes: Optional[int] = None
environment_vars: Optional[Dict[str, str]] = None
registry_credentials_id: Optional[str] = None
Expand Down
108 changes: 108 additions & 0 deletions packages/prime-sandboxes/src/prime_sandboxes/rpc_command_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Command session Connect RPC helpers."""

from typing import Dict, List, Optional, Protocol, cast

from connectrpc.method import IdempotencyLevel, MethodInfo
from google.protobuf.message import Message

from ._proto.command_session import command_session_pb2


class _CommandSpecLike(Protocol):
cwd: str


class _CommandSpecFactory(Protocol):
def __call__(self, *, cmd: str, args: List[str], envs: Dict[str, str]) -> _CommandSpecLike: ...


class _CommandSessionStartRequestFactory(Protocol):
def __call__(self, *, command: _CommandSpecLike, stdin: bool) -> Message: ...


class _CommandSessionDataEventLike(Protocol):
stdout: bytes
stderr: bytes
pty: bytes

def WhichOneof(self, field_name: str) -> str | None: ...


class _CommandSessionEndEventLike(Protocol):
exit_code: int


class _CommandSessionEventLike(Protocol):
data: _CommandSessionDataEventLike
end: _CommandSessionEndEventLike

def WhichOneof(self, field_name: str) -> str | None: ...


class _CommandSessionStartResponseLike(Protocol):
event: _CommandSessionEventLike

def HasField(self, field_name: str) -> bool: ...


_COMMAND_SESSION_START_REQUEST_TYPE = cast(
type[Message], getattr(command_session_pb2, "StartRequest")
)
_COMMAND_SESSION_START_RESPONSE_TYPE = cast(
type[Message], getattr(command_session_pb2, "StartResponse")
)
_COMMAND_SESSION_START_REQUEST_FACTORY = cast(
_CommandSessionStartRequestFactory, _COMMAND_SESSION_START_REQUEST_TYPE
)
_COMMAND_SPEC_FACTORY = cast(_CommandSpecFactory, getattr(command_session_pb2, "CommandSpec"))


COMMAND_SESSION_START_RPC_METHOD = MethodInfo(
name="Start",
service_name="command_session.CommandSession",
input=_COMMAND_SESSION_START_REQUEST_TYPE,
output=_COMMAND_SESSION_START_RESPONSE_TYPE,
idempotency_level=IdempotencyLevel.UNKNOWN,
)


def build_command_session_start_request(
command: str,
working_dir: Optional[str],
env: Optional[Dict[str, str]],
) -> Message:
command_spec = _COMMAND_SPEC_FACTORY(
cmd="/bin/bash",
args=["-l", "-c", command],
envs=env or {},
)
if working_dir is not None:
command_spec.cwd = working_dir

return _COMMAND_SESSION_START_REQUEST_FACTORY(command=command_spec, stdin=False)


def collect_command_session_start_event(
response: Message,
stdout_parts: List[str],
stderr_parts: List[str],
) -> Optional[int]:
start_response = cast(_CommandSessionStartResponseLike, response)
if not start_response.HasField("event"):
return None

event = start_response.event
event_kind = event.WhichOneof("event")

if event_kind == "data":
data_kind = event.data.WhichOneof("output")
if data_kind == "stdout" and event.data.stdout:
stdout_parts.append(event.data.stdout.decode("utf-8", errors="replace"))
elif data_kind == "stderr" and event.data.stderr:
stderr_parts.append(event.data.stderr.decode("utf-8", errors="replace"))
elif data_kind == "pty" and event.data.pty:
stdout_parts.append(event.data.pty.decode("utf-8", errors="replace"))
elif event_kind == "end":
return int(event.end.exit_code)

return None
Loading