From 6079274f02989c729d4d8ea5193aed28fdb2658e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:27:41 -0800 Subject: [PATCH 01/25] docs: add PRD for zero-boilerplate flash run experience --- PRD.md | 341 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 PRD.md diff --git a/PRD.md b/PRD.md new file mode 100644 index 00000000..a5d9d98f --- /dev/null +++ b/PRD.md @@ -0,0 +1,341 @@ +# Flash SDK: Zero-Boilerplate Experience — Product Requirements Document + +## 1. Problem Statement + +Flash currently forces every project into a FastAPI-first model: + +- Users must create `main.py` with a `FastAPI()` instance +- HTTP routing boilerplate adds no semantic value — the routes simply call `@remote` functions +- No straightforward path for deploying a standalone QB function without wrapping it in a FastAPI app +- The "mothership" concept introduces an implicit coordinator with no clear ownership model +- `flash run` fails unless `main.py` exists with a FastAPI app, blocking the simplest use cases + +## 2. Goals + +- **Zero boilerplate**: a `@remote`-decorated function in any `.py` file is sufficient for `flash run` and `flash deploy` +- **File-system-as-namespace**: the project directory structure maps 1:1 to URL paths on the local dev server +- **Single command**: `flash run` works for all project topologies (one QB function, many files, mixed QB+LB) without any configuration +- **`flash deploy` requires no additional configuration** beyond the `@remote` declarations themselves +- **Peer endpoints**: every `@resource_config` is a first-class endpoint; no implicit coordinator + +## 3. Non-Goals + +- No backward compatibility with `main.py`/FastAPI-first style +- No implicit "mothership" concept; all endpoints are peers +- No changes to the QB runtime (`generic_handler.py`) or QB stub behavior +- No changes to deployed endpoint behavior (RunPod QB/LB APIs are unchanged) + +## 4. Developer Experience Specification + +### 4.1 Minimum viable QB project + +```python +# gpu_worker.py +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(input_data: dict) -> dict: + return {"result": "processed", "input": input_data} +``` + +`flash run` → `POST /gpu_worker/run` and `POST /gpu_worker/run_sync` +`flash deploy` → standalone QB endpoint at `api.runpod.ai/v2/{id}/run` + +### 4.2 LB endpoint + +```python +# api/routes.py +from runpod_flash import CpuLiveLoadBalancer, remote + +lb_config = CpuLiveLoadBalancer(name="api_routes") + +@remote(lb_config, method="POST", path="/compute") +async def compute(input_data: dict) -> dict: + return {"result": input_data} +``` + +`flash run` → `POST /api/routes/compute` +`flash deploy` → LB endpoint at `{id}.api.runpod.ai/compute` + +### 4.3 Mixed QB + LB (LB calling QB) + +```python +# api/routes.py (LB) +from runpod_flash import CpuLiveLoadBalancer, remote +from workers.gpu import heavy_compute # QB stub + +lb_config = CpuLiveLoadBalancer(name="api_routes") + +@remote(lb_config, method="POST", path="/process") +async def process_route(data: dict): + return await heavy_compute(data) # dispatches to QB endpoint + +# workers/gpu.py (QB) +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def heavy_compute(data: dict) -> dict: ... +``` + +## 5. URL Path Specification + +### 5.1 File prefix derivation + +The local dev server uses the project directory structure as a URL namespace. Each file's URL prefix is its path relative to the project root with `.py` stripped: + +``` +File Local URL prefix +────────────────────────────── ──────────────────────────── +gpu_worker.py → /gpu_worker +longruns/stage1.py → /longruns/stage1 +preprocess/first_pass.py → /preprocess/first_pass +workers/gpu/inference.py → /workers/gpu/inference +``` + +### 5.2 QB route generation + +| Condition | Routes | +|---|---| +| One `@remote` function in file | `POST {file_prefix}/run` and `POST {file_prefix}/run_sync` | +| Multiple `@remote` functions in file | `POST {file_prefix}/{fn_name}/run` and `POST {file_prefix}/{fn_name}/run_sync` | + +### 5.3 LB route generation + +| Condition | Route | +|---|---| +| `@remote(lb_config, method="POST", path="/compute")` | `POST {file_prefix}/compute` | + +The declared `path=` is appended to the file prefix. The `method=` determines the HTTP verb. + +### 5.4 QB request/response envelope + +Mirrors RunPod's API for consistency: + +``` +POST /gpu_worker/run_sync +Body: {"input": {"key": "value"}} +Response: {"id": "uuid", "status": "COMPLETED", "output": {...}} +``` + +## 6. Deployed Topology Specification + +Each unique resource config gets its own RunPod endpoint: + +| Type | Deployed URL | Example | +|---|---|---| +| QB | `https://api.runpod.ai/v2/{endpoint_id}/run` | `https://api.runpod.ai/v2/uoy3n7hkyb052a/run` | +| QB sync | `https://api.runpod.ai/v2/{endpoint_id}/run_sync` | | +| LB | `https://{endpoint_id}.api.runpod.ai/{declared_path}` | `https://rzlk6lph6gw7dk.api.runpod.ai/compute` | + +## 7. `.flash/` Folder Specification + +All generated artifacts go to `.flash/` in the project root. Auto-created, gitignored, never committed. + +``` +my_project/ +├── gpu_worker.py +├── longruns/ +│ └── stage1.py +└── .flash/ + ├── server.py ← generated by flash run + └── manifest.json ← generated by flash build +``` + +- `.flash/` is added to `.gitignore` automatically on first `flash run` +- `server.py` and `manifest.json` are overwritten on each run/build; other files preserved +- The `.flash/` directory itself is never committed + +### 7.1 Dev server launch + +Uvicorn is launched with `--app-dir .flash/` so `server:app` is importable. The server inserts the project root into `sys.path` so user modules resolve: + +```bash +uvicorn server:app \ + --app-dir .flash/ \ + --reload \ + --reload-dir . \ + --reload-include "*.py" +``` + +## 8. `flash run` Behavior + +1. Scan project for all `@remote` functions (QB and LB) in any `.py` file + - Skip: `.flash/`, `__pycache__`, `*.pyc`, `__init__.py` +2. If none found: print error with usage instructions, exit 1 +3. Generate `.flash/server.py` with routes for all discovered functions +4. Add `.flash/` to `.gitignore` if not already present +5. Start uvicorn with `--reload` watching both `.flash/` and project root +6. Print startup table: local paths → resource names → types +7. Swagger UI available at `http://localhost:{port}/docs` +8. On exit (Ctrl+C or SIGTERM): deprovision all Live Serverless endpoints provisioned during this session + +### 8.1 Startup table format + +``` +Flash Dev Server http://localhost:8888 + + Local path Resource Type + ────────────────────────────────── ─────────────────── ──── + POST /gpu_worker/run gpu_worker QB + POST /gpu_worker/run_sync gpu_worker QB + POST /longruns/stage1/run longruns_stage1 QB + POST /preprocess/first_pass/compute preprocess_first_pass LB + + Visit http://localhost:8888/docs for Swagger UI +``` + +## 9. `flash build` Behavior + +1. Scan project for all `@remote` functions (QB and LB) +2. Build `.flash/manifest.json` with flat resource structure (see §10) +3. For LB resources: generate deployed handler files using `module_path` +4. Package build artifact + +## 10. Manifest Structure + +Resource names are derived from file paths (slashes → underscores): + +```json +{ + "version": "1.0", + "project_name": "my_project", + "resources": { + "gpu_worker": { + "resource_type": "LiveServerless", + "file_path": "gpu_worker.py", + "local_path_prefix": "/gpu_worker", + "module_path": "gpu_worker", + "functions": ["gpu_hello"], + "is_load_balanced": false, + "makes_remote_calls": false + }, + "longruns_stage1": { + "resource_type": "LiveServerless", + "file_path": "longruns/stage1.py", + "local_path_prefix": "/longruns/stage1", + "module_path": "longruns.stage1", + "functions": ["stage1_process"], + "is_load_balanced": false, + "makes_remote_calls": false + }, + "preprocess_first_pass": { + "resource_type": "CpuLiveLoadBalancer", + "file_path": "preprocess/first_pass.py", + "local_path_prefix": "/preprocess/first_pass", + "module_path": "preprocess.first_pass", + "functions": [ + {"name": "first_pass_fn", "http_method": "POST", "http_path": "/compute"} + ], + "is_load_balanced": true, + "makes_remote_calls": true + } + } +} +``` + +## 11. `.flash/server.py` Structure + +```python +"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run.""" +import sys +import uuid +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from fastapi import FastAPI + +# QB imports +from gpu_worker import gpu_hello +from longruns.stage1 import stage1_process + +# LB imports +from preprocess.first_pass import first_pass_fn + +app = FastAPI( + title="Flash Dev Server", + description="Auto-generated by `flash run`. Visit /docs for interactive testing.", +) + +# QB: gpu_worker.py +@app.post("/gpu_worker/run", tags=["gpu_worker [QB]"]) +@app.post("/gpu_worker/run_sync", tags=["gpu_worker [QB]"]) +async def gpu_worker_run(body: dict): + result = await gpu_hello(body.get("input", body)) + return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} + +# QB: longruns/stage1.py +@app.post("/longruns/stage1/run", tags=["longruns/stage1 [QB]"]) +@app.post("/longruns/stage1/run_sync", tags=["longruns/stage1 [QB]"]) +async def longruns_stage1_run(body: dict): + result = await stage1_process(body.get("input", body)) + return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} + +# LB: preprocess/first_pass.py +@app.post("/preprocess/first_pass/compute", tags=["preprocess/first_pass [LB]"]) +async def _route_first_pass_compute(body: dict): + return await first_pass_fn(body) + +# Health +@app.get("/", tags=["health"]) +def home(): + return {"message": "Flash Dev Server", "docs": "/docs"} + +@app.get("/ping", tags=["health"]) +def ping(): + return {"status": "healthy"} +``` + +Subdirectory imports use dotted module paths: `longruns/stage1.py` → `from longruns.stage1 import fn`. + +Multi-function QB files (2+ `@remote` functions) get sub-prefixed routes: +``` +longruns/stage1.py has: stage1_preprocess, stage1_infer +→ POST /longruns/stage1/stage1_preprocess/run +→ POST /longruns/stage1/stage1_preprocess/run_sync +→ POST /longruns/stage1/stage1_infer/run +→ POST /longruns/stage1/stage1_infer/run_sync +``` + +## 12. Acceptance Criteria + +- [ ] A file with one `@remote(QB_config)` function and nothing else is a valid Flash project +- [ ] `flash run` produces a Swagger UI showing all routes grouped by source file +- [ ] QB routes accept `{"input": {...}}` and return `{"id": ..., "status": "COMPLETED", "output": {...}}` +- [ ] Subdirectory files produce URL prefixes matching their relative path +- [ ] Multiple `@remote` functions in one file each get their own sub-prefixed routes +- [ ] LB route handler body executes directly (not dispatched remotely) +- [ ] QB calls inside LB route handler body route to the remote QB endpoint +- [ ] `flash deploy` creates a RunPod endpoint for each resource config +- [ ] `flash build` produces `.flash/manifest.json` with `file_path`, `local_path_prefix`, `module_path` per resource +- [ ] When `flash run` exits, all Live Serverless endpoints provisioned during that session are automatically undeployed + +## 13. Edge Cases + +- **No `@remote` functions found**: Error with clear message and usage instructions +- **Multiple `@remote` functions per file (QB)**: Sub-prefixed routes `/{file_prefix}/{fn_name}/run` +- **`__init__.py` files**: Skipped — not treated as worker files +- **File path with hyphens** (e.g., `my-worker.py`): Resource name sanitized to `my_worker`, URL prefix `/my-worker` (hyphens valid in URLs, underscores in Python identifiers) +- **LB function calling another LB function**: Not supported via `@remote` — emit a warning at build time +- **`.flash/` already exists**: `server.py` and `manifest.json` overwritten; other files preserved +- **`flash deploy` with no LB endpoints**: QB-only deploy +- **Subdirectory `__init__.py`** imports needed: Generator checks and warns if missing + +## 14. Implementation Files + +| File | Change | +|------|--------| +| `flash/main/PRD.md` | This document | +| `src/runpod_flash/client.py` | Passthrough for LB route handlers (`__is_lb_route_handler__`) | +| `cli/commands/run.py` | Unified server generation; `--app-dir .flash/`; file-path-based route discovery | +| `cli/commands/build_utils/scanner.py` | Path utilities; `is_lb_route_handler` field; file-based resource identity | +| `cli/commands/build_utils/manifest.py` | Flat resource structure; `file_path`/`local_path_prefix`/`module_path` fields | +| `cli/commands/build_utils/lb_handler_generator.py` | Import module by `module_path`, walk `__is_lb_route_handler__`, register routes | +| `cli/commands/build.py` | Remove main.py requirement from `validate_project_structure` | +| `core/resources/serverless.py` | Inject `FLASH_MODULE_PATH` env var | +| `flash-examples/.../01_hello_world/` | Rewrite to bare minimum | +| `flash-examples/.../00_standalone_worker/` | New | +| `flash-examples/.../00_multi_resource/` | New | From 13a6757a094c810ecd5ffb2026a63dfd1db6987d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:29:40 -0800 Subject: [PATCH 02/25] feat(client,scanner): LB route handler passthrough and path-aware discovery LB @remote functions (with method= and path=) now return the decorated function unwrapped with __is_lb_route_handler__=True. The function body executes directly on the LB endpoint server rather than being dispatched as a remote stub. QB stubs inside the body are unaffected. Scanner gains three path utilities (file_to_url_prefix, file_to_resource_name, file_to_module_path) that convert file paths to URL prefixes, resource names, and dotted module paths respectively. RemoteFunctionMetadata gains is_lb_route_handler to distinguish LB route handlers from QB remote stubs during discovery. --- .../cli/commands/build_utils/scanner.py | 73 +++++- src/runpod_flash/client.py | 14 ++ .../build_utils/test_path_utilities.py | 217 ++++++++++++++++++ 3 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 tests/unit/cli/commands/build_utils/test_path_utilities.py diff --git a/src/runpod_flash/cli/commands/build_utils/scanner.py b/src/runpod_flash/cli/commands/build_utils/scanner.py index 2215ab9e..d217dcb3 100644 --- a/src/runpod_flash/cli/commands/build_utils/scanner.py +++ b/src/runpod_flash/cli/commands/build_utils/scanner.py @@ -3,6 +3,7 @@ import ast import importlib import logging +import os import re from dataclasses import dataclass, field from pathlib import Path @@ -11,6 +12,61 @@ logger = logging.getLogger(__name__) +def file_to_url_prefix(file_path: Path, project_root: Path) -> str: + """Derive the local dev server URL prefix from a source file path. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + URL prefix starting with "/" (e.g., /longruns/stage1) + + Example: + longruns/stage1.py → /longruns/stage1 + """ + rel = file_path.relative_to(project_root).with_suffix("") + return "/" + str(rel).replace(os.sep, "/") + + +def file_to_resource_name(file_path: Path, project_root: Path) -> str: + """Derive the manifest resource name from a source file path. + + Slashes and hyphens are replaced with underscores to produce a valid + Python identifier suitable for use as a resource name. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + Resource name using underscores (e.g., longruns_stage1) + + Example: + longruns/stage1.py → longruns_stage1 + my-worker.py → my_worker + """ + rel = file_path.relative_to(project_root).with_suffix("") + return str(rel).replace(os.sep, "_").replace("/", "_").replace("-", "_") + + +def file_to_module_path(file_path: Path, project_root: Path) -> str: + """Derive the Python dotted module path from a source file path. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + Dotted module path (e.g., longruns.stage1) + + Example: + longruns/stage1.py → longruns.stage1 + """ + rel = file_path.relative_to(project_root).with_suffix("") + return str(rel).replace(os.sep, ".").replace("/", ".") + + @dataclass class RemoteFunctionMetadata: """Metadata about a @remote decorated function or class.""" @@ -35,6 +91,9 @@ class RemoteFunctionMetadata: called_remote_functions: List[str] = field( default_factory=list ) # Names of @remote functions called + is_lb_route_handler: bool = ( + False # LB @remote with method= and path= — runs directly as HTTP handler + ) class RemoteDecoratorScanner: @@ -62,7 +121,9 @@ def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: rel_path = f.relative_to(self.project_dir) # Check if first part of path is in excluded_root_dirs if rel_path.parts and rel_path.parts[0] not in excluded_root_dirs: - self.py_files.append(f) + # Exclude __init__.py — not valid worker entry points + if f.name != "__init__.py": + self.py_files.append(f) except (ValueError, IndexError): # Include files that can't be made relative self.py_files.append(f) @@ -220,6 +281,15 @@ def _extract_remote_functions( {"is_load_balanced": False, "is_live_resource": False}, ) + # An LB route handler is an LB @remote function that has + # both method= and path= declared. Its body runs directly + # on the LB endpoint — it is NOT a remote dispatch stub. + is_lb_route_handler = ( + flags["is_load_balanced"] + and http_method is not None + and http_path is not None + ) + metadata = RemoteFunctionMetadata( function_name=node.name, module_path=module_path, @@ -235,6 +305,7 @@ def _extract_remote_functions( config_variable=self.resource_variables.get( resource_config_name ), + is_lb_route_handler=is_lb_route_handler, ) functions.append(metadata) diff --git a/src/runpod_flash/client.py b/src/runpod_flash/client.py index ed68bc30..8709cf75 100644 --- a/src/runpod_flash/client.py +++ b/src/runpod_flash/client.py @@ -159,6 +159,20 @@ def decorator(func_or_class): "system_dependencies": system_dependencies, } + # LB route handler passthrough — return the function unwrapped. + # + # When @remote is applied to an LB resource (LiveLoadBalancer, + # CpuLiveLoadBalancer, LoadBalancerSlsResource) with method= and path=, + # the decorated function IS the HTTP route handler. Its body executes + # directly on the LB endpoint server; it is not dispatched to a remote + # process. QB @remote calls inside its body still use their own stubs. + is_lb_route_handler = is_lb_resource and method is not None and path is not None + if is_lb_route_handler: + routing_config["is_lb_route_handler"] = True + func_or_class.__remote_config__ = routing_config + func_or_class.__is_lb_route_handler__ = True + return func_or_class + # Local execution mode - execute without provisioning remote servers if local: func_or_class.__remote_config__ = routing_config diff --git a/tests/unit/cli/commands/build_utils/test_path_utilities.py b/tests/unit/cli/commands/build_utils/test_path_utilities.py new file mode 100644 index 00000000..73ec3557 --- /dev/null +++ b/tests/unit/cli/commands/build_utils/test_path_utilities.py @@ -0,0 +1,217 @@ +"""TDD tests for scanner path utility functions. + +Written first (failing) per the plan's TDD requirement. +These test file_to_url_prefix, file_to_resource_name, file_to_module_path. +""" + +import os + + +from runpod_flash.cli.commands.build_utils.scanner import ( + file_to_module_path, + file_to_resource_name, + file_to_url_prefix, +) + + +class TestFileToUrlPrefix: + """Tests for file_to_url_prefix utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → /gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_url_prefix(f, tmp_path) == "/gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → /longruns/stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_url_prefix(f, tmp_path) == "/longruns/stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → /preprocess/first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_url_prefix(f, tmp_path) == "/preprocess/first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → /workers/gpu/inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_url_prefix(f, tmp_path) == "/workers/gpu/inference" + + def test_hyphenated_filename(self, tmp_path): + """my-worker.py → /my-worker (hyphens valid in URLs)""" + f = tmp_path / "my-worker.py" + assert file_to_url_prefix(f, tmp_path) == "/my-worker" + + def test_starts_with_slash(self, tmp_path): + """Result always starts with /""" + f = tmp_path / "worker.py" + result = file_to_url_prefix(f, tmp_path) + assert result.startswith("/") + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_url_prefix(f, tmp_path) + assert ".py" not in result + + +class TestFileToResourceName: + """Tests for file_to_resource_name utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_resource_name(f, tmp_path) == "gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → longruns_stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_resource_name(f, tmp_path) == "longruns_stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → preprocess_first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_resource_name(f, tmp_path) == "preprocess_first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → workers_gpu_inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_resource_name(f, tmp_path) == "workers_gpu_inference" + + def test_hyphenated_filename(self, tmp_path): + """my-worker.py → my_worker (hyphens replaced with underscores for Python identifiers)""" + f = tmp_path / "my-worker.py" + assert file_to_resource_name(f, tmp_path) == "my_worker" + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_resource_name(f, tmp_path) + assert ".py" not in result + + def test_no_path_separators(self, tmp_path): + """Result contains no / or os.sep characters""" + f = tmp_path / "a" / "b" / "worker.py" + result = file_to_resource_name(f, tmp_path) + assert "/" not in result + assert os.sep not in result + + +class TestFileToModulePath: + """Tests for file_to_module_path utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_module_path(f, tmp_path) == "gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → longruns.stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_module_path(f, tmp_path) == "longruns.stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → preprocess.first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_module_path(f, tmp_path) == "preprocess.first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → workers.gpu.inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_module_path(f, tmp_path) == "workers.gpu.inference" + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_module_path(f, tmp_path) + assert ".py" not in result + + def test_uses_dots_not_slashes(self, tmp_path): + """Result uses dots as separators, not slashes""" + f = tmp_path / "a" / "b" / "worker.py" + result = file_to_module_path(f, tmp_path) + assert "." in result + assert "/" not in result + assert os.sep not in result + + +class TestIsLbRouteHandlerField: + """Tests that RemoteFunctionMetadata.is_lb_route_handler is set correctly.""" + + def test_lb_function_with_method_and_path_is_handler(self, tmp_path): + """An LB @remote function with method= and path= is marked as LB route handler.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "routes.py").write_text( + """ +from runpod_flash import CpuLiveLoadBalancer, remote + +lb_config = CpuLiveLoadBalancer(name="my_lb") + +@remote(lb_config, method="POST", path="/compute") +async def compute(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_lb_route_handler is True + + def test_qb_function_is_not_handler(self, tmp_path): + """A QB @remote function is NOT marked as LB route handler.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "worker.py").write_text( + """ +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_lb_route_handler is False + + def test_init_py_files_excluded(self, tmp_path): + """__init__.py files are excluded from scanning.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "__init__.py").write_text( + """ +from runpod_flash import LiveServerless, remote + +gpu_config = LiveServerless(name="gpu_worker") + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + (tmp_path / "worker.py").write_text( + """ +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + # Only the worker.py function should be discovered, not __init__.py + assert len(functions) == 1 + assert functions[0].file_path.name == "worker.py" From 0782671cab2104ad906f860b991c06264a5bc4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:30:59 -0800 Subject: [PATCH 03/25] refactor(manifest): remove mothership dead code, flat resource structure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove _serialize_routes, _create_mothership_resource, and _create_mothership_from_explicit — all referenced unimported symbols and caused F821 lint errors. The manifest now emits a flat resources dict with file_path, local_path_prefix, and module_path per resource; no is_mothership flag. --- .../cli/commands/build_utils/manifest.py | 205 ++------- .../build_utils/test_manifest_mothership.py | 404 ------------------ 2 files changed, 36 insertions(+), 573 deletions(-) delete mode 100644 tests/unit/cli/commands/build_utils/test_manifest_mothership.py diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index b67ce9bd..af2a283f 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -9,45 +9,17 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from runpod_flash.core.resources.constants import ( - DEFAULT_WORKERS_MAX, - DEFAULT_WORKERS_MIN, - FLASH_CPU_LB_IMAGE, - FLASH_LB_IMAGE, +from .scanner import ( + RemoteFunctionMetadata, + file_to_module_path, + file_to_url_prefix, ) -from .scanner import RemoteFunctionMetadata, detect_explicit_mothership, detect_main_app - logger = logging.getLogger(__name__) RESERVED_PATHS = ["/execute", "/ping"] -def _serialize_routes(routes: List[RemoteFunctionMetadata]) -> List[Dict[str, Any]]: - """Convert RemoteFunctionMetadata to manifest dict format. - - Args: - routes: List of route metadata objects - - Returns: - List of dicts with route information for manifest - """ - return [ - { - "name": route.function_name, - "module": route.module_path, - "is_async": route.is_async, - "is_class": route.is_class, - "is_load_balanced": route.is_load_balanced, - "is_live_resource": route.is_live_resource, - "config_variable": route.config_variable, - "http_method": route.http_method, - "http_path": route.http_path, - } - for route in routes - ] - - @dataclass class ManifestFunction: """Function entry in manifest.""" @@ -213,83 +185,13 @@ def _extract_deployment_config( return config - def _create_mothership_resource(self, main_app_config: dict) -> Dict[str, Any]: - """Create implicit mothership resource from main.py. - - Args: - main_app_config: Dict with 'file_path', 'app_variable', 'has_routes', 'fastapi_routes' keys - - Returns: - Dictionary representing the mothership resource for the manifest - """ - # Extract FastAPI routes if present - fastapi_routes = main_app_config.get("fastapi_routes", []) - functions_list = _serialize_routes(fastapi_routes) - - return { - "resource_type": "CpuLiveLoadBalancer", - "functions": functions_list, - "is_load_balanced": True, - "is_live_resource": True, - "is_mothership": True, - "main_file": main_app_config["file_path"].name, - "app_variable": main_app_config["app_variable"], - "imageName": FLASH_CPU_LB_IMAGE, - "workersMin": DEFAULT_WORKERS_MIN, - "workersMax": DEFAULT_WORKERS_MAX, - } - - def _create_mothership_from_explicit( - self, explicit_config: dict, search_dir: Path - ) -> Dict[str, Any]: - """Create mothership resource from explicit mothership.py configuration. - - Args: - explicit_config: Configuration dict from detect_explicit_mothership() - search_dir: Project directory + def build(self) -> Dict[str, Any]: + """Build the manifest dictionary. - Returns: - Dictionary representing the mothership resource for the manifest + Resources are keyed by resource_config_name for runtime compatibility. + Each resource entry includes file_path, local_path_prefix, and module_path + for the dev server and LB handler generator. """ - # Detect FastAPI app details for handler generation - main_app_config = detect_main_app(search_dir, explicit_mothership_exists=False) - - if not main_app_config: - # No FastAPI app found, use defaults - main_file = "main.py" - app_variable = "app" - fastapi_routes = [] - else: - main_file = main_app_config["file_path"].name - app_variable = main_app_config["app_variable"] - fastapi_routes = main_app_config.get("fastapi_routes", []) - - # Extract FastAPI routes into functions list - functions_list = _serialize_routes(fastapi_routes) - - # Map resource type to image name - resource_type = explicit_config.get("resource_type", "CpuLiveLoadBalancer") - if resource_type == "LiveLoadBalancer": - image_name = FLASH_LB_IMAGE # GPU load balancer - else: - image_name = FLASH_CPU_LB_IMAGE # CPU load balancer - - return { - "resource_type": resource_type, - "functions": functions_list, - "is_load_balanced": True, - "is_live_resource": True, - "is_mothership": True, - "is_explicit": True, # Flag to indicate explicit configuration - "main_file": main_file, - "app_variable": app_variable, - "imageName": image_name, - "workersMin": explicit_config.get("workersMin", DEFAULT_WORKERS_MIN), - "workersMax": explicit_config.get("workersMax", DEFAULT_WORKERS_MAX), - } - - def build(self) -> Dict[str, Any]: - """Build the manifest dictionary.""" # Group functions by resource_config_name resources: Dict[str, List[RemoteFunctionMetadata]] = {} @@ -305,6 +207,9 @@ def build(self) -> Dict[str, Any]: str, Dict[str, str] ] = {} # resource_name -> {route_key -> function_name} + # Determine project root for path derivation + project_root = self.build_dir.parent if self.build_dir else Path.cwd() + for resource_name, functions in sorted(resources.items()): # Use actual resource type from first function in group resource_type = ( @@ -315,6 +220,27 @@ def build(self) -> Dict[str, Any]: is_load_balanced = functions[0].is_load_balanced if functions else False is_live_resource = functions[0].is_live_resource if functions else False + # Derive path fields from the first function's source file. + # All functions in a resource share the same source file per convention. + first_file = functions[0].file_path if functions else None + file_path_str = "" + local_path_prefix = "" + resource_module_path = functions[0].module_path if functions else "" + + if first_file and first_file.exists(): + try: + file_path_str = str(first_file.relative_to(project_root)) + local_path_prefix = file_to_url_prefix(first_file, project_root) + resource_module_path = file_to_module_path(first_file, project_root) + except ValueError: + # File is outside project root — fall back to module_path + file_path_str = str(first_file) + local_path_prefix = "/" + functions[0].module_path.replace(".", "/") + elif first_file: + # File path may be relative (in test scenarios) + file_path_str = str(first_file) + local_path_prefix = "/" + functions[0].module_path.replace(".", "/") + # Validate and collect routing for LB endpoints resource_routes = {} if is_load_balanced: @@ -374,6 +300,9 @@ def build(self) -> Dict[str, Any]: resources_dict[resource_name] = { "resource_type": resource_type, + "file_path": file_path_str, + "local_path_prefix": local_path_prefix, + "module_path": resource_module_path, "functions": functions_list, "is_load_balanced": is_load_balanced, "is_live_resource": is_live_resource, @@ -395,68 +324,6 @@ def build(self) -> Dict[str, Any]: ) function_registry[f.function_name] = resource_name - # === MOTHERSHIP DETECTION (EXPLICIT THEN FALLBACK) === - search_dir = self.build_dir if self.build_dir else Path.cwd() - - # Step 1: Check for explicit mothership.py - explicit_mothership = detect_explicit_mothership(search_dir) - - if explicit_mothership: - # Use explicit configuration - logger.debug("Found explicit mothership configuration in mothership.py") - - # Check for name conflict - mothership_name = explicit_mothership.get("name", "mothership") - if mothership_name in resources_dict: - logger.warning( - f"Project has a @remote resource named '{mothership_name}'. " - f"Using 'mothership-entrypoint' for explicit mothership endpoint." - ) - mothership_name = "mothership-entrypoint" - - # Create mothership resource from explicit config - mothership_resource = self._create_mothership_from_explicit( - explicit_mothership, search_dir - ) - resources_dict[mothership_name] = mothership_resource - - else: - # Step 2: Fallback to auto-detection - main_app_config = detect_main_app( - search_dir, explicit_mothership_exists=False - ) - - if main_app_config and main_app_config["has_routes"]: - logger.warning( - "Auto-detected FastAPI app in main.py (no mothership.py found). " - "Consider running 'flash init' to create explicit mothership configuration." - ) - - # Check for name conflict - if "mothership" in resources_dict: - logger.warning( - "Project has a @remote resource named 'mothership'. " - "Using 'mothership-entrypoint' for auto-generated mothership endpoint." - ) - mothership_name = "mothership-entrypoint" - else: - mothership_name = "mothership" - - # Create mothership resource from auto-detection (legacy behavior) - mothership_resource = self._create_mothership_resource(main_app_config) - resources_dict[mothership_name] = mothership_resource - - # Extract routes from mothership resources - for resource_name, resource in resources_dict.items(): - if resource.get("is_mothership") and resource.get("functions"): - mothership_routes = {} - for func in resource["functions"]: - if func.get("http_method") and func.get("http_path"): - route_key = f"{func['http_method']} {func['http_path']}" - mothership_routes[route_key] = func["name"] - if mothership_routes: - routes_dict[resource_name] = mothership_routes - manifest = { "version": "1.0", "generated_at": datetime.now(timezone.utc) diff --git a/tests/unit/cli/commands/build_utils/test_manifest_mothership.py b/tests/unit/cli/commands/build_utils/test_manifest_mothership.py deleted file mode 100644 index 896eefdf..00000000 --- a/tests/unit/cli/commands/build_utils/test_manifest_mothership.py +++ /dev/null @@ -1,404 +0,0 @@ -"""Tests for mothership resource creation in manifest.""" - -import tempfile -from pathlib import Path -from unittest.mock import patch - -from runpod_flash.cli.commands.build_utils.manifest import ManifestBuilder -from runpod_flash.cli.commands.build_utils.scanner import RemoteFunctionMetadata -from runpod_flash.core.resources.constants import ( - FLASH_CPU_LB_IMAGE, - FLASH_LB_IMAGE, -) - - -class TestManifestMothership: - """Test mothership resource creation in manifest.""" - - def test_manifest_includes_mothership_with_main_py(self): - """Test mothership resource added to manifest when main.py detected.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create a simple function file - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -gpu_config = LiveServerless(name="gpu_worker") - -@remote(resource_config=gpu_config) -def process(data): - return data -""" - ) - - # Change to project directory for detection - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", - remote_functions=[], - ) - manifest = builder.build() - - # Check mothership is in resources - assert "mothership" in manifest["resources"] - mothership = manifest["resources"]["mothership"] - assert mothership["is_mothership"] is True - assert mothership["main_file"] == "main.py" - assert mothership["app_variable"] == "app" - assert mothership["resource_type"] == "CpuLiveLoadBalancer" - assert mothership["imageName"] == FLASH_CPU_LB_IMAGE - - def test_manifest_skips_mothership_without_routes(self): - """Test mothership NOT added if main.py has no routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py without routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() -# No routes defined -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Mothership should NOT be in resources - assert "mothership" not in manifest["resources"] - - def test_manifest_skips_mothership_without_main_py(self): - """Test mothership NOT added if no main.py exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Mothership should NOT be in resources - assert "mothership" not in manifest["resources"] - - def test_manifest_handles_mothership_name_conflict(self): - """Test mothership uses alternate name if conflict with @remote resource.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create a remote function with name "mothership" (conflict) - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -mothership_config = LiveServerless(name="mothership") - -@remote(resource_config=mothership_config) -def process(data): - return data -""" - ) - - # Create remote function metadata with resource named "mothership" - remote_func = RemoteFunctionMetadata( - function_name="process", - module_path="functions", - resource_config_name="mothership", - resource_type="LiveServerless", - is_async=False, - is_class=False, - file_path=func_file, - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", remote_functions=[remote_func] - ) - manifest = builder.build() - - # Original mothership should be in resources - assert "mothership" in manifest["resources"] - # Auto-generated mothership should use alternate name - assert "mothership-entrypoint" in manifest["resources"] - entrypoint = manifest["resources"]["mothership-entrypoint"] - assert entrypoint["is_mothership"] is True - - def test_mothership_resource_config(self): - """Test mothership resource has correct configuration.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - mothership = manifest["resources"]["mothership"] - - # Check all expected fields - assert mothership["resource_type"] == "CpuLiveLoadBalancer" - # Functions should include the FastAPI route - assert len(mothership["functions"]) == 1 - assert mothership["functions"][0]["name"] == "root" - assert mothership["functions"][0]["http_method"] == "GET" - assert mothership["functions"][0]["http_path"] == "/" - assert mothership["is_load_balanced"] is True - assert mothership["is_live_resource"] is True - assert mothership["imageName"] == FLASH_CPU_LB_IMAGE - assert mothership["workersMin"] == 1 - assert mothership["workersMax"] == 1 - - def test_manifest_uses_explicit_mothership_config(self): - """Test explicit mothership.py config takes precedence over auto-detection.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create explicit mothership.py with custom config - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="my-api", - workersMin=3, - workersMax=7, -) -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Check explicit config is used - assert "my-api" in manifest["resources"] - mothership = manifest["resources"]["my-api"] - assert mothership["is_explicit"] is True - assert mothership["workersMin"] == 3 - assert mothership["workersMax"] == 7 - - def test_manifest_skips_auto_detect_with_explicit_config(self): - """Test auto-detection is skipped when explicit config exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create explicit mothership.py - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="explicit-mothership", - workersMin=2, - workersMax=4, -) -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Check only explicit config is in resources (not auto-detected "mothership") - assert "explicit-mothership" in manifest["resources"] - assert ( - manifest["resources"]["explicit-mothership"]["is_explicit"] is True - ) - assert "mothership" not in manifest["resources"] - - def test_manifest_handles_explicit_mothership_name_conflict(self): - """Test explicit mothership uses alternate name if conflict with @remote.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create explicit mothership.py with name that conflicts with resource - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="api", # Will conflict with @remote resource named "api" - workersMin=1, - workersMax=3, -) -""" - ) - - # Create a remote function with name "api" (conflict) - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -api_config = LiveServerless(name="api") - -@remote(resource_config=api_config) -def process(data): - return data -""" - ) - - remote_func = RemoteFunctionMetadata( - function_name="process", - module_path="functions", - resource_config_name="api", - resource_type="LiveServerless", - is_async=False, - is_class=False, - file_path=func_file, - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", remote_functions=[remote_func] - ) - manifest = builder.build() - - # Original resource should be in resources - assert "api" in manifest["resources"] - # Explicit mothership should use alternate name - assert "mothership-entrypoint" in manifest["resources"] - entrypoint = manifest["resources"]["mothership-entrypoint"] - assert entrypoint["is_explicit"] is True - - def test_manifest_explicit_mothership_with_gpu_load_balancer(self): - """Test explicit GPU-based load balancer config.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create explicit mothership.py with GPU load balancer - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import LiveLoadBalancer - -mothership = LiveLoadBalancer( - name="gpu-mothership", - workersMin=1, - workersMax=2, -) -""" - ) - - # Create main.py for FastAPI app - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - mothership = manifest["resources"]["gpu-mothership"] - assert mothership["resource_type"] == "LiveLoadBalancer" - assert mothership["imageName"] == FLASH_LB_IMAGE - assert mothership["is_explicit"] is True From 9b38bd8e5e645f8407ff32e592b95c354fcf9621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:31:42 -0800 Subject: [PATCH 04/25] feat(run): file-system-as-namespace dev server generation flash run now scans the project for all @remote functions, generates .flash/server.py with routes derived from file paths, and starts uvicorn with --app-dir .flash/. Route convention: gpu_worker.py -> /gpu_worker/run and /gpu_worker/run_sync; subdirectory files produce matching URL prefixes. Cleanup on Ctrl+C is fixed: _cleanup_live_endpoints now reads .runpod/resources.pkl written by the uvicorn subprocess and deprovisions all live- prefixed endpoints, removing the dead in-process _SESSION_ENDPOINTS approach which never received data from the subprocess. --- src/runpod_flash/cli/commands/run.py | 581 +++++++++++++------ tests/integration/test_run_auto_provision.py | 337 ----------- tests/unit/cli/test_run.py | 12 +- 3 files changed, 397 insertions(+), 533 deletions(-) delete mode 100644 tests/integration/test_run_auto_provision.py diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 051115a8..faf4f50b 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -5,16 +5,361 @@ import signal import subprocess import sys +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import List -import questionary import typer from rich.console import Console +from rich.table import Table + +from .build_utils.scanner import ( + RemoteDecoratorScanner, + file_to_module_path, + file_to_resource_name, + file_to_url_prefix, +) logger = logging.getLogger(__name__) console = Console() +# Resource state file written by ResourceManager in the uvicorn subprocess. +_RESOURCE_STATE_FILE = Path(".runpod") / "resources.pkl" + + +@dataclass +class WorkerInfo: + """Info about a discovered @remote function for dev server generation.""" + + file_path: Path + url_prefix: str # e.g. /longruns/stage1 + module_path: str # e.g. longruns.stage1 + resource_name: str # e.g. longruns_stage1 + worker_type: str # "QB" or "LB" + functions: List[str] # function names + lb_routes: List[dict] = field(default_factory=list) # [{method, path, fn_name}] + + +def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: + """Scan the project for all @remote decorated functions. + + Walks all .py files (excluding .flash/, __pycache__, __init__.py) and + builds WorkerInfo for each file that contains @remote functions. + + Files with QB functions produce one WorkerInfo per file (QB type). + Files with LB functions produce one WorkerInfo per file (LB type). + A file can have both QB and LB functions (unusual but supported). + + Args: + project_root: Root directory of the Flash project + + Returns: + List of WorkerInfo, one entry per discovered source file + """ + scanner = RemoteDecoratorScanner(project_root) + remote_functions = scanner.discover_remote_functions() + + # Group by file path + by_file: dict[Path, List] = {} + for func in remote_functions: + by_file.setdefault(func.file_path, []).append(func) + + workers: List[WorkerInfo] = [] + for file_path, funcs in sorted(by_file.items()): + url_prefix = file_to_url_prefix(file_path, project_root) + module_path = file_to_module_path(file_path, project_root) + resource_name = file_to_resource_name(file_path, project_root) + + qb_funcs = [f for f in funcs if not f.is_load_balanced] + lb_funcs = [f for f in funcs if f.is_load_balanced and f.is_lb_route_handler] + + if qb_funcs: + workers.append( + WorkerInfo( + file_path=file_path, + url_prefix=url_prefix, + module_path=module_path, + resource_name=resource_name, + worker_type="QB", + functions=[f.function_name for f in qb_funcs], + ) + ) + + if lb_funcs: + lb_routes = [ + { + "method": f.http_method, + "path": f.http_path, + "fn_name": f.function_name, + } + for f in lb_funcs + ] + workers.append( + WorkerInfo( + file_path=file_path, + url_prefix=url_prefix, + module_path=module_path, + resource_name=resource_name, + worker_type="LB", + functions=[f.function_name for f in lb_funcs], + lb_routes=lb_routes, + ) + ) + + return workers + + +def _ensure_gitignore(project_root: Path) -> None: + """Add .flash/ to .gitignore if not already present.""" + gitignore = project_root / ".gitignore" + entry = ".flash/" + + if gitignore.exists(): + content = gitignore.read_text(encoding="utf-8") + if entry in content: + return + # Append with a newline + if not content.endswith("\n"): + content += "\n" + gitignore.write_text(content + entry + "\n", encoding="utf-8") + else: + gitignore.write_text(entry + "\n", encoding="utf-8") + + +def _sanitize_fn_name(name: str) -> str: + """Sanitize a string for use as a Python function name.""" + return name.replace("/", "_").replace(".", "_").replace("-", "_") + + +def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: + """Generate .flash/server.py from the discovered workers. + + Args: + project_root: Root of the Flash project + workers: List of discovered worker infos + + Returns: + Path to the generated server.py + """ + flash_dir = project_root / ".flash" + flash_dir.mkdir(exist_ok=True) + + _ensure_gitignore(project_root) + + lines = [ + '"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run."""', + "import sys", + "import uuid", + "from pathlib import Path", + "sys.path.insert(0, str(Path(__file__).parent.parent))", + "", + "from fastapi import FastAPI", + "", + ] + + # Collect all imports + all_imports: List[str] = [] + for worker in workers: + for fn_name in worker.functions: + all_imports.append(f"from {worker.module_path} import {fn_name}") + + if all_imports: + lines.extend(all_imports) + lines.append("") + + lines += [ + "app = FastAPI(", + ' title="Flash Dev Server",', + ' description="Auto-generated by `flash run`. Visit /docs for interactive testing.",', + ")", + "", + ] + + for worker in workers: + tag = f"{worker.url_prefix.lstrip('/')} [{worker.worker_type}]" + lines.append(f"# {'─' * 60}") + lines.append(f"# {worker.worker_type}: {worker.file_path.name}") + lines.append(f"# {'─' * 60}") + + if worker.worker_type == "QB": + if len(worker.functions) == 1: + fn = worker.functions[0] + handler_name = _sanitize_fn_name(f"{worker.resource_name}_run") + run_path = f"{worker.url_prefix}/run" + sync_path = f"{worker.url_prefix}/run_sync" + lines += [ + f'@app.post("{run_path}", tags=["{tag}"])', + f'@app.post("{sync_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f' result = await {fn}(body.get("input", body))', + ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', + "", + ] + else: + for fn in worker.functions: + handler_name = _sanitize_fn_name(f"{worker.resource_name}_{fn}_run") + run_path = f"{worker.url_prefix}/{fn}/run" + sync_path = f"{worker.url_prefix}/{fn}/run_sync" + lines += [ + f'@app.post("{run_path}", tags=["{tag}"])', + f'@app.post("{sync_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f' result = await {fn}(body.get("input", body))', + ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', + "", + ] + + elif worker.worker_type == "LB": + for route in worker.lb_routes: + method = route["method"].lower() + sub_path = route["path"].lstrip("/") + fn_name = route["fn_name"] + full_path = f"{worker.url_prefix}/{sub_path}" + handler_name = _sanitize_fn_name( + f"_route_{worker.resource_name}_{fn_name}" + ) + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f" return await {fn_name}(body)", + "", + ] + + # Health endpoints + lines += [ + "# Health", + '@app.get("/", tags=["health"])', + "def home():", + ' return {"message": "Flash Dev Server", "docs": "/docs"}', + "", + '@app.get("/ping", tags=["health"])', + "def ping():", + ' return {"status": "healthy"}', + "", + ] + + server_path = flash_dir / "server.py" + server_path.write_text("\n".join(lines), encoding="utf-8") + return server_path + + +def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> None: + """Print the startup table showing local paths, resource names, and types.""" + console.print(f"\n[bold green]Flash Dev Server[/bold green] http://{host}:{port}") + console.print() + + table = Table(show_header=True, header_style="bold") + table.add_column("Local path", style="cyan") + table.add_column("Resource", style="white") + table.add_column("Type", style="yellow") + + for worker in workers: + if worker.worker_type == "QB": + if len(worker.functions) == 1: + table.add_row( + f"POST {worker.url_prefix}/run", + worker.resource_name, + "QB", + ) + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) + else: + for fn in worker.functions: + table.add_row( + f"POST {worker.url_prefix}/{fn}/run", + worker.resource_name, + "QB", + ) + table.add_row( + f"POST {worker.url_prefix}/{fn}/run_sync", + worker.resource_name, + "QB", + ) + elif worker.worker_type == "LB": + for route in worker.lb_routes: + sub_path = route["path"].lstrip("/") + full_path = f"{worker.url_prefix}/{sub_path}" + table.add_row( + f"{route['method']} {full_path}", + worker.resource_name, + "LB", + ) + + console.print(table) + console.print(f"\n Visit [bold]http://{host}:{port}/docs[/bold] for Swagger UI\n") + + +def _cleanup_live_endpoints() -> None: + """Deprovision all Live Serverless endpoints created during this session. + + Reads the resource state file written by the uvicorn subprocess, finds + all endpoints with the 'live-' name prefix, and deprovisions them. + Best-effort: errors per endpoint are logged but do not prevent cleanup + of other endpoints. + """ + if not _RESOURCE_STATE_FILE.exists(): + return + + try: + import asyncio + import cloudpickle + from ...core.utils.file_lock import file_lock + + with open(_RESOURCE_STATE_FILE, "rb") as f: + with file_lock(f, exclusive=False): + data = cloudpickle.load(f) + + if isinstance(data, tuple): + resources, configs = data + else: + resources, configs = data, {} + + live_items = { + key: resource + for key, resource in resources.items() + if hasattr(resource, "name") + and resource.name + and resource.name.startswith("live-") + } + + if not live_items: + return + + async def _do_cleanup(): + for key, resource in live_items.items(): + name = getattr(resource, "name", key) + try: + success = await resource._do_undeploy() + if success: + console.print(f" Deprovisioned: {name}") + else: + logger.warning(f"Failed to deprovision: {name}") + except Exception as e: + logger.warning(f"Error deprovisioning {name}: {e}") + + asyncio.run(_do_cleanup()) + + # Remove live- entries from persisted state so they don't linger. + remaining = {k: v for k, v in resources.items() if k not in live_items} + remaining_configs = {k: v for k, v in configs.items() if k not in live_items} + try: + with open(_RESOURCE_STATE_FILE, "wb") as f: + with file_lock(f, exclusive=True): + cloudpickle.dump((remaining, remaining_configs), f) + except Exception as e: + logger.warning(f"Could not update resource state after cleanup: {e}") + + except Exception as e: + logger.warning(f"Live endpoint cleanup failed: {e}") + + +def _is_reload() -> bool: + """Check if running in uvicorn reload subprocess.""" + return "UVICORN_RELOADER_PID" in os.environ + def run_command( host: str = typer.Option( @@ -33,68 +378,51 @@ def run_command( reload: bool = typer.Option( True, "--reload/--no-reload", help="Enable auto-reload" ), - auto_provision: bool = typer.Option( - False, - "--auto-provision", - help="Auto-provision deployable resources on startup", - ), ): - """Run Flash development server with uvicorn.""" + """Run Flash development server. - # Discover entry point - entry_point = discover_entry_point() - if not entry_point: - console.print("[red]Error:[/red] No entry point found") - console.print("Create main.py with a FastAPI app") - raise typer.Exit(1) + Scans the project for @remote decorated functions, generates a dev server + at .flash/server.py, and starts uvicorn with hot-reload. - # Check if entry point has FastAPI app - app_location = check_fastapi_app(entry_point) - if not app_location: - console.print(f"[red]Error:[/red] No FastAPI app found in {entry_point}") - console.print("Make sure your main.py contains: app = FastAPI()") - raise typer.Exit(1) + No main.py or FastAPI boilerplate required. Any .py file with @remote + decorated functions is a valid Flash project. + """ + project_root = Path.cwd() - # Set flag for all flash run sessions to ensure both auto-provisioned - # and on-the-fly provisioned resources get the live- prefix + # Set flag for live provisioning so stubs get the live- prefix if not _is_reload(): os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true" - # Auto-provision resources if flag is set and not a reload - if auto_provision and not _is_reload(): - try: - resources = _discover_resources(entry_point) + # Discover @remote functions + workers = _scan_project_workers(project_root) - if resources: - # If many resources found, ask for confirmation - if len(resources) > 5: - if not _confirm_large_provisioning(resources): - console.print("[yellow]Auto-provisioning cancelled[/yellow]\n") - else: - _provision_resources(resources) - else: - _provision_resources(resources) - except Exception as e: - logger.error("Auto-provisioning failed", exc_info=True) - console.print( - f"[yellow]Warning:[/yellow] Resource provisioning failed: {e}" - ) - console.print( - "[yellow]Note:[/yellow] Resources will be deployed on-demand when first called" - ) + if not workers: + console.print("[red]Error:[/red] No @remote functions found.") + console.print("Add @remote decorators to your functions to get started.") + console.print("\nExample:") + console.print( + " from runpod_flash import LiveServerless, remote\n" + " gpu_config = LiveServerless(name='my_worker')\n" + "\n" + " @remote(gpu_config)\n" + " async def process(input_data: dict) -> dict:\n" + " return {'result': input_data}" + ) + raise typer.Exit(1) - console.print("\n[green]Starting Flash Server[/green]") - console.print(f"Entry point: [bold]{app_location}[/bold]") - console.print(f"Server: [bold]http://{host}:{port}[/bold]") - console.print(f"Auto-reload: [bold]{'enabled' if reload else 'disabled'}[/bold]") - console.print("\nPress CTRL+C to stop\n") + # Generate .flash/server.py + _generate_flash_server(project_root, workers) - # Build uvicorn command + _print_startup_table(workers, host, port) + + # Build uvicorn command using --app-dir so server:app is importable cmd = [ sys.executable, "-m", "uvicorn", - app_location, + "server:app", + "--app-dir", + ".flash", "--host", host, "--port", @@ -104,13 +432,16 @@ def run_command( ] if reload: - cmd.append("--reload") + cmd += [ + "--reload", + "--reload-dir", + ".", + "--reload-include", + "*.py", + ] - # Run uvicorn with proper process group handling process = None try: - # Create new process group to ensure all child processes can be killed together - # On Unix systems, use process group; on Windows, CREATE_NEW_PROCESS_GROUP if sys.platform == "win32": process = subprocess.Popen( cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP @@ -118,27 +449,21 @@ def run_command( else: process = subprocess.Popen(cmd, preexec_fn=os.setsid) - # Wait for process to complete process.wait() except KeyboardInterrupt: - console.print("\n[yellow]Stopping server and cleaning up processes...[/yellow]") + console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") - # Kill the entire process group to ensure all child processes are terminated if process: try: if sys.platform == "win32": - # Windows: terminate the process process.terminate() else: - # Unix: kill entire process group os.killpg(os.getpgid(process.pid), signal.SIGTERM) - # Wait briefly for graceful shutdown try: process.wait(timeout=2) except subprocess.TimeoutExpired: - # Force kill if didn't terminate gracefully if sys.platform == "win32": process.kill() else: @@ -146,9 +471,9 @@ def run_command( process.wait() except (ProcessLookupError, OSError): - # Process already terminated pass + _cleanup_live_endpoints() console.print("[green]Server stopped[/green]") raise typer.Exit(0) @@ -162,135 +487,5 @@ def run_command( os.killpg(os.getpgid(process.pid), signal.SIGTERM) except (ProcessLookupError, OSError): pass + _cleanup_live_endpoints() raise typer.Exit(1) - - -def discover_entry_point() -> Optional[str]: - """Discover the main entry point file.""" - candidates = ["main.py", "app.py", "server.py"] - - for candidate in candidates: - if Path(candidate).exists(): - return candidate - - return None - - -def check_fastapi_app(entry_point: str) -> Optional[str]: - """ - Check if entry point has a FastAPI app and return the app location. - - Returns: - App location in format "module:app" or None - """ - try: - # Read the file - content = Path(entry_point).read_text() - - # Check for FastAPI app - if "app = FastAPI(" in content or "app=FastAPI(" in content: - # Extract module name from file path - module = entry_point.replace(".py", "").replace("/", ".") - return f"{module}:app" - - return None - - except Exception: - return None - - -def _is_reload() -> bool: - """Check if running in uvicorn reload subprocess. - - Returns: - True if running in a reload subprocess - """ - return "UVICORN_RELOADER_PID" in os.environ - - -def _discover_resources(entry_point: str): - """Discover deployable resources in entry point. - - Args: - entry_point: Path to entry point file - - Returns: - List of discovered DeployableResource instances - """ - from ...core.discovery import ResourceDiscovery - - try: - discovery = ResourceDiscovery(entry_point, max_depth=2) - resources = discovery.discover() - - # Debug: Log what was discovered - if resources: - console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") - for res in resources: - res_name = getattr(res, "name", "Unknown") - res_type = res.__class__.__name__ - console.print(f" [dim]• {res_name} ({res_type})[/dim]") - console.print() - - return resources - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Resource discovery failed: {e}") - return [] - - -def _confirm_large_provisioning(resources) -> bool: - """Show resources and prompt user for confirmation. - - Args: - resources: List of resources to provision - - Returns: - True if user confirms, False otherwise - """ - try: - console.print( - f"\n[yellow]Found {len(resources)} resources to provision:[/yellow]" - ) - - for resource in resources: - name = getattr(resource, "name", "Unknown") - resource_type = resource.__class__.__name__ - console.print(f" • {name} ({resource_type})") - - console.print() - - confirmed = questionary.confirm( - "This may take several minutes. Do you want to proceed?" - ).ask() - - return confirmed if confirmed is not None else False - - except (KeyboardInterrupt, EOFError): - console.print("\n[yellow]Cancelled[/yellow]") - return False - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Confirmation failed: {e}") - return False - - -def _provision_resources(resources): - """Provision resources and wait for completion. - - Args: - resources: List of resources to provision - """ - import asyncio - from ...core.deployment import DeploymentOrchestrator - - try: - console.print(f"\n[bold]Provisioning {len(resources)} resource(s)...[/bold]") - orchestrator = DeploymentOrchestrator(max_concurrent=3) - - # Run provisioning with progress shown - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(orchestrator.deploy_all(resources, show_progress=True)) - loop.close() - - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Provisioning failed: {e}") diff --git a/tests/integration/test_run_auto_provision.py b/tests/integration/test_run_auto_provision.py deleted file mode 100644 index 9478f442..00000000 --- a/tests/integration/test_run_auto_provision.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Integration tests for flash run --auto-provision command.""" - -import pytest -from unittest.mock import patch, MagicMock -from textwrap import dedent -from typer.testing import CliRunner - -from runpod_flash.cli.main import app - -runner = CliRunner() - - -class TestRunAutoProvision: - """Test flash run --auto-provision integration.""" - - @pytest.fixture - def temp_project(self, tmp_path): - """Create temporary Flash project for testing.""" - # Create main.py with FastAPI app - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - app = FastAPI() - - gpu_config = ServerlessResource( - name="test-gpu", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(resource_config=gpu_config) - async def gpu_task(): - return "result" - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - return tmp_path - - @pytest.fixture - def temp_project_many_resources(self, tmp_path): - """Create temporary project with many resources (> 5).""" - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - app = FastAPI() - - # Create 6 resources to trigger confirmation - configs = [ - ServerlessResource( - name=f"endpoint-{i}", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - for i in range(6) - ] - - @remote(resource_config=configs[0]) - async def task1(): pass - - @remote(resource_config=configs[1]) - async def task2(): pass - - @remote(resource_config=configs[2]) - async def task3(): pass - - @remote(resource_config=configs[3]) - async def task4(): pass - - @remote(resource_config=configs[4]) - async def task5(): pass - - @remote(resource_config=configs[5]) - async def task6(): pass - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - return tmp_path - - def test_run_without_auto_provision(self, temp_project, monkeypatch): - """Test that flash run without --auto-provision doesn't deploy resources.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations to prevent hanging - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to track if it was called - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - runner.invoke(app, ["run"]) - - # Discovery should not be called - mock_discover.assert_not_called() - - def test_run_with_auto_provision_single_resource(self, temp_project, monkeypatch): - """Test flash run --auto-provision with single resource.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock deployment orchestrator - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should be called - mock_provision.assert_called_once() - - def test_run_with_auto_provision_skips_reload(self, temp_project, monkeypatch): - """Test that auto-provision is skipped on reload.""" - monkeypatch.chdir(temp_project) - - # Simulate reload environment - monkeypatch.setenv("UVICORN_RELOADER_PID", "12345") - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock provisioning - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should NOT be called on reload - mock_provision.assert_not_called() - - def test_run_with_auto_provision_many_resources_confirmed( - self, temp_project, monkeypatch - ): - """Test auto-provision with > 5 resources and user confirmation.""" - monkeypatch.chdir(temp_project) - - # Create 6 mock resources - mock_resources = [MagicMock(name=f"endpoint-{i}") for i in range(6)] - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to return > 5 resources - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = mock_resources - - # Mock questionary to simulate user confirmation - with patch( - "runpod_flash.cli.commands.run.questionary.confirm" - ) as mock_confirm: - mock_confirm.return_value.ask.return_value = True - - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Should prompt for confirmation - mock_confirm.assert_called_once() - - # Should provision after confirmation - mock_provision.assert_called_once() - - def test_run_with_auto_provision_many_resources_cancelled( - self, temp_project, monkeypatch - ): - """Test auto-provision with > 5 resources and user cancellation.""" - monkeypatch.chdir(temp_project) - - # Create 6 mock resources - mock_resources = [MagicMock(name=f"endpoint-{i}") for i in range(6)] - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to return > 5 resources - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = mock_resources - - # Mock questionary to simulate user cancellation - with patch( - "runpod_flash.cli.commands.run.questionary.confirm" - ) as mock_confirm: - mock_confirm.return_value.ask.return_value = False - - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Should prompt for confirmation - mock_confirm.assert_called_once() - - # Should NOT provision after cancellation - mock_provision.assert_not_called() - - def test_run_auto_provision_discovery_error(self, temp_project, monkeypatch): - """Test that run handles discovery errors gracefully.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to raise exception - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = [] - - runner.invoke(app, ["run", "--auto-provision"]) - - # Server should still start despite discovery error - mock_popen.assert_called_once() - - def test_run_auto_provision_no_resources_found(self, tmp_path, monkeypatch): - """Test auto-provision when no resources are found.""" - monkeypatch.chdir(tmp_path) - - # Create main.py without any @remote decorators - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - - app = FastAPI() - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should not be called - mock_provision.assert_not_called() - - # Server should still start - mock_popen.assert_called_once() diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index a652aa75..cf7eb5fd 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -15,9 +15,15 @@ def runner(): @pytest.fixture def temp_fastapi_app(tmp_path): - """Create minimal FastAPI app for testing.""" - main_file = tmp_path / "main.py" - main_file.write_text("from fastapi import FastAPI\napp = FastAPI()") + """Create minimal Flash project with @remote function for testing.""" + worker_file = tmp_path / "worker.py" + worker_file.write_text( + "from runpod_flash import LiveServerless, remote\n" + "gpu_config = LiveServerless(name='test_worker')\n" + "@remote(gpu_config)\n" + "async def process(data: dict) -> dict:\n" + " return data\n" + ) return tmp_path From 098aafc707493364cd50ef7caf71089c0cbfe7f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:32:30 -0800 Subject: [PATCH 05/25] feat(build,lb_handler_generator): invoke LB handler generator, rglob project validation LBHandlerGenerator is now called from run_build() for all is_load_balanced resources, wiring the build pipeline to the new module_path-based handler generation. validate_project_structure switches from glob to rglob so projects with files only in subdirectories (e.g. 00_multi_resource) are not incorrectly rejected. lb_handler_generator loses the mothership reconciliation lifespan (StateManagerClient, reconcile_children) in favour of a clean startup/shutdown lifespan. --- src/runpod_flash/cli/commands/build.py | 25 ++++----- .../build_utils/lb_handler_generator.py | 54 +------------------ 2 files changed, 11 insertions(+), 68 deletions(-) diff --git a/src/runpod_flash/cli/commands/build.py b/src/runpod_flash/cli/commands/build.py index 44a1000e..1945c1ad 100644 --- a/src/runpod_flash/cli/commands/build.py +++ b/src/runpod_flash/cli/commands/build.py @@ -23,6 +23,7 @@ from runpod_flash.core.resources.constants import MAX_TARBALL_SIZE_MB from ..utils.ignore import get_file_tree, load_ignore_patterns +from .build_utils.lb_handler_generator import LBHandlerGenerator from .build_utils.manifest import ManifestBuilder from .build_utils.scanner import RemoteDecoratorScanner @@ -240,6 +241,9 @@ def run_build( manifest_path = build_dir / "flash_manifest.json" manifest_path.write_text(json.dumps(manifest, indent=2)) + lb_generator = LBHandlerGenerator(manifest, build_dir) + lb_generator.generate_handlers() + flash_dir = project_dir / ".flash" deployment_manifest_path = flash_dir / "flash_manifest.json" shutil.copy2(manifest_path, deployment_manifest_path) @@ -426,28 +430,19 @@ def validate_project_structure(project_dir: Path) -> bool: """ Validate that directory is a Flash project. + A Flash project is any directory containing Python files. The + RemoteDecoratorScanner validates that @remote functions exist. + Args: project_dir: Directory to validate Returns: True if valid Flash project """ - main_py = project_dir / "main.py" - - if not main_py.exists(): - console.print(f"[red]Error:[/red] main.py not found in {project_dir}") + py_files = list(project_dir.rglob("*.py")) + if not py_files: + console.print(f"[red]Error:[/red] No Python files found in {project_dir}") return False - - # Check if main.py has FastAPI app - try: - content = main_py.read_text(encoding="utf-8") - if "FastAPI" not in content: - console.print( - "[yellow]Warning:[/yellow] main.py does not appear to have a FastAPI app" - ) - except Exception: - pass - return True diff --git a/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py b/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py index dcd0845d..a0d28601 100644 --- a/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py +++ b/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py @@ -21,13 +21,10 @@ - Real-time communication patterns """ -import asyncio import logging from contextlib import asynccontextmanager -from pathlib import Path -from typing import Optional -from fastapi import FastAPI, Request +from fastapi import FastAPI from runpod_flash.runtime.lb_handler import create_lb_handler logger = logging.getLogger(__name__) @@ -45,57 +42,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): """Handle application startup and shutdown.""" - # Startup logger.info("Starting {resource_name} endpoint") - - # Check if this is the mothership and run reconciliation - # Note: Resources are now provisioned upfront by the CLI during deployment. - # This background task runs reconciliation on mothership startup to ensure - # all resources are still deployed and in sync with the manifest. - try: - from runpod_flash.runtime.mothership_provisioner import ( - is_mothership, - reconcile_children, - get_mothership_url, - ) - from runpod_flash.runtime.state_manager_client import StateManagerClient - - if is_mothership(): - logger.info("=" * 60) - logger.info("Mothership detected - Starting reconciliation task") - logger.info("Resources are provisioned upfront by the CLI") - logger.info("This task ensures all resources remain in sync") - logger.info("=" * 60) - try: - mothership_url = get_mothership_url() - logger.info(f"Mothership URL: {{mothership_url}}") - - # Initialize State Manager client for reconciliation - state_client = StateManagerClient() - - # Spawn background reconciliation task (non-blocking) - # This will verify all resources from manifest are deployed - manifest_path = Path(__file__).parent / "flash_manifest.json" - task = asyncio.create_task( - reconcile_children(manifest_path, mothership_url, state_client) - ) - # Add error callback to catch and log background task exceptions - task.add_done_callback( - lambda t: logger.error(f"Reconciliation task failed: {{t.exception()}}") - if t.exception() - else None - ) - - except Exception as e: - logger.error(f"Failed to start reconciliation task: {{e}}") - # Don't fail startup - continue serving traffic - - except ImportError: - logger.debug("Mothership provisioning modules not available") - yield - - # Shutdown logger.info("Shutting down {resource_name} endpoint") From ceb099f8dc440d77251d634a5987d4aa2d8903c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:33:32 -0800 Subject: [PATCH 06/25] fix(serverless): resolve flash run runtime bugs is_deployed skips the health check when FLASH_IS_LIVE_PROVISIONING=true. Newly created endpoints can fail RunPod's health API for a few seconds after creation (propagation delay), causing get_or_deploy_resource to trigger a spurious re-deploy on the second request (e.g. /run_sync immediately after /run). _payload_exclude now excludes template when templateId is already set. After first deployment _do_deploy sets templateId on the config object while the set_serverless_template validator has already set template at construction time. Sending both fields in the same payload causes RunPod to return 'You can only provide one of templateId or template.' Also adds _get_module_path helper and injects FLASH_MODULE_PATH into LB endpoint environment at deploy time so the deployed handler can import the correct user module. --- src/runpod_flash/core/resources/serverless.py | 60 ++++++++++++++++++- tests/unit/resources/test_serverless.py | 44 ++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index e7a14403..4793d5f7 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -474,6 +474,14 @@ def is_deployed(self) -> bool: if not self.id: return False + # During flash run, skip the health check. Newly-created endpoints + # can fail health checks due to RunPod propagation delay — the + # endpoint exists but the health API hasn't registered it yet. + # Trusting the cached ID is correct here; actual failures surface + # on the first real run/run_sync call. + if os.getenv("FLASH_IS_LIVE_PROVISIONING", "").lower() == "true": + return True + response = self.endpoint.health() return response is not None except Exception as e: @@ -484,6 +492,10 @@ def _payload_exclude(self) -> Set[str]: # flashEnvironmentId is input-only but must be sent when provided exclude_fields = set(self._input_only or set()) exclude_fields.discard("flashEnvironmentId") + # When templateId is already set, exclude template from the payload. + # RunPod rejects requests that contain both fields simultaneously. + if self.templateId: + exclude_fields.add("template") return exclude_fields @staticmethod @@ -564,12 +576,45 @@ def _check_makes_remote_calls(self) -> bool: ) return True # Safe default on error + def _get_module_path(self) -> Optional[str]: + """Get module_path from build manifest for this resource. + + Returns: + Dotted module path (e.g., 'preprocess.first_pass'), or None if not found. + """ + try: + manifest_path = Path.cwd() / "flash_manifest.json" + if not manifest_path.exists(): + manifest_path = Path("/flash_manifest.json") + if not manifest_path.exists(): + return None + + with open(manifest_path) as f: + manifest_data = json.load(f) + + resources = manifest_data.get("resources", {}) + + lookup_name = self.name + if lookup_name.endswith("-fb"): + lookup_name = lookup_name[:-3] + if lookup_name.startswith(LIVE_PREFIX): + lookup_name = lookup_name[len(LIVE_PREFIX) :] + + resource_config = resources.get(lookup_name) + if not resource_config: + return None + + return resource_config.get("module_path") + + except Exception: + return None + async def _do_deploy(self) -> "DeployableResource": """ Deploys the serverless resource using the provided configuration. - For queue-based endpoints that make remote calls, injects RUNPOD_API_KEY - into environment variables if not already set. + For queue-based endpoints that make remote calls, injects RUNPOD_API_KEY. + For load-balanced endpoints, injects FLASH_MODULE_PATH. Returns a DeployableResource object. """ @@ -604,6 +649,17 @@ async def _do_deploy(self) -> "DeployableResource": self.env = env_dict + # Inject module path for load-balanced endpoints + elif self.type == ServerlessType.LB: + env_dict = self.env or {} + + module_path = self._get_module_path() + if module_path and "FLASH_MODULE_PATH" not in env_dict: + env_dict["FLASH_MODULE_PATH"] = module_path + log.info(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}") + + self.env = env_dict + # Ensure network volume is deployed first await self._ensure_network_volume_deployed() diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index b5e3cbbe..124eb136 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -467,6 +467,32 @@ def test_is_deployed_false_when_no_id(self): assert serverless.is_deployed() is False + def test_is_deployed_skips_health_check_during_live_provisioning(self, monkeypatch): + """During flash run, is_deployed returns True based on ID alone.""" + monkeypatch.setenv("FLASH_IS_LIVE_PROVISIONING", "true") + serverless = ServerlessResource(name="test") + serverless.id = "ep-live-123" + + # health() must NOT be called — no mock needed, any call would raise + assert serverless.is_deployed() is True + + def test_is_deployed_uses_health_check_outside_live_provisioning(self, monkeypatch): + """Outside flash run, is_deployed falls back to health check.""" + monkeypatch.delenv("FLASH_IS_LIVE_PROVISIONING", raising=False) + serverless = ServerlessResource(name="test") + serverless.id = "ep-123" + + mock_endpoint = MagicMock() + mock_endpoint.health.return_value = {"workers": {}} + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + assert serverless.is_deployed() is True + mock_endpoint.health.assert_called_once() + @pytest.mark.asyncio async def test_deploy_already_deployed(self): """Test deploy returns early when already deployed.""" @@ -938,6 +964,24 @@ def test_is_deployed_with_exception(self): assert result is False + def test_payload_exclude_adds_template_when_template_id_set(self): + """_payload_exclude excludes template field when templateId is already set.""" + serverless = ServerlessResource(name="test") + serverless.templateId = "tmpl-123" + + excluded = serverless._payload_exclude() + + assert "template" in excluded + + def test_payload_exclude_does_not_exclude_template_without_template_id(self): + """_payload_exclude does not exclude template when templateId is absent.""" + serverless = ServerlessResource(name="test") + serverless.templateId = None + + excluded = serverless._payload_exclude() + + assert "template" not in excluded + def test_reverse_sync_from_backend_response(self): """Test reverse sync when receiving backend response with gpuIds.""" # This tests the lines 173-176 which convert gpuIds back to gpus list From c4065f30cd0e6fa7b7c7ea2fdddaf8c8fe954331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 12:54:59 -0800 Subject: [PATCH 07/25] fix(run): hot-reload regenerates server.py on route changes Parent process watches project .py files via watchfiles and regenerates .flash/server.py on change. Uvicorn now watches only .flash/server.py instead of the whole project, so it reloads exactly once per change with the updated routes visible. - Add _watch_and_regenerate() background thread using watchfiles - Change --reload-dir from '.' to '.flash', --reload-include to 'server.py' - Start watcher thread when reload=True, stop on KeyboardInterrupt/Exception - Add TestRunCommandHotReload and TestWatchAndRegenerate test classes --- src/runpod_flash/cli/commands/run.py | 52 ++++++- tests/unit/cli/test_run.py | 202 +++++++++++++++++++++++++++ 2 files changed, 252 insertions(+), 2 deletions(-) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index faf4f50b..86fceb59 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -5,6 +5,7 @@ import signal import subprocess import sys +import threading from dataclasses import dataclass, field from pathlib import Path from typing import List @@ -12,6 +13,8 @@ import typer from rich.console import Console from rich.table import Table +from watchfiles import DefaultFilter as _WatchfilesDefaultFilter +from watchfiles import watch as _watchfiles_watch from .build_utils.scanner import ( RemoteDecoratorScanner, @@ -361,6 +364,33 @@ def _is_reload() -> bool: return "UVICORN_RELOADER_PID" in os.environ +def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> None: + """Watch project .py files and regenerate server.py when they change. + + Ignores .flash/ to avoid reacting to our own writes. Runs until + stop_event is set. + """ + watch_filter = _WatchfilesDefaultFilter(ignore_paths=[str(project_root / ".flash")]) + + try: + for changes in _watchfiles_watch( + project_root, + watch_filter=watch_filter, + stop_event=stop_event, + ): + py_changed = [p for _, p in changes if p.endswith(".py")] + if not py_changed: + continue + try: + workers = _scan_project_workers(project_root) + _generate_flash_server(project_root, workers) + logger.debug("server.py regenerated (%d changed)", len(py_changed)) + except Exception as e: + logger.warning("Failed to regenerate server.py: %s", e) + except Exception: + pass # stop_event was set or watchfiles unavailable — both are fine + + def run_command( host: str = typer.Option( "localhost", @@ -435,11 +465,19 @@ def run_command( cmd += [ "--reload", "--reload-dir", - ".", + ".flash", "--reload-include", - "*.py", + "server.py", ] + stop_event = threading.Event() + watcher_thread = threading.Thread( + target=_watch_and_regenerate, + args=(project_root, stop_event), + daemon=True, + name="flash-watcher", + ) + process = None try: if sys.platform == "win32": @@ -449,11 +487,17 @@ def run_command( else: process = subprocess.Popen(cmd, preexec_fn=os.setsid) + if reload: + watcher_thread.start() + process.wait() except KeyboardInterrupt: console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") + stop_event.set() + watcher_thread.join(timeout=2) + if process: try: if sys.platform == "win32": @@ -479,6 +523,10 @@ def run_command( except Exception as e: console.print(f"[red]Error:[/red] {e}") + + stop_event.set() + watcher_thread.join(timeout=2) + if process: try: if sys.platform == "win32": diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index cf7eb5fd..1e0c549a 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -30,6 +30,12 @@ def temp_fastapi_app(tmp_path): class TestRunCommandEnvironmentVariables: """Test flash run command environment variable support.""" + @pytest.fixture(autouse=True) + def patch_watcher(self): + """Prevent the background watcher thread from blocking tests.""" + with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): + yield + def test_port_from_environment_variable( self, runner, temp_fastapi_app, monkeypatch ): @@ -221,3 +227,199 @@ def test_short_port_flag_overrides_environment( assert "--port" in call_args port_index = call_args.index("--port") assert call_args[port_index + 1] == "7000" + + +class TestRunCommandHotReload: + """Test flash run hot-reload behavior.""" + + @pytest.fixture(autouse=True) + def patch_watcher(self): + """Prevent the background watcher thread from blocking tests.""" + with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): + yield + + def _invoke_run(self, runner, monkeypatch, temp_fastapi_app, extra_args=None): + """Helper: invoke flash run and return the Popen call args.""" + monkeypatch.chdir(temp_fastapi_app) + monkeypatch.delenv("FLASH_PORT", raising=False) + monkeypatch.delenv("FLASH_HOST", raising=False) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + runner.invoke(app, ["run"] + (extra_args or [])) + + return mock_popen.call_args[0][0] + + def test_reload_watches_flash_server_py( + self, runner, temp_fastapi_app, monkeypatch + ): + """Uvicorn watches .flash/server.py, not the whole project.""" + cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) + + assert "--reload" in cmd + assert "--reload-dir" in cmd + reload_dir_index = cmd.index("--reload-dir") + assert cmd[reload_dir_index + 1] == ".flash" + + assert "--reload-include" in cmd + reload_include_index = cmd.index("--reload-include") + assert cmd[reload_include_index + 1] == "server.py" + + def test_reload_does_not_watch_project_root( + self, runner, temp_fastapi_app, monkeypatch + ): + """Uvicorn reload-dir must not be '.' to prevent double-reload.""" + cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) + + reload_dir_index = cmd.index("--reload-dir") + assert cmd[reload_dir_index + 1] != "." + + def test_no_reload_skips_watcher_thread( + self, runner, temp_fastapi_app, monkeypatch + ): + """--no-reload: neither uvicorn reload args nor watcher thread started.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + runner.invoke(app, ["run", "--no-reload"]) + + cmd = mock_popen.call_args[0][0] + assert "--reload" not in cmd + mock_thread.start.assert_not_called() + + def test_watcher_thread_started_on_reload( + self, runner, temp_fastapi_app, monkeypatch, patch_watcher + ): + """When reload=True, the background watcher thread is started.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + runner.invoke(app, ["run"]) + + mock_thread.start.assert_called_once() + + def test_watcher_thread_stopped_on_keyboard_interrupt( + self, runner, temp_fastapi_app, monkeypatch + ): + """KeyboardInterrupt sets stop_event and joins the watcher thread.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + with patch( + "runpod_flash.cli.commands.run.threading.Event" + ) as mock_event_cls: + mock_stop = MagicMock() + mock_event_cls.return_value = mock_stop + + runner.invoke(app, ["run"]) + + mock_stop.set.assert_called_once() + mock_thread.join.assert_called_once_with(timeout=2) + + +class TestWatchAndRegenerate: + """Unit tests for the _watch_and_regenerate background function.""" + + def test_regenerates_server_py_on_py_file_change(self, tmp_path): + """When a .py file changes, server.py is regenerated.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch( + "runpod_flash.cli.commands.run._scan_project_workers", return_value=[] + ) as mock_scan: + with patch( + "runpod_flash.cli.commands.run._generate_flash_server" + ) as mock_gen: + with patch( + "runpod_flash.cli.commands.run._watchfiles_watch" + ) as mock_watch: + # Yield one batch of changes then stop + mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) + stop.set() # ensures the loop exits after one iteration + _watch_and_regenerate(tmp_path, stop) + + mock_scan.assert_called_once_with(tmp_path) + mock_gen.assert_called_once() + + def test_ignores_non_py_changes(self, tmp_path): + """Changes to non-.py files do not trigger regeneration.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch("runpod_flash.cli.commands.run._scan_project_workers") as mock_scan: + with patch( + "runpod_flash.cli.commands.run._generate_flash_server" + ) as mock_gen: + with patch( + "runpod_flash.cli.commands.run._watchfiles_watch" + ) as mock_watch: + mock_watch.return_value = iter([{(1, "/path/to/README.md")}]) + _watch_and_regenerate(tmp_path, stop) + + mock_scan.assert_not_called() + mock_gen.assert_not_called() + + def test_scan_error_does_not_crash_watcher(self, tmp_path): + """If regeneration raises, the watcher logs a warning and continues.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch( + "runpod_flash.cli.commands.run._scan_project_workers", + side_effect=RuntimeError("scan failed"), + ): + with patch("runpod_flash.cli.commands.run._watchfiles_watch") as mock_watch: + mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) + # Should not raise + _watch_and_regenerate(tmp_path, stop) From d8e14b7f461d18a833c5f7aac6abc638aa37439b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 13:12:52 -0800 Subject: [PATCH 08/25] fix(run): suppress watchfiles debug logs from flash run output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit watchfiles emits DEBUG-level messages ("all changes filtered out", "rust notify timeout") that are correct behavior but should not be visible to users. Silence the watchfiles logger at WARNING in _watch_and_regenerate() — scoped to that namespace only. --- src/runpod_flash/cli/commands/run.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 86fceb59..57221c1a 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -370,6 +370,9 @@ def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> No Ignores .flash/ to avoid reacting to our own writes. Runs until stop_event is set. """ + # Suppress watchfiles' internal debug chatter (filter hits, rust timeouts). + logging.getLogger("watchfiles").setLevel(logging.WARNING) + watch_filter = _WatchfilesDefaultFilter(ignore_paths=[str(project_root / ".flash")]) try: From dd2491d04f7a2789fb28375c9275d3277e5f6910 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 14:25:49 -0800 Subject: [PATCH 09/25] fix(run): omit body param from GET/HEAD route handlers FastAPI treats `body: dict` as a required JSON body. GET/HEAD routes must be zero-arg so Swagger UI and browsers do not attempt to send a body, which triggers a fetch TypeError. Split the LB route code generator in _generate_flash_server() on method: get/head emit no-arg handlers; all other methods keep body: dict. --- src/runpod_flash/cli/commands/run.py | 20 ++++++++++----- tests/unit/cli/test_run.py | 38 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 57221c1a..e8732ad9 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -221,12 +221,20 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat handler_name = _sanitize_fn_name( f"_route_{worker.resource_name}_{fn_name}" ) - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict):", - f" return await {fn_name}(body)", - "", - ] + if method in ("get", "head"): + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}():", + f" return await {fn_name}()", + "", + ] + else: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f" return await {fn_name}(body)", + "", + ] # Health endpoints lines += [ diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index 1e0c549a..7fae486a 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -1,10 +1,12 @@ """Unit tests for run CLI command.""" import pytest +from pathlib import Path from unittest.mock import patch, MagicMock from typer.testing import CliRunner from runpod_flash.cli.main import app +from runpod_flash.cli.commands.run import WorkerInfo, _generate_flash_server @pytest.fixture @@ -423,3 +425,39 @@ def test_scan_error_does_not_crash_watcher(self, tmp_path): mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) # Should not raise _watch_and_regenerate(tmp_path, stop) + + +class TestGenerateFlashServer: + """Test _generate_flash_server() route code generation.""" + + def _make_lb_worker(self, tmp_path: Path, method: str) -> WorkerInfo: + return WorkerInfo( + file_path=tmp_path / "api.py", + url_prefix="/api", + module_path="api", + resource_name="api", + worker_type="LB", + functions=["list_routes"], + lb_routes=[ + {"method": method, "path": "/routes/list", "fn_name": "list_routes"} + ], + ) + + def test_get_route_has_no_body_param(self, tmp_path): + """GET handler must omit body: dict to satisfy FastAPI/browser constraints.""" + worker = self._make_lb_worker(tmp_path, "GET") + server_path = _generate_flash_server(tmp_path, [worker]) + content = server_path.read_text() + + # The GET handler must be zero-arg + assert "async def _route_api_list_routes():" in content + # No body parameter on any GET handler + assert "body: dict" not in content + + def test_post_route_keeps_body_param(self, tmp_path): + """POST handler must include body: dict for JSON request body.""" + worker = self._make_lb_worker(tmp_path, "POST") + server_path = _generate_flash_server(tmp_path, [worker]) + content = server_path.read_text() + + assert "async def _route_api_list_routes(body: dict):" in content From 9896c46b7a99ce832551575843e0ca6789442ca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 18:40:02 -0800 Subject: [PATCH 10/25] feat(run): proxy LB routes to deployed endpoints, restore --auto-provision LB route handlers were executing locally in the dev server process instead of forwarding to the deployed LB endpoint. The @remote decorator returns LB handlers unwrapped (passthrough) because in a deployed pod the body IS the HTTP handler, but in flash run there is no deployed pod. Changes: - Add _run_server_helpers.py with lb_proxy() that uses ResourceManager.get_or_deploy_resource() for on-demand provisioning and get_authenticated_httpx_client() for auth headers - Generate proxy handlers for all LB routes (any HTTP method) that forward requests to the deployed endpoint transparently - Import resource config variables (not function bodies) for LB workers so the actual DeployableResource object is passed to lb_proxy - Restore --auto-provision flag dropped in 35cfa6e, using existing ResourceDiscovery and DeploymentOrchestrator to provision all endpoints upfront and eliminate cold-start latency - Replace TestGenerateFlashServer tests with proxy-aware assertions --- .../cli/commands/_run_server_helpers.py | 84 +++++++++++ src/runpod_flash/cli/commands/run.py | 140 +++++++++++++++--- tests/unit/cli/test_run.py | 64 +++++--- 3 files changed, 249 insertions(+), 39 deletions(-) create mode 100644 src/runpod_flash/cli/commands/_run_server_helpers.py diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py new file mode 100644 index 00000000..44b6d5d3 --- /dev/null +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -0,0 +1,84 @@ +"""Helpers for the flash run dev server — loaded inside the generated server.py.""" + +import httpx +from fastapi import HTTPException, Request +from fastapi.responses import Response + +from runpod_flash.core.resources.base import DeployableResource +from runpod_flash.core.resources.resource_manager import ResourceManager +from runpod_flash.core.utils.http import get_authenticated_httpx_client + +_resource_manager = ResourceManager() + + +async def lb_proxy( + resource_config: DeployableResource, path_prefix: str, request: Request +) -> Response: + """Transparent HTTP proxy from flash run dev server to deployed LB endpoint. + + Uses ResourceManager.get_or_deploy_resource() to resolve the endpoint, + which handles provisioning, name prefixing, and caching automatically. + + Args: + resource_config: The resource config object (e.g. LiveLoadBalancer instance) + path_prefix: URL prefix used by the dev server (e.g. "/api") — stripped before proxying + request: The incoming FastAPI request to forward + + Returns: + FastAPI Response with upstream status code and body + + Raises: + HTTPException 503: Endpoint not deployed or has no ID + HTTPException 504: Upstream request timed out + HTTPException 502: Connection error reaching the upstream endpoint + """ + try: + deployed = await _resource_manager.get_or_deploy_resource(resource_config) + endpoint_url = deployed.endpoint_url + except ValueError as e: + raise HTTPException( + status_code=503, + detail=f"Endpoint '{resource_config.name}' not available: {e}", + ) + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Failed to provision '{resource_config.name}': {e}", + ) + + target_path = request.url.path + if path_prefix and target_path.startswith(path_prefix): + target_path = target_path[len(path_prefix) :] + if not target_path: + target_path = "/" + + target_url = endpoint_url.rstrip("/") + target_path + if request.url.query: + target_url += "?" + request.url.query + + body = await request.body() + skip_headers = {"host", "content-length", "transfer-encoding", "connection"} + headers = { + k: v for k, v in request.headers.items() if k.lower() not in skip_headers + } + + try: + async with get_authenticated_httpx_client(timeout=30.0) as client: + resp = await client.request( + request.method, target_url, content=body, headers=headers + ) + return Response( + content=resp.content, + status_code=resp.status_code, + media_type=resp.headers.get("content-type"), + ) + except httpx.TimeoutException: + raise HTTPException( + status_code=504, + detail=f"Timeout proxying to '{resource_config.name}'.", + ) + except httpx.RequestError as e: + raise HTTPException( + status_code=502, + detail=f"Connection error proxying to '{resource_config.name}': {e}", + ) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index e8732ad9..f6392990 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -94,6 +94,7 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: "method": f.http_method, "path": f.http_path, "fn_name": f.function_name, + "config_variable": f.config_variable, } for f in lb_funcs ] @@ -149,6 +150,8 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat _ensure_gitignore(project_root) + has_lb_workers = any(w.worker_type == "LB" for w in workers) + lines = [ '"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run."""', "import sys", @@ -156,15 +159,36 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat "from pathlib import Path", "sys.path.insert(0, str(Path(__file__).parent.parent))", "", - "from fastapi import FastAPI", - "", ] - # Collect all imports + if has_lb_workers: + lines += [ + "from fastapi import FastAPI, Request", + "from runpod_flash.cli.commands._run_server_helpers import lb_proxy as _lb_proxy", + "", + ] + else: + lines += [ + "from fastapi import FastAPI", + "", + ] + + # Collect imports — QB functions are called directly, LB config variables are + # passed to lb_proxy for on-demand provisioning via ResourceManager. all_imports: List[str] = [] for worker in workers: - for fn_name in worker.functions: - all_imports.append(f"from {worker.module_path} import {fn_name}") + if worker.worker_type == "QB": + for fn_name in worker.functions: + all_imports.append(f"from {worker.module_path} import {fn_name}") + elif worker.worker_type == "LB": + # Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...)) + config_vars = { + r["config_variable"] + for r in worker.lb_routes + if r.get("config_variable") + } + for var in sorted(config_vars): + all_imports.append(f"from {worker.module_path} import {var}") if all_imports: lines.extend(all_imports) @@ -217,24 +241,17 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat method = route["method"].lower() sub_path = route["path"].lstrip("/") fn_name = route["fn_name"] + config_var = route["config_variable"] full_path = f"{worker.url_prefix}/{sub_path}" handler_name = _sanitize_fn_name( f"_route_{worker.resource_name}_{fn_name}" ) - if method in ("get", "head"): - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}():", - f" return await {fn_name}()", - "", - ] - else: - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict):", - f" return await {fn_name}(body)", - "", - ] + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(request: Request):", + f" return await _lb_proxy({config_var}, {worker.url_prefix!r}, request)", + "", + ] # Health endpoints lines += [ @@ -402,6 +419,73 @@ def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> No pass # stop_event was set or watchfiles unavailable — both are fine +def _discover_resources(project_root: Path): + """Discover deployable resources in project files. + + Uses ResourceDiscovery to find all DeployableResource instances by + parsing @remote decorators and importing the referenced config variables. + + Args: + project_root: Root directory of the Flash project + + Returns: + List of discovered DeployableResource instances + """ + from ...core.discovery import ResourceDiscovery + + py_files = sorted( + p + for p in project_root.rglob("*.py") + if not any( + skip in p.parts + for skip in (".flash", ".venv", "venv", "__pycache__", ".git") + ) + ) + + resources = [] + for py_file in py_files: + try: + discovery = ResourceDiscovery(str(py_file), max_depth=0) + resources.extend(discovery.discover()) + except Exception as e: + logger.debug("Discovery failed for %s: %s", py_file, e) + + if resources: + console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") + for res in resources: + res_name = getattr(res, "name", "Unknown") + res_type = res.__class__.__name__ + console.print(f" [dim]- {res_name} ({res_type})[/dim]") + console.print() + + return resources + + +def _provision_resources(resources) -> None: + """Provision resources in parallel and wait for completion. + + Args: + resources: List of DeployableResource instances to provision + """ + import asyncio + + from ...core.deployment import DeploymentOrchestrator + + try: + console.print(f"[bold]Provisioning {len(resources)} resource(s)...[/bold]") + orchestrator = DeploymentOrchestrator(max_concurrent=3) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(orchestrator.deploy_all(resources, show_progress=True)) + loop.close() + except Exception as e: + console.print(f"[yellow]Warning:[/yellow] Provisioning failed: {e}") + console.print( + "[dim]Resources will be provisioned on-demand at first request.[/dim]" + ) + + def run_command( host: str = typer.Option( "localhost", @@ -419,6 +503,11 @@ def run_command( reload: bool = typer.Option( True, "--reload/--no-reload", help="Enable auto-reload" ), + auto_provision: bool = typer.Option( + False, + "--auto-provision", + help="Auto-provision all endpoints on startup (eliminates cold-start on first request)", + ), ): """Run Flash development server. @@ -434,6 +523,19 @@ def run_command( if not _is_reload(): os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true" + # Auto-provision all endpoints upfront (eliminates cold-start) + if auto_provision and not _is_reload(): + try: + resources = _discover_resources(project_root) + if resources: + _provision_resources(resources) + except Exception as e: + logger.error("Auto-provisioning failed", exc_info=True) + console.print(f"[yellow]Warning:[/yellow] Auto-provisioning failed: {e}") + console.print( + "[dim]Resources will be provisioned on-demand at first request.[/dim]" + ) + # Discover @remote functions workers = _scan_project_workers(project_root) diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index 7fae486a..2152b526 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -430,7 +430,7 @@ def test_scan_error_does_not_crash_watcher(self, tmp_path): class TestGenerateFlashServer: """Test _generate_flash_server() route code generation.""" - def _make_lb_worker(self, tmp_path: Path, method: str) -> WorkerInfo: + def _make_lb_worker(self, tmp_path: Path, method: str = "GET") -> WorkerInfo: return WorkerInfo( file_path=tmp_path / "api.py", url_prefix="/api", @@ -439,25 +439,49 @@ def _make_lb_worker(self, tmp_path: Path, method: str) -> WorkerInfo: worker_type="LB", functions=["list_routes"], lb_routes=[ - {"method": method, "path": "/routes/list", "fn_name": "list_routes"} + { + "method": method, + "path": "/routes/list", + "fn_name": "list_routes", + "config_variable": "api_config", + } ], ) - def test_get_route_has_no_body_param(self, tmp_path): - """GET handler must omit body: dict to satisfy FastAPI/browser constraints.""" - worker = self._make_lb_worker(tmp_path, "GET") - server_path = _generate_flash_server(tmp_path, [worker]) - content = server_path.read_text() - - # The GET handler must be zero-arg - assert "async def _route_api_list_routes():" in content - # No body parameter on any GET handler - assert "body: dict" not in content - - def test_post_route_keeps_body_param(self, tmp_path): - """POST handler must include body: dict for JSON request body.""" - worker = self._make_lb_worker(tmp_path, "POST") - server_path = _generate_flash_server(tmp_path, [worker]) - content = server_path.read_text() - - assert "async def _route_api_list_routes(body: dict):" in content + def test_lb_route_generates_proxy_handler(self, tmp_path): + """All LB routes (any method) generate a proxy handler, not a local call.""" + for method in ("GET", "POST", "DELETE", "PUT", "PATCH"): + worker = self._make_lb_worker(tmp_path, method) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "async def _route_api_list_routes(request: Request):" in content + assert "_lb_proxy(" in content + assert "body: dict" not in content + + def test_lb_config_variable_passed_to_proxy(self, tmp_path): + """The resource config variable is passed to lb_proxy, not a string name.""" + worker = self._make_lb_worker(tmp_path) + content = _generate_flash_server(tmp_path, [worker]).read_text() + # Config variable is passed as a Python reference, not a quoted string + assert "_lb_proxy(api_config," in content + assert "from api import api_config" in content + + def test_lb_proxy_import_present_when_lb_routes_exist(self, tmp_path): + """server.py imports _lb_proxy when there are LB workers.""" + worker = self._make_lb_worker(tmp_path) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "_lb_proxy" in content + assert "lb_proxy" in content + + def test_qb_function_still_imported_directly(self, tmp_path): + """QB workers still import and call functions directly.""" + worker = WorkerInfo( + file_path=tmp_path / "worker.py", + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "from worker import process" in content + assert "await process(" in content From 440b4b3a8abec1e44225bc3f702324ab31989ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 18:50:42 -0800 Subject: [PATCH 11/25] fix(run): add project root to sys.path during resource discovery ResourceDiscovery._import_module() uses importlib to execute each file, but cross-module imports (e.g. "from longruns.stage1 import ...") fail when the project root isn't on sys.path. This caused --auto-provision to silently skip LB endpoints whose files import from sibling packages. --- src/runpod_flash/cli/commands/run.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index f6392990..18509abe 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -442,13 +442,24 @@ def _discover_resources(project_root: Path): ) ) + # Add project root to sys.path so cross-module imports resolve + # (e.g. api/routes.py doing "from longruns.stage1 import stage1_process"). + root_str = str(project_root) + added_to_path = root_str not in sys.path + if added_to_path: + sys.path.insert(0, root_str) + resources = [] - for py_file in py_files: - try: - discovery = ResourceDiscovery(str(py_file), max_depth=0) - resources.extend(discovery.discover()) - except Exception as e: - logger.debug("Discovery failed for %s: %s", py_file, e) + try: + for py_file in py_files: + try: + discovery = ResourceDiscovery(str(py_file), max_depth=0) + resources.extend(discovery.discover()) + except Exception as e: + logger.debug("Discovery failed for %s: %s", py_file, e) + finally: + if added_to_path: + sys.path.remove(root_str) if resources: console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") From 1c285ce4f2843a27845fa73347d188a3747af292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 18:54:56 -0800 Subject: [PATCH 12/25] feat(run): show resource count and elapsed time during cleanup Cleanup on server stop now prints a summary line with undeployed count and wall-clock duration, matching the provisioning output format. --- src/runpod_flash/cli/commands/run.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 18509abe..f4d88b2d 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -356,19 +356,30 @@ def _cleanup_live_endpoints() -> None: if not live_items: return + import time + async def _do_cleanup(): + undeployed = 0 for key, resource in live_items.items(): name = getattr(resource, "name", key) try: success = await resource._do_undeploy() if success: console.print(f" Deprovisioned: {name}") + undeployed += 1 else: logger.warning(f"Failed to deprovision: {name}") except Exception as e: logger.warning(f"Error deprovisioning {name}: {e}") + return undeployed - asyncio.run(_do_cleanup()) + t0 = time.monotonic() + undeployed = asyncio.run(_do_cleanup()) + elapsed = time.monotonic() - t0 + console.print( + f" Cleanup completed: {undeployed}/{len(live_items)} " + f"resource(s) undeployed in {elapsed:.1f}s" + ) # Remove live- entries from persisted state so they don't linger. remaining = {k: v for k, v in resources.items() if k not in live_items} From 98d471f53bdd8279dfe77ba8d39bacf2b29a2b75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 18 Feb 2026 21:37:52 -0800 Subject: [PATCH 13/25] fix(run): route LB calls through LoadBalancerSlsStub instead of HTTP proxy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace lb_proxy (transparent HTTP forwarding) with lb_execute which uses LoadBalancerSlsStub's /execute dispatch path. This fixes 404s on CpuLiveLoadBalancer resources where the remote container has no user routes — only the /execute endpoint that accepts serialized function code. - Change isinstance check from LiveLoadBalancer to LiveServerlessMixin so all live resource types (including CpuLiveLoadBalancer) use /execute - Add explicit CpuLiveLoadBalancer singledispatch registration in registry - Generate server.py imports for both config var and function reference - Clean up redundant URL debug logs in resource_manager --- .../cli/commands/_run_server_helpers.py | 86 ++++++------------- src/runpod_flash/cli/commands/run.py | 10 ++- src/runpod_flash/core/api/runpod.py | 4 +- .../resources/load_balancer_sls_resource.py | 4 +- .../core/resources/resource_manager.py | 4 - src/runpod_flash/stubs/load_balancer_sls.py | 6 +- src/runpod_flash/stubs/registry.py | 26 ++++++ tests/unit/cli/test_run.py | 22 ++--- 8 files changed, 74 insertions(+), 88 deletions(-) diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py index 44b6d5d3..8de16732 100644 --- a/src/runpod_flash/cli/commands/_run_server_helpers.py +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -1,84 +1,46 @@ """Helpers for the flash run dev server — loaded inside the generated server.py.""" -import httpx from fastapi import HTTPException, Request -from fastapi.responses import Response -from runpod_flash.core.resources.base import DeployableResource from runpod_flash.core.resources.resource_manager import ResourceManager -from runpod_flash.core.utils.http import get_authenticated_httpx_client +from runpod_flash.stubs.load_balancer_sls import LoadBalancerSlsStub _resource_manager = ResourceManager() -async def lb_proxy( - resource_config: DeployableResource, path_prefix: str, request: Request -) -> Response: - """Transparent HTTP proxy from flash run dev server to deployed LB endpoint. +async def lb_execute(resource_config, func, request: Request): + """Execute LB function on deployed endpoint via LoadBalancerSlsStub. - Uses ResourceManager.get_or_deploy_resource() to resolve the endpoint, - which handles provisioning, name prefixing, and caching automatically. - - Args: - resource_config: The resource config object (e.g. LiveLoadBalancer instance) - path_prefix: URL prefix used by the dev server (e.g. "/api") — stripped before proxying - request: The incoming FastAPI request to forward - - Returns: - FastAPI Response with upstream status code and body - - Raises: - HTTPException 503: Endpoint not deployed or has no ID - HTTPException 504: Upstream request timed out - HTTPException 502: Connection error reaching the upstream endpoint + Uses the same /execute dispatch path that works on main — provisions + the endpoint, serializes the function via cloudpickle, and POSTs to + /execute on the deployed container. """ try: deployed = await _resource_manager.get_or_deploy_resource(resource_config) - endpoint_url = deployed.endpoint_url - except ValueError as e: - raise HTTPException( - status_code=503, - detail=f"Endpoint '{resource_config.name}' not available: {e}", - ) except Exception as e: raise HTTPException( status_code=503, detail=f"Failed to provision '{resource_config.name}': {e}", ) - target_path = request.url.path - if path_prefix and target_path.startswith(path_prefix): - target_path = target_path[len(path_prefix) :] - if not target_path: - target_path = "/" - - target_url = endpoint_url.rstrip("/") + target_path - if request.url.query: - target_url += "?" + request.url.query + stub = LoadBalancerSlsStub(deployed) - body = await request.body() - skip_headers = {"host", "content-length", "transfer-encoding", "connection"} - headers = { - k: v for k, v in request.headers.items() if k.lower() not in skip_headers - } + # Parse HTTP request into function kwargs + if request.method in ("POST", "PUT", "PATCH"): + try: + kwargs = await request.json() + if not isinstance(kwargs, dict): + kwargs = {"input": kwargs} + except Exception: + kwargs = {} + else: + kwargs = dict(request.query_params) try: - async with get_authenticated_httpx_client(timeout=30.0) as client: - resp = await client.request( - request.method, target_url, content=body, headers=headers - ) - return Response( - content=resp.content, - status_code=resp.status_code, - media_type=resp.headers.get("content-type"), - ) - except httpx.TimeoutException: - raise HTTPException( - status_code=504, - detail=f"Timeout proxying to '{resource_config.name}'.", - ) - except httpx.RequestError as e: - raise HTTPException( - status_code=502, - detail=f"Connection error proxying to '{resource_config.name}': {e}", - ) + return await stub(func, None, None, False, **kwargs) + except TimeoutError as e: + raise HTTPException(status_code=504, detail=str(e)) + except ConnectionError as e: + raise HTTPException(status_code=502, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index f4d88b2d..1c5b117c 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -164,7 +164,7 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat if has_lb_workers: lines += [ "from fastapi import FastAPI, Request", - "from runpod_flash.cli.commands._run_server_helpers import lb_proxy as _lb_proxy", + "from runpod_flash.cli.commands._run_server_helpers import lb_execute as _lb_execute", "", ] else: @@ -173,8 +173,8 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat "", ] - # Collect imports — QB functions are called directly, LB config variables are - # passed to lb_proxy for on-demand provisioning via ResourceManager. + # Collect imports — QB functions are called directly, LB config variables and + # functions are passed to lb_execute for dispatch via LoadBalancerSlsStub. all_imports: List[str] = [] for worker in workers: if worker.worker_type == "QB": @@ -189,6 +189,8 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat } for var in sorted(config_vars): all_imports.append(f"from {worker.module_path} import {var}") + for fn_name in worker.functions: + all_imports.append(f"from {worker.module_path} import {fn_name}") if all_imports: lines.extend(all_imports) @@ -249,7 +251,7 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat lines += [ f'@app.{method}("{full_path}", tags=["{tag}"])', f"async def {handler_name}(request: Request):", - f" return await _lb_proxy({config_var}, {worker.url_prefix!r}, request)", + f" return await _lb_execute({config_var}, {fn_name}, request)", "", ] diff --git a/src/runpod_flash/core/api/runpod.py b/src/runpod_flash/core/api/runpod.py index bc30219a..478428b8 100644 --- a/src/runpod_flash/core/api/runpod.py +++ b/src/runpod_flash/core/api/runpod.py @@ -202,7 +202,7 @@ async def save_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: variables = {"input": input_data} - log.debug(f"Saving endpoint with GraphQL: {input_data.get('name', 'unnamed')}") + log.debug(f"GraphQL saveEndpoint: {input_data.get('name', 'unnamed')}") result = await self._execute_graphql(mutation, variables) @@ -211,7 +211,7 @@ async def save_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: endpoint_data = result["saveEndpoint"] log.debug( - f"Saved endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}" + f"GraphQL response: {endpoint_data.get('id', 'unknown')} ({endpoint_data.get('name', 'unnamed')})" ) return endpoint_data diff --git a/src/runpod_flash/core/resources/load_balancer_sls_resource.py b/src/runpod_flash/core/resources/load_balancer_sls_resource.py index eb664ed0..df84d622 100644 --- a/src/runpod_flash/core/resources/load_balancer_sls_resource.py +++ b/src/runpod_flash/core/resources/load_balancer_sls_resource.py @@ -259,10 +259,10 @@ async def _do_deploy(self) -> "LoadBalancerSlsResource": self.env["FLASH_IS_MOTHERSHIP"] = "true" # Call parent deploy (creates endpoint via RunPod API) - log.debug(f"Deploying LB endpoint {self.name}...") + log.info(f"Deploying LB endpoint: {self.name}") deployed = await super()._do_deploy() - log.debug(f"LB endpoint {self.name} ({deployed.id}) deployed successfully") + log.info(f"Deployed: {self.name} ({deployed.url})") return deployed except Exception as e: diff --git a/src/runpod_flash/core/resources/resource_manager.py b/src/runpod_flash/core/resources/resource_manager.py index 0cd18f51..82eebe2b 100644 --- a/src/runpod_flash/core/resources/resource_manager.py +++ b/src/runpod_flash/core/resources/resource_manager.py @@ -245,7 +245,6 @@ async def get_or_deploy_resource( deployed_resource = await self._deploy_with_error_context( config ) - log.debug(f"URL: {deployed_resource.url}") self._add_resource(resource_key, deployed_resource) return deployed_resource except Exception: @@ -278,7 +277,6 @@ async def get_or_deploy_resource( deployed_resource = await self._deploy_with_error_context( config ) - log.debug(f"URL: {deployed_resource.url}") self._add_resource(resource_key, deployed_resource) return deployed_resource except Exception: @@ -292,13 +290,11 @@ async def get_or_deploy_resource( raise # Config unchanged, reuse existing - log.info(f"URL: {existing.url}") return existing # No existing resource, deploy new one try: deployed_resource = await self._deploy_with_error_context(config) - log.debug(f"URL: {deployed_resource.url}") self._add_resource(resource_key, deployed_resource) return deployed_resource except Exception: diff --git a/src/runpod_flash/stubs/load_balancer_sls.py b/src/runpod_flash/stubs/load_balancer_sls.py index d30a9ce5..f44bccbb 100644 --- a/src/runpod_flash/stubs/load_balancer_sls.py +++ b/src/runpod_flash/stubs/load_balancer_sls.py @@ -75,10 +75,10 @@ def _should_use_execute_endpoint(self, func: Callable[..., Any]) -> bool: Returns: True if /execute should be used, False if user route should be used """ - from ..core.resources.live_serverless import LiveLoadBalancer + from ..core.resources.live_serverless import LiveServerlessMixin - # Always use /execute for LiveLoadBalancer (local development) - if isinstance(self.server, LiveLoadBalancer): + # Always use /execute for live resources (local development) + if isinstance(self.server, LiveServerlessMixin): log.debug(f"Using /execute endpoint for LiveLoadBalancer: {func.__name__}") return True diff --git a/src/runpod_flash/stubs/registry.py b/src/runpod_flash/stubs/registry.py index 674e0085..bbea9243 100644 --- a/src/runpod_flash/stubs/registry.py +++ b/src/runpod_flash/stubs/registry.py @@ -3,6 +3,7 @@ from functools import singledispatch from ..core.resources import ( + CpuLiveLoadBalancer, CpuLiveServerless, CpuServerlessEndpoint, LiveLoadBalancer, @@ -209,3 +210,28 @@ async def stubbed_resource( ) return stubbed_resource + + +@stub_resource.register(CpuLiveLoadBalancer) +def _(resource, **extra): + """Create stub for CpuLiveLoadBalancer (HTTP-based execution, local testing).""" + stub = LoadBalancerSlsStub(resource) + + async def stubbed_resource( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) -> dict: + return await stub( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) + + return stubbed_resource diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index 2152b526..8e92ff0c 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -448,29 +448,29 @@ def _make_lb_worker(self, tmp_path: Path, method: str = "GET") -> WorkerInfo: ], ) - def test_lb_route_generates_proxy_handler(self, tmp_path): - """All LB routes (any method) generate a proxy handler, not a local call.""" + def test_lb_route_generates_execute_handler(self, tmp_path): + """All LB routes (any method) generate a stub-based execute handler.""" for method in ("GET", "POST", "DELETE", "PUT", "PATCH"): worker = self._make_lb_worker(tmp_path, method) content = _generate_flash_server(tmp_path, [worker]).read_text() assert "async def _route_api_list_routes(request: Request):" in content - assert "_lb_proxy(" in content + assert "_lb_execute(" in content assert "body: dict" not in content - def test_lb_config_variable_passed_to_proxy(self, tmp_path): - """The resource config variable is passed to lb_proxy, not a string name.""" + def test_lb_config_and_function_passed_to_execute(self, tmp_path): + """Both config variable and function are passed to lb_execute.""" worker = self._make_lb_worker(tmp_path) content = _generate_flash_server(tmp_path, [worker]).read_text() - # Config variable is passed as a Python reference, not a quoted string - assert "_lb_proxy(api_config," in content + assert "_lb_execute(api_config, list_routes, request)" in content assert "from api import api_config" in content + assert "from api import list_routes" in content - def test_lb_proxy_import_present_when_lb_routes_exist(self, tmp_path): - """server.py imports _lb_proxy when there are LB workers.""" + def test_lb_execute_import_present_when_lb_routes_exist(self, tmp_path): + """server.py imports _lb_execute when there are LB workers.""" worker = self._make_lb_worker(tmp_path) content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "_lb_proxy" in content - assert "lb_proxy" in content + assert "_lb_execute" in content + assert "lb_execute" in content def test_qb_function_still_imported_directly(self, tmp_path): """QB workers still import and call functions directly.""" From 76e6927d533a008ab98dde742f459125cb4f0099 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 19 Feb 2026 08:18:20 -0800 Subject: [PATCH 14/25] fix(run): revert LB to remote dispatch, remove QB /run route MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restore lb_execute to dispatch through LoadBalancerSlsStub instead of calling functions locally — LB resources require Live Serverless containers and cannot execute on a local machine. Keep _map_body_to_params and body: dict signatures for OpenAPI/Swagger compatibility while dispatching remotely via the stub's /execute path. Remove /run from generated QB routes, retaining only /run_sync since the dev server executes synchronously. --- .../cli/commands/_run_server_helpers.py | 50 +++++++---- src/runpod_flash/cli/commands/run.py | 41 +++++---- tests/unit/cli/test_run.py | 85 +++++++++++++++++-- 3 files changed, 128 insertions(+), 48 deletions(-) diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py index 8de16732..70391bbd 100644 --- a/src/runpod_flash/cli/commands/_run_server_helpers.py +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -1,6 +1,8 @@ """Helpers for the flash run dev server — loaded inside the generated server.py.""" -from fastapi import HTTPException, Request +import inspect + +from fastapi import HTTPException from runpod_flash.core.resources.resource_manager import ResourceManager from runpod_flash.stubs.load_balancer_sls import LoadBalancerSlsStub @@ -8,12 +10,36 @@ _resource_manager = ResourceManager() -async def lb_execute(resource_config, func, request: Request): - """Execute LB function on deployed endpoint via LoadBalancerSlsStub. +def _map_body_to_params(func, body): + """Map an HTTP request body to function parameters. - Uses the same /execute dispatch path that works on main — provisions - the endpoint, serializes the function via cloudpickle, and POSTs to - /execute on the deployed container. + If the body is a dict whose keys match the function's parameter names, + spread it as kwargs. Otherwise pass the whole body as the value of the + first parameter (mirrors how FastAPI maps a JSON body to a single param). + """ + sig = inspect.signature(func) + param_names = set(sig.parameters.keys()) + + if isinstance(body, dict) and body.keys() <= param_names: + return body + + first_param = next(iter(sig.parameters), None) + if first_param is None: + return {} + return {first_param: body} + + +async def lb_execute(resource_config, func, body: dict): + """Dispatch an LB route to the deployed endpoint via LoadBalancerSlsStub. + + Provisions the endpoint via ResourceManager, maps the HTTP body to + function kwargs, then dispatches through the stub's /execute path + which serializes the function via cloudpickle to the remote container. + + Args: + resource_config: The resource config object (e.g. LiveLoadBalancer instance). + func: The @remote LB route handler function. + body: Parsed request body (from FastAPI's automatic JSON parsing). """ try: deployed = await _resource_manager.get_or_deploy_resource(resource_config) @@ -24,17 +50,7 @@ async def lb_execute(resource_config, func, request: Request): ) stub = LoadBalancerSlsStub(deployed) - - # Parse HTTP request into function kwargs - if request.method in ("POST", "PUT", "PATCH"): - try: - kwargs = await request.json() - if not isinstance(kwargs, dict): - kwargs = {"input": kwargs} - except Exception: - kwargs = {} - else: - kwargs = dict(request.query_params) + kwargs = _map_body_to_params(func, body) try: return await stub(func, None, None, False, **kwargs) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 1c5b117c..8bc94a22 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -213,11 +213,9 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat if worker.worker_type == "QB": if len(worker.functions) == 1: fn = worker.functions[0] - handler_name = _sanitize_fn_name(f"{worker.resource_name}_run") - run_path = f"{worker.url_prefix}/run" + handler_name = _sanitize_fn_name(f"{worker.resource_name}_run_sync") sync_path = f"{worker.url_prefix}/run_sync" lines += [ - f'@app.post("{run_path}", tags=["{tag}"])', f'@app.post("{sync_path}", tags=["{tag}"])', f"async def {handler_name}(body: dict):", f' result = await {fn}(body.get("input", body))', @@ -226,11 +224,11 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat ] else: for fn in worker.functions: - handler_name = _sanitize_fn_name(f"{worker.resource_name}_{fn}_run") - run_path = f"{worker.url_prefix}/{fn}/run" + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{fn}_run_sync" + ) sync_path = f"{worker.url_prefix}/{fn}/run_sync" lines += [ - f'@app.post("{run_path}", tags=["{tag}"])', f'@app.post("{sync_path}", tags=["{tag}"])', f"async def {handler_name}(body: dict):", f' result = await {fn}(body.get("input", body))', @@ -248,12 +246,21 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat handler_name = _sanitize_fn_name( f"_route_{worker.resource_name}_{fn_name}" ) - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(request: Request):", - f" return await _lb_execute({config_var}, {fn_name}, request)", - "", - ] + has_body = method in ("post", "put", "patch", "delete") + if has_body: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f" return await _lb_execute({config_var}, {fn_name}, body)", + "", + ] + else: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(request: Request):", + f" return await _lb_execute({config_var}, {fn_name}, dict(request.query_params))", + "", + ] # Health endpoints lines += [ @@ -286,11 +293,6 @@ def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> Non for worker in workers: if worker.worker_type == "QB": if len(worker.functions) == 1: - table.add_row( - f"POST {worker.url_prefix}/run", - worker.resource_name, - "QB", - ) table.add_row( f"POST {worker.url_prefix}/run_sync", worker.resource_name, @@ -298,11 +300,6 @@ def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> Non ) else: for fn in worker.functions: - table.add_row( - f"POST {worker.url_prefix}/{fn}/run", - worker.resource_name, - "QB", - ) table.add_row( f"POST {worker.url_prefix}/{fn}/run_sync", worker.resource_name, diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index 8e92ff0c..d13abd12 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -448,20 +448,28 @@ def _make_lb_worker(self, tmp_path: Path, method: str = "GET") -> WorkerInfo: ], ) - def test_lb_route_generates_execute_handler(self, tmp_path): - """All LB routes (any method) generate a stub-based execute handler.""" - for method in ("GET", "POST", "DELETE", "PUT", "PATCH"): + def test_post_lb_route_generates_body_param(self, tmp_path): + """POST/PUT/PATCH/DELETE LB routes use body: dict for OpenAPI docs.""" + for method in ("POST", "PUT", "PATCH", "DELETE"): worker = self._make_lb_worker(tmp_path, method) content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "async def _route_api_list_routes(request: Request):" in content - assert "_lb_execute(" in content - assert "body: dict" not in content + assert "async def _route_api_list_routes(body: dict):" in content + assert "_lb_execute(api_config, list_routes, body)" in content - def test_lb_config_and_function_passed_to_execute(self, tmp_path): - """Both config variable and function are passed to lb_execute.""" + def test_get_lb_route_uses_query_params(self, tmp_path): + """GET LB routes pass query params as a dict.""" + worker = self._make_lb_worker(tmp_path, "GET") + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "async def _route_api_list_routes(request: Request):" in content + assert ( + "_lb_execute(api_config, list_routes, dict(request.query_params))" + in content + ) + + def test_lb_config_var_and_function_imported(self, tmp_path): + """LB config vars and functions are both imported for remote dispatch.""" worker = self._make_lb_worker(tmp_path) content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "_lb_execute(api_config, list_routes, request)" in content assert "from api import api_config" in content assert "from api import list_routes" in content @@ -485,3 +493,62 @@ def test_qb_function_still_imported_directly(self, tmp_path): content = _generate_flash_server(tmp_path, [worker]).read_text() assert "from worker import process" in content assert "await process(" in content + + +class TestMapBodyToParams: + """Tests for _map_body_to_params — maps HTTP body to function arguments.""" + + def test_body_keys_match_params_spreads_as_kwargs(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def process(name: str, value: int): + pass + + result = _map_body_to_params(process, {"name": "test", "value": 42}) + assert result == {"name": "test", "value": 42} + + def test_body_keys_mismatch_wraps_in_first_param(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def run_pipeline(input_data: dict): + pass + + body = {"text": "hello", "mode": "fast"} + result = _map_body_to_params(run_pipeline, body) + assert result == {"input_data": {"text": "hello", "mode": "fast"}} + + def test_non_dict_body_wraps_in_first_param(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def run_pipeline(input_data): + pass + + result = _map_body_to_params(run_pipeline, [1, 2, 3]) + assert result == {"input_data": [1, 2, 3]} + + def test_no_params_returns_empty(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def no_args(): + pass + + result = _map_body_to_params(no_args, {"key": "val"}) + assert result == {} + + def test_partial_key_match_wraps_in_first_param(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def process(name: str, value: int): + pass + + result = _map_body_to_params(process, {"name": "test", "extra": "bad"}) + assert result == {"name": {"name": "test", "extra": "bad"}} + + def test_empty_dict_body_spreads_as_empty_kwargs(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def run_pipeline(input_data: dict): + pass + + result = _map_body_to_params(run_pipeline, {}) + assert result == {} From e0874e08a9ec846a9e0332d261b3cb736c992d2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 19 Feb 2026 12:26:44 -0800 Subject: [PATCH 15/25] refactor(init): simplify skeleton to flat worker files for flash run Replace the old multi-directory skeleton (main.py, mothership.py, workers/) with three flat files: gpu_worker.py, cpu_worker.py, and lb_worker.py. flash run auto-discovers @remote functions so the FastAPI boilerplate and router structure are no longer needed. - Remove main.py, mothership.py, workers/, .ruff_cache from skeleton - Add gpu_worker.py (QB GPU), cpu_worker.py (QB CPU), lb_worker.py (LB) - Simplify pyproject.toml deps (drop fastapi/uvicorn) - Add .flash/ to .gitignore - Rewrite README with uv setup, QB/LB examples, GpuType reference - Update init command panel output and next steps - Add Ctrl+C cleanup hint to flash run startup output - Update skeleton tests for new file structure --- src/runpod_flash/cli/commands/init.py | 89 +++-- src/runpod_flash/cli/commands/run.py | 5 +- .../cli/utils/skeleton_template/.gitignore | 1 + .../cli/utils/skeleton_template/README.md | 304 +++++++----------- .../cli/utils/skeleton_template/cpu_worker.py | 17 + .../cli/utils/skeleton_template/gpu_worker.py | 27 ++ .../cli/utils/skeleton_template/lb_worker.py | 24 ++ .../cli/utils/skeleton_template/main.py | 44 --- .../cli/utils/skeleton_template/mothership.py | 55 ---- .../utils/skeleton_template/pyproject.toml | 47 +-- .../skeleton_template/workers/__init__.py | 0 .../skeleton_template/workers/cpu/__init__.py | 19 -- .../skeleton_template/workers/cpu/endpoint.py | 36 --- .../skeleton_template/workers/gpu/__init__.py | 19 -- .../skeleton_template/workers/gpu/endpoint.py | 61 ---- tests/unit/test_skeleton.py | 51 ++- 16 files changed, 273 insertions(+), 526 deletions(-) create mode 100644 src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py create mode 100644 src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py create mode 100644 src/runpod_flash/cli/utils/skeleton_template/lb_worker.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/main.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/mothership.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/workers/__init__.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py delete mode 100644 src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py diff --git a/src/runpod_flash/cli/commands/init.py b/src/runpod_flash/cli/commands/init.py index 15a96d3d..eabd7583 100644 --- a/src/runpod_flash/cli/commands/init.py +++ b/src/runpod_flash/cli/commands/init.py @@ -5,6 +5,8 @@ import typer from rich.console import Console +from rich.panel import Panel +from rich.table import Table from ..utils.skeleton import create_project_skeleton, detect_file_conflicts @@ -19,70 +21,99 @@ def init_command( ): """Create new Flash project with Flash Server and GPU workers.""" + # Determine target directory and initialization mode if project_name is None or project_name == ".": + # Initialize in current directory project_dir = Path.cwd() is_current_dir = True + # Use current directory name as project name actual_project_name = project_dir.name else: + # Create new directory project_dir = Path(project_name) is_current_dir = False actual_project_name = project_name + # Create project directory if needed if not is_current_dir: project_dir.mkdir(parents=True, exist_ok=True) + # Check for file conflicts in target directory conflicts = detect_file_conflicts(project_dir) - should_overwrite = force + should_overwrite = force # Start with force flag value if conflicts and not force: + # Show warning and prompt user console.print( - "[yellow]Warning:[/yellow] The following files will be overwritten:\n" + Panel( + "[yellow]Warning: The following files will be overwritten:[/yellow]\n\n" + + "\n".join(f" • {conflict}" for conflict in conflicts), + title="File Conflicts Detected", + expand=False, + ) ) - for conflict in conflicts: - console.print(f" {conflict}") - console.print() + # Prompt user for confirmation proceed = typer.confirm("Continue and overwrite these files?", default=False) if not proceed: - console.print("[yellow]Cancelled[/yellow]") + console.print("[yellow]Initialization aborted.[/yellow]") raise typer.Exit(0) + # User confirmed, so we should overwrite should_overwrite = True + # Create project skeleton status_msg = ( - "Initializing Flash project..." + "Initializing Flash project in current directory..." if is_current_dir else f"Creating Flash project '{project_name}'..." ) with console.status(status_msg): create_project_skeleton(project_dir, should_overwrite) - console.print(f"[green]Created[/green] [bold]{actual_project_name}[/bold]\n") - - prefix = "./" if is_current_dir else f"{actual_project_name}/" - console.print(f" {prefix}") - console.print(" ├── main.py FastAPI server") - console.print(" ├── mothership.py Mothership config") - console.print(" ├── pyproject.toml") - console.print(" ├── workers/") - console.print(" │ ├── gpu/") - console.print(" │ └── cpu/") - console.print(" ├── .env.example") - console.print(" ├── requirements.txt") - console.print(" └── README.md") - + # Success output + if is_current_dir: + panel_content = f"Flash project '[bold]{actual_project_name}[/bold]' initialized in current directory!\n\n" + panel_content += "Project structure:\n" + panel_content += " ./\n" + else: + panel_content = f"Flash project '[bold]{actual_project_name}[/bold]' created successfully!\n\n" + panel_content += "Project structure:\n" + panel_content += f" {actual_project_name}/\n" + + panel_content += " ├── gpu_worker.py # GPU serverless worker\n" + panel_content += " ├── cpu_worker.py # CPU serverless worker\n" + panel_content += " ├── lb_worker.py # CPU load-balanced API\n" + panel_content += " ├── pyproject.toml\n" + panel_content += " ├── .env.example\n" + panel_content += " ├── requirements.txt\n" + panel_content += " └── README.md\n" + + title = "Project Initialized" if is_current_dir else "Project Created" + console.print(Panel(panel_content, title=title, expand=False)) + + # Next steps console.print("\n[bold]Next steps:[/bold]") + steps_table = Table(show_header=False, box=None, padding=(0, 1)) + steps_table.add_column("Step", style="bold cyan") + steps_table.add_column("Description") + step_num = 1 if not is_current_dir: - console.print(f" {step_num}. cd {actual_project_name}") + steps_table.add_row(f"{step_num}.", f"cd {actual_project_name}") step_num += 1 - console.print(f" {step_num}. pip install -r requirements.txt") + + steps_table.add_row(f"{step_num}.", "pip install -r requirements.txt") + step_num += 1 + steps_table.add_row(f"{step_num}.", "cp .env.example .env") step_num += 1 - console.print(f" {step_num}. cp .env.example .env && add RUNPOD_API_KEY") + steps_table.add_row(f"{step_num}.", "Add your RUNPOD_API_KEY to .env") step_num += 1 - console.print(f" {step_num}. flash run") + steps_table.add_row(f"{step_num}.", "flash run") - console.print( - "\n [dim]API keys: https://docs.runpod.io/get-started/api-keys[/dim]" - ) - console.print(" [dim]Docs: http://localhost:8888/docs (after running)[/dim]") + console.print(steps_table) + + console.print("\n[bold]Get your API key:[/bold]") + console.print(" https://docs.runpod.io/get-started/api-keys") + console.print("\nVisit http://localhost:8888/docs after running") + console.print("\nCheck out the README.md for more") diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 8bc94a22..b3000a20 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -316,7 +316,10 @@ def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> Non ) console.print(table) - console.print(f"\n Visit [bold]http://{host}:{port}/docs[/bold] for Swagger UI\n") + console.print(f"\n Visit [bold]http://{host}:{port}/docs[/bold] for Swagger UI") + console.print( + " Press [bold]Ctrl+C[/bold] to stop — provisioned endpoints are cleaned up automatically\n" + ) def _cleanup_live_endpoints() -> None: diff --git a/src/runpod_flash/cli/utils/skeleton_template/.gitignore b/src/runpod_flash/cli/utils/skeleton_template/.gitignore index f0673581..0e3b93d7 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/.gitignore +++ b/src/runpod_flash/cli/utils/skeleton_template/.gitignore @@ -36,6 +36,7 @@ wheels/ .env.local # Flash +.flash/ .runpod/ dist/ diff --git a/src/runpod_flash/cli/utils/skeleton_template/README.md b/src/runpod_flash/cli/utils/skeleton_template/README.md index 6c4801e5..f30adf00 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/README.md +++ b/src/runpod_flash/cli/utils/skeleton_template/README.md @@ -1,243 +1,176 @@ # {{project_name}} -Flash application demonstrating distributed GPU and CPU computing on Runpod's serverless infrastructure. - -## About This Template - -This project was generated using `flash init`. The `{{project_name}}` placeholder is automatically replaced with your actual project name during initialization. +Runpod Flash application with GPU and CPU workers on Runpod serverless infrastructure. ## Quick Start -### 1. Install Dependencies +Install [uv](https://docs.astral.sh/uv/getting-started/installation/) (recommended Python package manager): ```bash -pip install -r requirements.txt +curl -LsSf https://astral.sh/uv/install.sh | sh ``` -### 2. Configure Environment - -Create `.env` file: +Set up the project: ```bash -RUNPOD_API_KEY=your_api_key_here +uv venv && source .venv/bin/activate +uv sync +cp .env.example .env # Add your RUNPOD_API_KEY +flash run ``` -Get your API key from [Runpod Settings](https://www.runpod.io/console/user/settings). - -### 3. Run Locally +Or with pip: ```bash -# Standard run +python -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt +cp .env.example .env # Add your RUNPOD_API_KEY flash run - -# Faster development: pre-provision endpoints (eliminates cold-start delays) -flash run --auto-provision ``` -Server starts at **http://localhost:8000** +Server starts at **http://localhost:8888**. Visit **http://localhost:8888/docs** for interactive Swagger UI. -With `--auto-provision`, all serverless endpoints deploy before testing begins. This is much faster for development because endpoints are cached and reused across server restarts. Subsequent runs skip deployment and start immediately. +Use `flash run --auto-provision` to pre-deploy all endpoints on startup, eliminating cold-start delays on first request. Provisioned endpoints are cached and reused across restarts. -### 4. Test the API +When you stop the server with Ctrl+C, all endpoints provisioned during the session are automatically cleaned up. -```bash -# Health check -curl http://localhost:8000/ping +Get your API key from [Runpod Settings](https://www.runpod.io/console/user/settings). +Learn more about it from our [Documentation](https://docs.runpod.io/get-started/api-keys). + +## Test the API -# GPU worker -curl -X POST http://localhost:8000/gpu/hello \ +```bash +# Queue-based GPU worker +curl -X POST http://localhost:8888/gpu_worker/run_sync \ -H "Content-Type: application/json" \ -d '{"message": "Hello GPU!"}' -# CPU worker -curl -X POST http://localhost:8000/cpu/hello \ +# Queue-based CPU worker +curl -X POST http://localhost:8888/cpu_worker/run_sync \ -H "Content-Type: application/json" \ -d '{"message": "Hello CPU!"}' -``` - -Visit **http://localhost:8000/docs** for interactive API documentation. -## What This Demonstrates - -### GPU Worker (`workers/gpu/`) -Simple GPU-based serverless function: -- Remote execution with `@remote` decorator -- GPU resource configuration -- Automatic scaling (0-3 workers) -- No external dependencies required - -```python -@remote( - resource_config=LiveServerless( - name="gpu_worker", - gpus=[GpuGroup.ADA_24], # RTX 4090 - workersMin=0, - workersMax=3, - ) -) -async def gpu_hello(input_data: dict) -> dict: - # Your GPU code here - return {"status": "success", "message": "Hello from GPU!"} -``` - -### CPU Worker (`workers/cpu/`) -Simple CPU-based serverless function: -- CPU-only execution (no GPU overhead) -- CpuLiveServerless configuration -- Efficient for API endpoints -- Automatic scaling (0-5 workers) +# Load-balanced HTTP endpoint +curl -X POST http://localhost:8888/lb_worker/process \ + -H "Content-Type: application/json" \ + -d '{"input": "test"}' -```python -@remote( - resource_config=CpuLiveServerless( - name="cpu_worker", - instanceIds=[CpuInstanceType.CPU3G_2_8], # 2 vCPU, 8GB RAM - workersMin=0, - workersMax=5, - ) -) -async def cpu_hello(input_data: dict) -> dict: - # Your CPU code here - return {"status": "success", "message": "Hello from CPU!"} +# Load-balanced health check +curl http://localhost:8888/lb_worker/health ``` ## Project Structure ``` {{project_name}}/ -├── main.py # FastAPI application -├── workers/ -│ ├── gpu/ # GPU worker -│ │ ├── __init__.py # FastAPI router -│ │ └── endpoint.py # @remote decorated function -│ └── cpu/ # CPU worker -│ ├── __init__.py # FastAPI router -│ └── endpoint.py # @remote decorated function -├── .env # Environment variables -├── requirements.txt # Dependencies -└── README.md # This file +├── gpu_worker.py # GPU serverless worker (queue-based) +├── cpu_worker.py # CPU serverless worker (queue-based) +├── lb_worker.py # CPU load-balanced HTTP endpoint +├── .env.example # Environment variable template +├── requirements.txt # Python dependencies +└── README.md ``` -## Key Concepts - -### Remote Execution -The `@remote` decorator transparently executes functions on serverless infrastructure: -- Code runs locally during development -- Automatically deploys to Runpod when configured -- Handles serialization, dependencies, and resource management - -### Resource Scaling -Both workers scale to zero when idle to minimize costs: -- **idleTimeout**: Seconds before scaling down (default: 60) -- **workersMin**: 0 = completely scales to zero -- **workersMax**: Maximum concurrent workers - -### GPU Types -Available GPU options for `LiveServerless`: -- `GpuGroup.ADA_24` - RTX 4090 (24GB) -- `GpuGroup.ADA_48_PRO` - RTX 6000 Ada, L40 (48GB) -- `GpuGroup.AMPERE_80` - A100 (80GB) -- `GpuGroup.ANY` - Any available GPU - -### CPU Types -Available CPU options for `CpuLiveServerless`: -- `CpuInstanceType.CPU3G_2_8` - 2 vCPU, 8GB RAM (General Purpose) -- `CpuInstanceType.CPU3C_4_8` - 4 vCPU, 8GB RAM (Compute Optimized) -- `CpuInstanceType.CPU5G_4_16` - 4 vCPU, 16GB RAM (Latest Gen) -- `CpuInstanceType.ANY` - Any available GPU - -## Development Workflow - -### Test Workers Locally -```bash -# Test GPU worker -python -m workers.gpu.endpoint +## Worker Types -# Test CPU worker -python -m workers.cpu.endpoint -``` +### Queue-Based (QB) Workers -### Run the Application -```bash -flash run -``` - -### Deploy to Production -```bash -# Build and deploy in one step -flash deploy - -# Or deploy to a specific environment -flash deploy --env production -``` - -## Adding New Workers - -### Add a GPU Worker +QB workers process jobs from a queue. Each call to `/run_sync` sends a job and waits +for the result. Use QB for compute-heavy tasks that may take seconds to minutes. -1. Create `workers/my_worker/endpoint.py`: +**gpu_worker.py** — GPU serverless function: ```python -from runpod_flash import remote, LiveServerless +from runpod_flash import GpuType, LiveServerless, remote -config = LiveServerless(name="my_worker") +gpu_config = LiveServerless( + name="gpu_worker", + gpus=[GpuType.ANY], +) -@remote(resource_config=config, dependencies=["torch"]) -async def my_function(data: dict) -> dict: +@remote(resource_config=gpu_config, dependencies=["torch"]) +async def gpu_hello(input_data: dict) -> dict: import torch - # Your code here - return {"result": "success"} + gpu_available = torch.cuda.is_available() + gpu_name = torch.cuda.get_device_name(0) if gpu_available else "No GPU detected" + return {"message": gpu_name} ``` -2. Create `workers/my_worker/__init__.py`: +**cpu_worker.py** — CPU serverless function: ```python -from fastapi import APIRouter -from .endpoint import my_function +from runpod_flash import CpuLiveServerless, remote -router = APIRouter() +cpu_config = CpuLiveServerless(name="cpu_worker") -@router.post("/process") -async def handler(data: dict): - return await my_function(data) +@remote(resource_config=cpu_config) +async def cpu_hello(input_data: dict = {}) -> dict: + return {"message": "Hello from CPU!"} + input_data ``` -3. Add to `main.py`: -```python -from workers.my_worker import router as my_router -app.include_router(my_router, prefix="/my_worker") -``` +### Load-Balanced (LB) Workers -### Add a CPU Worker +LB workers expose standard HTTP endpoints (GET, POST, etc.) behind a load balancer. +Use LB for low-latency API endpoints that need horizontal scaling. -Same pattern but use `CpuLiveServerless`: +**lb_worker.py** — HTTP endpoints on a load-balanced container: ```python -from runpod_flash import remote, CpuLiveServerless, CpuInstanceType +from runpod_flash import CpuLiveLoadBalancer, remote -config = CpuLiveServerless( - name="my_cpu_worker", - instanceIds=[CpuInstanceType.CPU3G_2_8] +api_config = CpuLiveLoadBalancer( + name="lb_worker", + workersMin=1, ) -@remote(resource_config=config, dependencies=["requests"]) -async def fetch_data(url: str) -> dict: - import requests - return requests.get(url).json() +@remote(resource_config=api_config, method="POST", path="/process") +async def process(input_data: dict) -> dict: + return {"status": "success", "echo": input_data} + +@remote(resource_config=api_config, method="GET", path="/health") +async def health() -> dict: + return {"status": "healthy"} ``` -## Adding Dependencies +## Adding New Workers + +Create a new `.py` file with a `@remote` function. `flash run` auto-discovers all +`@remote` functions in the project. -Specify dependencies in the `@remote` decorator: ```python -@remote( - resource_config=config, - dependencies=["torch>=2.0.0", "transformers"], # Python packages - system_dependencies=["ffmpeg"] # System packages -) -async def my_function(data: dict) -> dict: - # Dependencies are automatically installed - import torch - import transformers +# my_worker.py +from runpod_flash import LiveServerless, GpuType, remote + +config = LiveServerless(name="my_worker", gpus=[GpuType.NVIDIA_GEFORCE_RTX_4090]) + +@remote(resource_config=config, dependencies=["transformers"]) +async def predict(input_data: dict) -> dict: + from transformers import pipeline + pipe = pipeline("sentiment-analysis") + return pipe(input_data["text"])[0] ``` +Then run `flash run` — the new worker appears automatically. + +## GPU Types + +| Config | Hardware | VRAM | +|--------|----------|------| +| `GpuType.ANY` | Any available GPU | varies | +| `GpuType.NVIDIA_GEFORCE_RTX_4090` | RTX 4090 | 24 GB | +| `GpuType.NVIDIA_GEFORCE_RTX_5090` | RTX 5090 | 32 GB | +| `GpuType.NVIDIA_RTX_6000_ADA_GENERATION` | RTX 6000 Ada | 48 GB | +| `GpuType.NVIDIA_L4` | L4 | 24 GB | +| `GpuType.NVIDIA_A100_80GB_PCIe` | A100 PCIe | 80 GB | +| `GpuType.NVIDIA_A100_SXM4_80GB` | A100 SXM4 | 80 GB | +| `GpuType.NVIDIA_H100_80GB_HBM3` | H100 | 80 GB | +| `GpuType.NVIDIA_H200` | H200 | 141 GB | + +## CPU Types + +| Config | vCPU | RAM | +|--------|------|-----| +| `CpuInstanceType.CPU3G_2_8` | 2 | 8 GB | +| `CpuInstanceType.CPU3C_4_8` | 4 | 8 GB | +| `CpuInstanceType.CPU5G_4_16` | 4 | 16 GB | + ## Environment Variables ```bash @@ -245,16 +178,13 @@ async def my_function(data: dict) -> dict: RUNPOD_API_KEY=your_api_key # Optional -FLASH_HOST=localhost # Host to bind the server to (default: localhost) -FLASH_PORT=8888 # Port to bind the server to (default: 8888) -LOG_LEVEL=INFO # Logging level (default: INFO) +FLASH_HOST=localhost # Server host (default: localhost) +FLASH_PORT=8888 # Server port (default: 8888) +LOG_LEVEL=INFO # Logging level (default: INFO) ``` -## Next Steps +## Deploy -- Add your ML models or processing logic -- Configure GPU/CPU resources based on your needs -- Add authentication to your endpoints -- Implement error handling and retries -- Add monitoring and logging -- Deploy to production with `flash deploy` +```bash +flash deploy +``` diff --git a/src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py b/src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py new file mode 100644 index 00000000..aee4b5a3 --- /dev/null +++ b/src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py @@ -0,0 +1,17 @@ +from runpod_flash import CpuLiveServerless, remote + +cpu_config = CpuLiveServerless(name="cpu_worker") + + +@remote(resource_config=cpu_config) +async def cpu_hello(input_data: dict) -> dict: + """CPU worker — lightweight processing without GPU.""" + import platform + from datetime import datetime + + return { + "message": input_data.get("message", "Hello from CPU worker!"), + "timestamp": datetime.now().isoformat(), + "platform": platform.system(), + "python_version": platform.python_version(), + } diff --git a/src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py b/src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py new file mode 100644 index 00000000..d787e1e9 --- /dev/null +++ b/src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py @@ -0,0 +1,27 @@ +from runpod_flash import GpuType, LiveServerless, remote + +gpu_config = LiveServerless( + name="gpu_worker", + gpus=[GpuType.ANY], +) + + +@remote(resource_config=gpu_config, dependencies=["torch"]) +async def gpu_hello(input_data: dict) -> dict: + """GPU worker — detects available GPU hardware.""" + import platform + + try: + import torch + + gpu_available = torch.cuda.is_available() + gpu_name = torch.cuda.get_device_name(0) if gpu_available else "No GPU detected" + except Exception as e: + gpu_available = False + gpu_name = f"Error: {e}" + + return { + "message": input_data.get("message", "Hello from GPU worker!"), + "gpu": {"available": gpu_available, "name": gpu_name}, + "python_version": platform.python_version(), + } diff --git a/src/runpod_flash/cli/utils/skeleton_template/lb_worker.py b/src/runpod_flash/cli/utils/skeleton_template/lb_worker.py new file mode 100644 index 00000000..1b40ed0c --- /dev/null +++ b/src/runpod_flash/cli/utils/skeleton_template/lb_worker.py @@ -0,0 +1,24 @@ +from runpod_flash import CpuLiveLoadBalancer, remote + +api_config = CpuLiveLoadBalancer( + name="lb_worker", + workersMin=1, +) + + +@remote(resource_config=api_config, method="POST", path="/process") +async def process(input_data: dict) -> dict: + """Process input data on a load-balanced CPU endpoint.""" + from datetime import datetime + + return { + "status": "success", + "echo": input_data, + "timestamp": datetime.now().isoformat(), + } + + +@remote(resource_config=api_config, method="GET", path="/health") +async def health() -> dict: + """Health check for the load-balanced endpoint.""" + return {"status": "healthy"} diff --git a/src/runpod_flash/cli/utils/skeleton_template/main.py b/src/runpod_flash/cli/utils/skeleton_template/main.py deleted file mode 100644 index ad3ce717..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/main.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -import os - -from fastapi import FastAPI - -from workers.cpu import cpu_router -from workers.gpu import gpu_router - -logger = logging.getLogger(__name__) - - -app = FastAPI( - title="Flash Application", - description="Distributed GPU and CPU computing with Runpod Flash", - version="0.1.0", -) - -# Include routers -app.include_router(gpu_router, prefix="/gpu", tags=["GPU Workers"]) -app.include_router(cpu_router, prefix="/cpu", tags=["CPU Workers"]) - - -@app.get("/") -def home(): - return { - "message": "Flash Application", - "docs": "/docs", - "endpoints": {"gpu_hello": "/gpu/hello", "cpu_hello": "/cpu/hello"}, - } - - -@app.get("/ping") -def ping(): - return {"status": "healthy"} - - -if __name__ == "__main__": - import uvicorn - - host = os.getenv("FLASH_HOST", "localhost") - port = int(os.getenv("FLASH_PORT", 8888)) - logger.info(f"Starting Flash server on {host}:{port}") - - uvicorn.run(app, host=host, port=port) diff --git a/src/runpod_flash/cli/utils/skeleton_template/mothership.py b/src/runpod_flash/cli/utils/skeleton_template/mothership.py deleted file mode 100644 index 85779bfc..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/mothership.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Mothership Endpoint Configuration - -The mothership endpoint serves your FastAPI application routes. -It is automatically deployed as a CPU-optimized load-balanced endpoint. - -To customize this configuration: -- Modify worker scaling: change workersMin and workersMax values -- Use GPU load balancer: import LiveLoadBalancer instead of CpuLiveLoadBalancer -- Change endpoint name: update the 'name' parameter - -To disable mothership deployment: -- Delete this file, or -- Comment out the 'mothership' variable below - -Documentation: https://docs.runpod.io/flash/mothership -""" - -from runpod_flash import CpuLiveLoadBalancer - -# Mothership endpoint configuration -# This serves your FastAPI app routes from main.py -mothership = CpuLiveLoadBalancer( - name="mothership", - workersMin=1, - workersMax=1, -) - -# Examples of customization: - -# Increase scaling for high traffic -# mothership = CpuLiveLoadBalancer( -# name="mothership", -# workersMin=2, -# workersMax=10, -# ) - -# Use GPU-based load balancer instead of CPU -# (requires importing LiveLoadBalancer) -# from runpod_flash import LiveLoadBalancer -# mothership = LiveLoadBalancer( -# name="mothership", -# gpus=[GpuGroup.ANY], -# ) - -# Custom endpoint name -# mothership = CpuLiveLoadBalancer( -# name="my-api-gateway", -# workersMin=1, -# workersMax=1, -# ) - -# To disable mothership: -# - Delete this entire file, or -# - Comment out the 'mothership' variable above diff --git a/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml b/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml index 7987ad22..a58ae558 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml +++ b/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml @@ -5,54 +5,9 @@ build-backend = "setuptools.build_meta" [project] name = "{{project_name}}" version = "0.1.0" -description = "Flash serverless application" +description = "Runpod Flash Serverless Application" readme = "README.md" requires-python = ">=3.11" dependencies = [ "runpod-flash", - "fastapi>=0.104.0", - "uvicorn>=0.24.0", -] - -[project.optional-dependencies] -dev = [ - "pytest>=7.0", - "pytest-asyncio>=0.21", - "pytest-cov>=4.0", - "ruff>=0.1", - "mypy>=1.0", -] - -[tool.ruff] -line-length = 100 -target-version = "py311" - -[tool.ruff.lint] -select = ["E", "F", "I", "N", "W"] -ignore = ["E501"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py", "*_test.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" - -[tool.mypy] -python_version = "3.11" -warn_return_any = false -warn_unused_configs = true -disallow_untyped_defs = false - -[tool.coverage.run] -source = ["src"] -omit = ["*/tests/*"] - -[tool.coverage.report] -exclude_lines = [ - "pragma: no cover", - "def __repr__", - "raise AssertionError", - "raise NotImplementedError", - "if __name__ == .__main__.:", ] diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/__init__.py b/src/runpod_flash/cli/utils/skeleton_template/workers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py b/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py deleted file mode 100644 index aef10a1a..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter -from pydantic import BaseModel - -from .endpoint import cpu_hello - -cpu_router = APIRouter() - - -class MessageRequest(BaseModel): - """Request model for CPU worker.""" - - message: str = "Hello from CPU!" - - -@cpu_router.post("/hello") -async def hello(request: MessageRequest): - """Simple CPU worker endpoint.""" - result = await cpu_hello({"message": request.message}) - return result diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py b/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py deleted file mode 100644 index 8161e5a7..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py +++ /dev/null @@ -1,36 +0,0 @@ -from runpod_flash import CpuLiveServerless, remote - -cpu_config = CpuLiveServerless( - name="cpu_worker", - workersMin=0, - workersMax=1, - idleTimeout=60, -) - - -@remote(resource_config=cpu_config) -async def cpu_hello(input_data: dict) -> dict: - """Simple CPU worker example.""" - import platform - from datetime import datetime - - message = input_data.get("message", "Hello from CPU worker!") - - return { - "status": "success", - "message": message, - "worker_type": "CPU", - "timestamp": datetime.now().isoformat(), - "platform": platform.system(), - "python_version": platform.python_version(), - } - - -# Test locally with: python -m workers.cpu.endpoint -if __name__ == "__main__": - import asyncio - - test_payload = {"message": "Testing CPU worker"} - print(f"Testing CPU worker with payload: {test_payload}") - result = asyncio.run(cpu_hello(test_payload)) - print(f"Result: {result}") diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py b/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py deleted file mode 100644 index a6a3caad..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter -from pydantic import BaseModel - -from .endpoint import gpu_hello - -gpu_router = APIRouter() - - -class MessageRequest(BaseModel): - """Request model for GPU worker.""" - - message: str = "Hello from GPU!" - - -@gpu_router.post("/hello") -async def hello(request: MessageRequest): - """Simple GPU worker endpoint.""" - result = await gpu_hello({"message": request.message}) - return result diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py b/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py deleted file mode 100644 index f3c4466c..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py +++ /dev/null @@ -1,61 +0,0 @@ -from runpod_flash import GpuGroup, LiveServerless, remote - -gpu_config = LiveServerless( - name="gpu_worker", - gpus=[GpuGroup.ANY], - workersMin=0, - workersMax=1, - idleTimeout=60, -) - - -@remote(resource_config=gpu_config, dependencies=["torch"]) -async def gpu_hello(input_data: dict) -> dict: - """Simple GPU worker example with GPU detection.""" - import platform - from datetime import datetime - - try: - import torch - - gpu_available = torch.cuda.is_available() - if gpu_available: - gpu_name = torch.cuda.get_device_name(0) - gpu_count = torch.cuda.device_count() - gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) - else: - gpu_name = "No GPU detected" - gpu_count = 0 - gpu_memory = 0 - except Exception as e: - gpu_available = False - gpu_name = f"Error detecting GPU: {str(e)}" - gpu_count = 0 - gpu_memory = 0 - - message = input_data.get("message", "Hello from GPU worker!") - - return { - "status": "success", - "message": message, - "worker_type": "GPU", - "gpu_info": { - "available": gpu_available, - "name": gpu_name, - "count": gpu_count, - "memory_gb": round(gpu_memory, 2) if gpu_memory else 0, - }, - "timestamp": datetime.now().isoformat(), - "platform": platform.system(), - "python_version": platform.python_version(), - } - - -# Test locally with: python -m workers.gpu.endpoint -if __name__ == "__main__": - import asyncio - - test_payload = {"message": "Testing GPU worker"} - print(f"Testing GPU worker with payload: {test_payload}") - result = asyncio.run(gpu_hello(test_payload)) - print(f"Result: {result}") diff --git a/tests/unit/test_skeleton.py b/tests/unit/test_skeleton.py index 0c3962b4..ea788bd6 100644 --- a/tests/unit/test_skeleton.py +++ b/tests/unit/test_skeleton.py @@ -85,13 +85,13 @@ def test_detect_no_conflicts_empty_directory(self, tmp_path): def test_detect_conflict_with_existing_file(self, tmp_path): """Test that existing files are detected as conflicts.""" # Create a file that exists in the template - (tmp_path / "main.py").write_text("# existing file") + (tmp_path / "gpu_worker.py").write_text("# existing file") conflicts = detect_file_conflicts(tmp_path) - # Should detect main.py as a conflict + # Should detect gpu_worker.py as a conflict conflict_names = [str(c) for c in conflicts] - assert "main.py" in conflict_names + assert "gpu_worker.py" in conflict_names def test_detect_conflict_with_hidden_file(self, tmp_path): """Test that existing hidden files are detected as conflicts.""" @@ -138,7 +138,9 @@ def test_create_skeleton_in_empty_directory(self, tmp_path): assert len(created_files) > 0 # Check that key files exist - assert (tmp_path / "main.py").exists() + assert (tmp_path / "gpu_worker.py").exists() + assert (tmp_path / "cpu_worker.py").exists() + assert (tmp_path / "lb_worker.py").exists() assert (tmp_path / "README.md").exists() assert (tmp_path / "requirements.txt").exists() @@ -147,13 +149,6 @@ def test_create_skeleton_in_empty_directory(self, tmp_path): assert (tmp_path / ".gitignore").exists() assert (tmp_path / ".flashignore").exists() - # Check that workers directory structure exists - assert (tmp_path / "workers").is_dir() - assert (tmp_path / "workers" / "cpu").is_dir() - assert (tmp_path / "workers" / "gpu").is_dir() - assert (tmp_path / "workers" / "cpu" / "__init__.py").exists() - assert (tmp_path / "workers" / "gpu" / "__init__.py").exists() - def test_create_skeleton_with_project_name_substitution(self, tmp_path): """Test that {{project_name}} placeholder is replaced.""" project_dir = tmp_path / "my_test_project" @@ -169,14 +164,14 @@ def test_create_skeleton_with_project_name_substitution(self, tmp_path): def test_create_skeleton_skips_existing_files_without_force(self, tmp_path): """Test that existing files are not overwritten without force flag.""" # Create an existing file with specific content - existing_content = "# This is my custom main.py" - (tmp_path / "main.py").write_text(existing_content) + existing_content = "# This is my custom gpu_worker.py" + (tmp_path / "gpu_worker.py").write_text(existing_content) # Create skeleton without force create_project_skeleton(tmp_path, force=False) # Existing file should not be overwritten - assert (tmp_path / "main.py").read_text() == existing_content + assert (tmp_path / "gpu_worker.py").read_text() == existing_content # But other files should be created assert (tmp_path / ".env.example").exists() @@ -184,16 +179,16 @@ def test_create_skeleton_skips_existing_files_without_force(self, tmp_path): def test_create_skeleton_overwrites_with_force(self, tmp_path): """Test that existing files are overwritten with force=True.""" # Create an existing file - existing_content = "# This is my custom main.py" - (tmp_path / "main.py").write_text(existing_content) + existing_content = "# This is my custom gpu_worker.py" + (tmp_path / "gpu_worker.py").write_text(existing_content) # Create skeleton with force create_project_skeleton(tmp_path, force=True) # Existing file should be overwritten - new_content = (tmp_path / "main.py").read_text() + new_content = (tmp_path / "gpu_worker.py").read_text() assert new_content != existing_content - assert "# This is my custom main.py" not in new_content + assert "# This is my custom gpu_worker.py" not in new_content def test_create_skeleton_ignores_pycache(self, tmp_path): """Test that __pycache__ directories are not copied.""" @@ -225,7 +220,7 @@ def test_create_skeleton_creates_parent_directories(self, tmp_path): # All parent directories should exist assert project_dir.exists() - assert (project_dir / "main.py").exists() + assert (project_dir / "gpu_worker.py").exists() def test_create_skeleton_returns_created_files_list(self, tmp_path): """Test that function returns list of created files.""" @@ -236,14 +231,14 @@ def test_create_skeleton_returns_created_files_list(self, tmp_path): assert all(isinstance(f, str) for f in created_files) # Should contain expected files - assert "main.py" in created_files + assert "gpu_worker.py" in created_files assert ".env.example" in created_files assert "README.md" in created_files def test_create_skeleton_handles_readonly_files_gracefully(self, tmp_path): """Test handling of read-only files during creation.""" # Create a read-only file - readonly_file = tmp_path / "main.py" + readonly_file = tmp_path / "gpu_worker.py" readonly_file.write_text("# readonly") readonly_file.chmod(0o444) @@ -287,7 +282,9 @@ def test_full_init_workflow_in_place(self, tmp_path): # Verify all expected files exist expected_files = [ - "main.py", + "gpu_worker.py", + "cpu_worker.py", + "lb_worker.py", "README.md", "requirements.txt", ".env.example", @@ -297,14 +294,10 @@ def test_full_init_workflow_in_place(self, tmp_path): for filename in expected_files: assert (tmp_path / filename).exists(), f"{filename} should exist" - # Verify workers structure - assert (tmp_path / "workers" / "cpu" / "endpoint.py").exists() - assert (tmp_path / "workers" / "gpu" / "endpoint.py").exists() - def test_full_init_workflow_with_conflicts(self, tmp_path): """Test complete workflow when conflicts exist.""" # Create some existing files - (tmp_path / "main.py").write_text("# my custom main") + (tmp_path / "gpu_worker.py").write_text("# my custom worker") (tmp_path / ".env.example").write_text("MY_VAR=123") # Detect conflicts @@ -312,14 +305,14 @@ def test_full_init_workflow_with_conflicts(self, tmp_path): assert len(conflicts) == 2 conflict_names = [str(c) for c in conflicts] - assert "main.py" in conflict_names + assert "gpu_worker.py" in conflict_names assert ".env.example" in conflict_names # Create skeleton without force (should preserve existing) create_project_skeleton(tmp_path, force=False) # Check that existing files were preserved - assert (tmp_path / "main.py").read_text() == "# my custom main" + assert (tmp_path / "gpu_worker.py").read_text() == "# my custom worker" assert (tmp_path / ".env.example").read_text() == "MY_VAR=123" # But new files should be created From bed42ef4daca7e0832aaa7604c3c451a98a84ea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 19 Feb 2026 15:45:28 -0800 Subject: [PATCH 16/25] fix(run): handle numeric-prefixed directories in server codegen Directory names starting with digits (e.g. 01_getting_started/) produce invalid Python when used in import statements and function names. - Add _flash_import helper to generated server.py that uses importlib.import_module() with scoped sys.path so sibling imports (e.g. `from cpu_worker import ...`) resolve to the correct directory - Prefix generated function names with '_' when they start with a digit - Scope sys.path per-import to prevent name collisions when multiple directories contain files with the same name (e.g. cpu_worker.py) --- src/runpod_flash/cli/commands/run.py | 88 +++++++++- tests/unit/cli/test_run.py | 251 ++++++++++++++++++++++++++- 2 files changed, 332 insertions(+), 7 deletions(-) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index b3000a20..bf3bbd74 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -131,8 +131,53 @@ def _ensure_gitignore(project_root: Path) -> None: def _sanitize_fn_name(name: str) -> str: - """Sanitize a string for use as a Python function name.""" - return name.replace("/", "_").replace(".", "_").replace("-", "_") + """Sanitize a string for use as a Python function name. + + Replaces non-identifier characters with underscores and prepends '_' + if the result starts with a digit (Python identifiers cannot start + with digits). + """ + result = name.replace("/", "_").replace(".", "_").replace("-", "_") + if result and result[0].isdigit(): + result = "_" + result + return result + + +def _has_numeric_module_segments(module_path: str) -> bool: + """Check if any segment in a dotted module path starts with a digit. + + Python identifiers cannot start with digits, so ``from 01_foo import bar`` + is a SyntaxError. Callers should use ``importlib.import_module()`` instead. + """ + return any(seg and seg[0].isdigit() for seg in module_path.split(".")) + + +def _module_parent_subdir(module_path: str) -> str | None: + """Return the parent sub-directory for a dotted module path, or None for top-level. + + Example: ``01_getting_started.03_mixed.pipeline`` → ``01_getting_started/03_mixed`` + """ + parts = module_path.rsplit(".", 1) + if len(parts) == 1: + return None + return parts[0].replace(".", "/") + + +def _make_import_line(module_path: str, name: str) -> str: + """Build an import statement for *name* from *module_path*. + + Uses a regular ``from … import …`` when the module path is a valid + Python identifier chain. Falls back to ``_flash_import()`` (a generated + helper in server.py) when any segment starts with a digit. The helper + temporarily scopes ``sys.path`` so sibling imports in the target module + resolve to the correct directory. + """ + if _has_numeric_module_segments(module_path): + subdir = _module_parent_subdir(module_path) + if subdir: + return f'{name} = _flash_import("{module_path}", "{name}", "{subdir}")' + return f'{name} = _flash_import("{module_path}", "{name}")' + return f"from {module_path} import {name}" def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: @@ -157,10 +202,41 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat "import sys", "import uuid", "from pathlib import Path", - "sys.path.insert(0, str(Path(__file__).parent.parent))", + "_project_root = Path(__file__).parent.parent", + "sys.path.insert(0, str(_project_root))", "", ] + # When modules live in directories with numeric prefixes (e.g. 01_hello/), + # we cannot use ``from … import …`` — Python identifiers cannot start with + # digits. Instead we emit a small ``_flash_import`` helper that uses + # ``importlib.import_module()`` *and* temporarily scopes ``sys.path`` so + # that sibling imports inside the loaded module (e.g. ``from cpu_worker + # import …``) resolve to the correct directory rather than a same-named + # file from a different example subdirectory. + needs_importlib = any(_has_numeric_module_segments(w.module_path) for w in workers) + + if needs_importlib: + lines += [ + "import importlib as _importlib", + "", + "", + "def _flash_import(module_path, name, subdir=None):", + ' """Import *name* from *module_path* with scoped sys.path for sibling imports."""', + " _path = str(_project_root / subdir) if subdir else None", + " if _path:", + " sys.path.insert(0, _path)", + " try:", + " return getattr(_importlib.import_module(module_path), name)", + " finally:", + " if _path:", + " try:", + " sys.path.remove(_path)", + " except ValueError:", + " pass", + "", + ] + if has_lb_workers: lines += [ "from fastapi import FastAPI, Request", @@ -179,7 +255,7 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat for worker in workers: if worker.worker_type == "QB": for fn_name in worker.functions: - all_imports.append(f"from {worker.module_path} import {fn_name}") + all_imports.append(_make_import_line(worker.module_path, fn_name)) elif worker.worker_type == "LB": # Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...)) config_vars = { @@ -188,9 +264,9 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat if r.get("config_variable") } for var in sorted(config_vars): - all_imports.append(f"from {worker.module_path} import {var}") + all_imports.append(_make_import_line(worker.module_path, var)) for fn_name in worker.functions: - all_imports.append(f"from {worker.module_path} import {fn_name}") + all_imports.append(_make_import_line(worker.module_path, fn_name)) if all_imports: lines.extend(all_imports) diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index d13abd12..6014e9a3 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -6,7 +6,14 @@ from typer.testing import CliRunner from runpod_flash.cli.main import app -from runpod_flash.cli.commands.run import WorkerInfo, _generate_flash_server +from runpod_flash.cli.commands.run import ( + WorkerInfo, + _generate_flash_server, + _has_numeric_module_segments, + _make_import_line, + _module_parent_subdir, + _sanitize_fn_name, +) @pytest.fixture @@ -495,6 +502,248 @@ def test_qb_function_still_imported_directly(self, tmp_path): assert "await process(" in content +class TestSanitizeFnName: + """Test _sanitize_fn_name handles leading-digit identifiers.""" + + def test_normal_name_unchanged(self): + assert _sanitize_fn_name("worker_run_sync") == "worker_run_sync" + + def test_leading_digit_gets_underscore_prefix(self): + assert _sanitize_fn_name("01_hello_run_sync") == "_01_hello_run_sync" + + def test_slashes_replaced(self): + assert _sanitize_fn_name("a/b/c") == "a_b_c" + + def test_dots_and_hyphens_replaced(self): + assert _sanitize_fn_name("a.b-c") == "a_b_c" + + def test_numeric_after_slash(self): + assert _sanitize_fn_name("01_foo/02_bar") == "_01_foo_02_bar" + + +class TestHasNumericModuleSegments: + """Test _has_numeric_module_segments detects digit-prefixed segments.""" + + def test_normal_module_path(self): + assert _has_numeric_module_segments("worker") is False + + def test_dotted_normal(self): + assert _has_numeric_module_segments("longruns.stage1") is False + + def test_leading_digit_first_segment(self): + assert _has_numeric_module_segments("01_hello.worker") is True + + def test_leading_digit_nested_segment(self): + assert _has_numeric_module_segments("getting_started.01_hello.worker") is True + + def test_digit_in_middle_not_leading(self): + assert _has_numeric_module_segments("stage1.worker") is False + + +class TestModuleParentSubdir: + """Test _module_parent_subdir extracts parent directory from dotted path.""" + + def test_top_level_returns_none(self): + assert _module_parent_subdir("worker") is None + + def test_single_parent(self): + assert _module_parent_subdir("01_hello.gpu_worker") == "01_hello" + + def test_nested_parent(self): + assert ( + _module_parent_subdir("01_getting_started.03_mixed.pipeline") + == "01_getting_started/03_mixed" + ) + + +class TestMakeImportLine: + """Test _make_import_line generates correct import syntax.""" + + def test_normal_module_uses_from_import(self): + result = _make_import_line("worker", "process") + assert result == "from worker import process" + + def test_numeric_module_uses_flash_import(self): + result = _make_import_line("01_hello.gpu_worker", "gpu_hello") + assert ( + result + == 'gpu_hello = _flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' + ) + + def test_nested_numeric_includes_full_subdir(self): + result = _make_import_line( + "01_getting_started.01_hello.gpu_worker", "gpu_hello" + ) + assert '"01_getting_started/01_hello"' in result + + def test_top_level_numeric_module_no_subdir(self): + result = _make_import_line("01_worker", "process") + assert result == 'process = _flash_import("01_worker", "process")' + + +class TestGenerateFlashServerNumericDirs: + """Test _generate_flash_server with numeric-prefixed directory names.""" + + def test_qb_numeric_dir_uses_flash_import(self, tmp_path): + """QB workers in numeric dirs use _flash_import with scoped sys.path.""" + worker = WorkerInfo( + file_path=tmp_path / "01_hello" / "gpu_worker.py", + url_prefix="/01_hello/gpu_worker", + module_path="01_hello.gpu_worker", + resource_name="01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + # Must NOT contain invalid 'from 01_hello...' import + assert "from 01_hello" not in content + # Must have _flash_import helper and importlib + assert "import importlib as _importlib" in content + assert "def _flash_import(" in content + assert ( + '_flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' in content + ) + + def test_qb_numeric_dir_function_name_prefixed(self, tmp_path): + """QB handler function names starting with digits get '_' prefix.""" + worker = WorkerInfo( + file_path=tmp_path / "01_hello" / "gpu_worker.py", + url_prefix="/01_hello/gpu_worker", + module_path="01_hello.gpu_worker", + resource_name="01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + # Function name must start with '_', not a digit + assert "async def _01_hello_gpu_worker_run_sync(body: dict):" in content + + def test_lb_numeric_dir_uses_flash_import(self, tmp_path): + """LB workers in numeric dirs use _flash_import for config and function imports.""" + worker = WorkerInfo( + file_path=tmp_path / "03_advanced" / "05_lb" / "cpu_lb.py", + url_prefix="/03_advanced/05_lb/cpu_lb", + module_path="03_advanced.05_lb.cpu_lb", + resource_name="03_advanced_05_lb_cpu_lb", + worker_type="LB", + functions=["validate_data"], + lb_routes=[ + { + "method": "POST", + "path": "/validate", + "fn_name": "validate_data", + "config_variable": "cpu_config", + } + ], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + assert "from 03_advanced" not in content + assert ( + '_flash_import("03_advanced.05_lb.cpu_lb", "cpu_config", "03_advanced/05_lb")' + in content + ) + assert ( + '_flash_import("03_advanced.05_lb.cpu_lb", "validate_data", "03_advanced/05_lb")' + in content + ) + + def test_mixed_numeric_and_normal_dirs(self, tmp_path): + """Normal modules use 'from' imports, numeric modules use _flash_import.""" + normal_worker = WorkerInfo( + file_path=tmp_path / "worker.py", + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + ) + numeric_worker = WorkerInfo( + file_path=tmp_path / "01_hello" / "gpu_worker.py", + url_prefix="/01_hello/gpu_worker", + module_path="01_hello.gpu_worker", + resource_name="01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + content = _generate_flash_server( + tmp_path, [normal_worker, numeric_worker] + ).read_text() + + # Normal worker uses standard import + assert "from worker import process" in content + # Numeric worker uses scoped _flash_import + assert ( + '_flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' in content + ) + + def test_no_importlib_when_all_normal_dirs(self, tmp_path): + """importlib and _flash_import are not emitted when no numeric dirs exist.""" + worker = WorkerInfo( + file_path=tmp_path / "worker.py", + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "importlib" not in content + assert "_flash_import" not in content + + def test_scoped_import_includes_subdir(self, tmp_path): + """_flash_import calls pass the subdirectory for sibling import scoping.""" + worker = WorkerInfo( + file_path=tmp_path / "01_getting_started" / "03_mixed" / "pipeline.py", + url_prefix="/01_getting_started/03_mixed/pipeline", + module_path="01_getting_started.03_mixed.pipeline", + resource_name="01_getting_started_03_mixed_pipeline", + worker_type="LB", + functions=["classify"], + lb_routes=[ + { + "method": "POST", + "path": "/classify", + "fn_name": "classify", + "config_variable": "pipeline_config", + } + ], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + # Must scope to correct subdirectory, not add all dirs to sys.path + assert '"01_getting_started/03_mixed"' in content + # No global sys.path additions for subdirs — only the project root + # line at the top and the one inside _flash_import helper body + lines = content.split("\n") + global_sys_path_lines = [ + line + for line in lines + if "sys.path.insert" in line and not line.startswith(" ") + ] + assert len(global_sys_path_lines) == 1 + + def test_generated_server_is_valid_python(self, tmp_path): + """Generated server.py with numeric dirs must be parseable Python.""" + worker = WorkerInfo( + file_path=tmp_path / "01_getting_started" / "01_hello" / "gpu_worker.py", + url_prefix="/01_getting_started/01_hello/gpu_worker", + module_path="01_getting_started.01_hello.gpu_worker", + resource_name="01_getting_started_01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + server_path = _generate_flash_server(tmp_path, [worker]) + content = server_path.read_text() + + # Must parse without SyntaxError + import ast + + ast.parse(content) + + class TestMapBodyToParams: """Tests for _map_body_to_params — maps HTTP body to function arguments.""" From 5480404187a09aff190d274d162c5f5c2110a5cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 19 Feb 2026 16:37:14 -0800 Subject: [PATCH 17/25] fix(ci): update validate-wheel.sh for flat skeleton template The skeleton template was replaced with flat worker files (cpu_worker.py, gpu_worker.py, lb_worker.py, pyproject.toml) but the wheel validation script still expected the old multi-directory structure (main.py, workers/**). This caused the Build Package CI check to fail. --- scripts/validate-wheel.sh | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/scripts/validate-wheel.sh b/scripts/validate-wheel.sh index 7e4dd517..a38db25f 100755 --- a/scripts/validate-wheel.sh +++ b/scripts/validate-wheel.sh @@ -21,14 +21,12 @@ REQUIRED_TEMPLATE_FILES=( "runpod_flash/cli/utils/skeleton_template/.env.example" "runpod_flash/cli/utils/skeleton_template/.gitignore" "runpod_flash/cli/utils/skeleton_template/.flashignore" - "runpod_flash/cli/utils/skeleton_template/main.py" + "runpod_flash/cli/utils/skeleton_template/cpu_worker.py" + "runpod_flash/cli/utils/skeleton_template/gpu_worker.py" + "runpod_flash/cli/utils/skeleton_template/lb_worker.py" + "runpod_flash/cli/utils/skeleton_template/pyproject.toml" "runpod_flash/cli/utils/skeleton_template/README.md" "runpod_flash/cli/utils/skeleton_template/requirements.txt" - "runpod_flash/cli/utils/skeleton_template/workers/__init__.py" - "runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py" - "runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py" - "runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py" - "runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py" ) MISSING_IN_WHEEL=0 @@ -77,7 +75,7 @@ flash init test_project > /dev/null 2>&1 # Verify critical files exist echo "" echo "Verifying created files..." -REQUIRED_FILES=(".env.example" ".gitignore" ".flashignore" "main.py" "README.md" "requirements.txt") +REQUIRED_FILES=(".env.example" ".gitignore" ".flashignore" "cpu_worker.py" "gpu_worker.py" "lb_worker.py" "pyproject.toml" "README.md" "requirements.txt") MISSING_IN_OUTPUT=0 for file in "${REQUIRED_FILES[@]}"; do @@ -94,15 +92,6 @@ for file in "${REQUIRED_FILES[@]}"; do fi done -# Verify workers directory structure -if [ -d "test_project/workers/cpu" ] && [ -d "test_project/workers/gpu" ]; then - echo "[OK] workers/cpu/" - echo "[OK] workers/gpu/" -else - echo "[MISSING] workers directory structure" - MISSING_IN_OUTPUT=$((MISSING_IN_OUTPUT + 1)) -fi - # Cleanup deactivate cd - > /dev/null From 8b1e19f5de168cfc4e34e224c0f6323529846abe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 19 Feb 2026 17:13:14 -0800 Subject: [PATCH 18/25] fix: address PR 208 review feedback - Guard watcher_thread.join() with is_alive() check for --no-reload - Wrap watchfiles import in try/except for missing dependency - Fix debug log to show actual type instead of hardcoded class name - Fix invalid dict addition in skeleton README example - Fix PRD spec to match actual /run_sync-only behavior --- PRD.md | 2 +- src/runpod_flash/cli/commands/run.py | 24 +++++++++++++++---- .../cli/utils/skeleton_template/README.md | 2 +- src/runpod_flash/stubs/load_balancer_sls.py | 6 ++++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/PRD.md b/PRD.md index a5d9d98f..2df30adc 100644 --- a/PRD.md +++ b/PRD.md @@ -40,7 +40,7 @@ async def process(input_data: dict) -> dict: return {"result": "processed", "input": input_data} ``` -`flash run` → `POST /gpu_worker/run` and `POST /gpu_worker/run_sync` +`flash run` → `POST /gpu_worker/run_sync` `flash deploy` → standalone QB endpoint at `api.runpod.ai/v2/{id}/run` ### 4.2 LB endpoint diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index bf3bbd74..da96af7e 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -13,8 +13,22 @@ import typer from rich.console import Console from rich.table import Table -from watchfiles import DefaultFilter as _WatchfilesDefaultFilter -from watchfiles import watch as _watchfiles_watch + +try: + from watchfiles import DefaultFilter as _WatchfilesDefaultFilter + from watchfiles import watch as _watchfiles_watch +except ModuleNotFoundError: + + def _watchfiles_watch(*_a, **_kw): # type: ignore[misc] + raise ModuleNotFoundError( + "watchfiles is required for flash run --reload. " + "Install it with: pip install watchfiles" + ) + + class _WatchfilesDefaultFilter: # type: ignore[no-redef] + def __init__(self, **_kw): + pass + from .build_utils.scanner import ( RemoteDecoratorScanner, @@ -709,7 +723,8 @@ def run_command( console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") stop_event.set() - watcher_thread.join(timeout=2) + if watcher_thread.is_alive(): + watcher_thread.join(timeout=2) if process: try: @@ -738,7 +753,8 @@ def run_command( console.print(f"[red]Error:[/red] {e}") stop_event.set() - watcher_thread.join(timeout=2) + if watcher_thread.is_alive(): + watcher_thread.join(timeout=2) if process: try: diff --git a/src/runpod_flash/cli/utils/skeleton_template/README.md b/src/runpod_flash/cli/utils/skeleton_template/README.md index f30adf00..328a8ab3 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/README.md +++ b/src/runpod_flash/cli/utils/skeleton_template/README.md @@ -103,7 +103,7 @@ cpu_config = CpuLiveServerless(name="cpu_worker") @remote(resource_config=cpu_config) async def cpu_hello(input_data: dict = {}) -> dict: - return {"message": "Hello from CPU!"} + input_data + return {"message": "Hello from CPU!", **input_data} ``` ### Load-Balanced (LB) Workers diff --git a/src/runpod_flash/stubs/load_balancer_sls.py b/src/runpod_flash/stubs/load_balancer_sls.py index f44bccbb..d08a0c5a 100644 --- a/src/runpod_flash/stubs/load_balancer_sls.py +++ b/src/runpod_flash/stubs/load_balancer_sls.py @@ -79,7 +79,11 @@ def _should_use_execute_endpoint(self, func: Callable[..., Any]) -> bool: # Always use /execute for live resources (local development) if isinstance(self.server, LiveServerlessMixin): - log.debug(f"Using /execute endpoint for LiveLoadBalancer: {func.__name__}") + log.debug( + "Using /execute endpoint for live resource %s (type=%s)", + func.__name__, + type(self.server).__name__, + ) return True # Check if function has routing metadata From 277cd7a5c661059b1cb46ef7523c842016f6e82c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 09:12:34 -0800 Subject: [PATCH 19/25] fix(run): match generated QB/LB route calls to @remote function signatures The dev server codegen always generated `await fn(body.get("input", body))` regardless of actual function signature. This crashed zero-param functions with TypeError and incorrectly passed a dict to multi-param functions. Scanner changes: - Extract param_names from function AST nodes (excluding self) - Extract class_method_params per public method for @remote classes Codegen changes: - 0 params: `await fn()` with no `body: dict` in handler signature - 1 param: `await fn(body.get("input", body))` (preserves current behavior) - 2+ params: `await fn(**body.get("input", body))` (kwargs spread) - LB GET routes with path params (e.g. `/images/{file_id}`) now declare typed parameters in handler signature for proper Swagger UI rendering - LB POST routes with path params merge body and path params --- .../cli/commands/build_utils/scanner.py | 36 ++ src/runpod_flash/cli/commands/run.py | 214 +++++-- .../cli/commands/build_utils/test_scanner.py | 249 ++++++++ tests/unit/cli/commands/test_run.py | 583 ++++++++++++++++++ 4 files changed, 1047 insertions(+), 35 deletions(-) create mode 100644 tests/unit/cli/commands/test_run.py diff --git a/src/runpod_flash/cli/commands/build_utils/scanner.py b/src/runpod_flash/cli/commands/build_utils/scanner.py index d217dcb3..892f3bcf 100644 --- a/src/runpod_flash/cli/commands/build_utils/scanner.py +++ b/src/runpod_flash/cli/commands/build_utils/scanner.py @@ -94,6 +94,15 @@ class RemoteFunctionMetadata: is_lb_route_handler: bool = ( False # LB @remote with method= and path= — runs directly as HTTP handler ) + class_methods: List[str] = field( + default_factory=list + ) # Public methods for @remote classes + param_names: List[str] = field( + default_factory=list + ) # Function params excluding self + class_method_params: Dict[str, List[str]] = field( + default_factory=dict + ) # method_name -> param_names (for classes) class RemoteDecoratorScanner: @@ -290,6 +299,30 @@ def _extract_remote_functions( and http_path is not None ) + # Extract public methods for @remote classes + class_methods: List[str] = [] + class_method_params: Dict[str, List[str]] = {} + if is_class: + for n in node.body: + if isinstance( + n, (ast.FunctionDef, ast.AsyncFunctionDef) + ) and not n.name.startswith("_"): + class_methods.append(n.name) + class_method_params[n.name] = [ + arg.arg + for arg in n.args.args + if arg.arg != "self" + ] + + # Extract param names for functions (not classes) + param_names: List[str] = [] + if not is_class and isinstance( + node, (ast.FunctionDef, ast.AsyncFunctionDef) + ): + param_names = [ + arg.arg for arg in node.args.args if arg.arg != "self" + ] + metadata = RemoteFunctionMetadata( function_name=node.name, module_path=module_path, @@ -306,6 +339,9 @@ def _extract_remote_functions( resource_config_name ), is_lb_route_handler=is_lb_route_handler, + class_methods=class_methods, + param_names=param_names, + class_method_params=class_method_params, ) functions.append(metadata) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index da96af7e..87fddac9 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -2,6 +2,7 @@ import logging import os +import re import signal import subprocess import sys @@ -54,7 +55,13 @@ class WorkerInfo: resource_name: str # e.g. longruns_stage1 worker_type: str # "QB" or "LB" functions: List[str] # function names + class_remotes: List[dict] = field( + default_factory=list + ) # [{name, methods, method_params}] lb_routes: List[dict] = field(default_factory=list) # [{method, path, fn_name}] + function_params: dict[str, list[str]] = field( + default_factory=dict + ) # fn_name -> param_names def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: @@ -87,10 +94,11 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: module_path = file_to_module_path(file_path, project_root) resource_name = file_to_resource_name(file_path, project_root) - qb_funcs = [f for f in funcs if not f.is_load_balanced] + qb_funcs = [f for f in funcs if not f.is_load_balanced and not f.is_class] + qb_classes = [f for f in funcs if not f.is_load_balanced and f.is_class] lb_funcs = [f for f in funcs if f.is_load_balanced and f.is_lb_route_handler] - if qb_funcs: + if qb_funcs or qb_classes: workers.append( WorkerInfo( file_path=file_path, @@ -99,6 +107,15 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: resource_name=resource_name, worker_type="QB", functions=[f.function_name for f in qb_funcs], + class_remotes=[ + { + "name": c.function_name, + "methods": c.class_methods, + "method_params": c.class_method_params, + } + for c in qb_classes + ], + function_params={f.function_name: f.param_names for f in qb_funcs}, ) ) @@ -194,6 +211,37 @@ def _make_import_line(module_path: str, name: str) -> str: return f"from {module_path} import {name}" +_PATH_PARAM_RE = re.compile(r"\{(\w+)\}") + + +def _extract_path_params(path: str) -> list[str]: + """Extract path parameter names from a FastAPI-style route path. + + Example: "/images/{file_id}" -> ["file_id"] + """ + return _PATH_PARAM_RE.findall(path) + + +def _build_call_expr(callable_name: str, params: list[str] | None) -> tuple[str, bool]: + """Build an async call expression based on parameter count. + + Args: + callable_name: Fully qualified callable (e.g. "fn" or "instance.method") + params: List of param names, or None if unknown (backward compat) + + Returns: + Tuple of (call_expression, needs_body). needs_body is False when the + handler signature should omit the ``body: dict`` parameter. + """ + if params is not None and len(params) == 0: + return f"await {callable_name}()", False + elif params is not None and len(params) >= 2: + return f'await {callable_name}(**body.get("input", body))', True + else: + # 1 param or unknown (None) — preserve current behavior + return f'await {callable_name}(body.get("input", body))', True + + def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: """Generate .flash/server.py from the discovered workers. @@ -270,6 +318,10 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat if worker.worker_type == "QB": for fn_name in worker.functions: all_imports.append(_make_import_line(worker.module_path, fn_name)) + for cls_info in worker.class_remotes: + all_imports.append( + _make_import_line(worker.module_path, cls_info["name"]) + ) elif worker.worker_type == "LB": # Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...)) config_vars = { @@ -294,6 +346,15 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat "", ] + # Module-level instance creation for @remote classes + for worker in workers: + for cls_info in worker.class_remotes: + cls_name = cls_info["name"] + lines.append(f"_instance_{cls_name} = {cls_name}()") + # Add blank line if any instances were created + if any(worker.class_remotes for worker in workers): + lines.append("") + for worker in workers: tag = f"{worker.url_prefix.lstrip('/')} [{worker.worker_type}]" lines.append(f"# {'─' * 60}") @@ -301,27 +362,67 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat lines.append(f"# {'─' * 60}") if worker.worker_type == "QB": - if len(worker.functions) == 1: - fn = worker.functions[0] - handler_name = _sanitize_fn_name(f"{worker.resource_name}_run_sync") - sync_path = f"{worker.url_prefix}/run_sync" + # Total callable count: functions + sum of class methods + total_class_methods = sum(len(c["methods"]) for c in worker.class_remotes) + total_callables = len(worker.functions) + total_class_methods + use_multi = total_callables > 1 + + # Function-based routes + for fn in worker.functions: + if use_multi: + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{fn}_run_sync" + ) + sync_path = f"{worker.url_prefix}/{fn}/run_sync" + else: + handler_name = _sanitize_fn_name(f"{worker.resource_name}_run_sync") + sync_path = f"{worker.url_prefix}/run_sync" + params = worker.function_params.get(fn) + call_expr, needs_body = _build_call_expr(fn, params) + handler_sig = ( + f"async def {handler_name}(body: dict):" + if needs_body + else f"async def {handler_name}():" + ) lines += [ f'@app.post("{sync_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict):", - f' result = await {fn}(body.get("input", body))', + handler_sig, + f" result = {call_expr}", ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', "", ] - else: - for fn in worker.functions: - handler_name = _sanitize_fn_name( - f"{worker.resource_name}_{fn}_run_sync" + + # Class-based routes + for cls_info in worker.class_remotes: + cls_name = cls_info["name"] + methods = cls_info["methods"] + method_params = cls_info.get("method_params", {}) + instance_var = f"_instance_{cls_name}" + + for method in methods: + if use_multi: + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{cls_name}_{method}_run_sync" + ) + sync_path = f"{worker.url_prefix}/{method}/run_sync" + else: + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{cls_name}_run_sync" + ) + sync_path = f"{worker.url_prefix}/run_sync" + params = method_params.get(method) + call_expr, needs_body = _build_call_expr( + f"{instance_var}.{method}", params + ) + handler_sig = ( + f"async def {handler_name}(body: dict):" + if needs_body + else f"async def {handler_name}():" ) - sync_path = f"{worker.url_prefix}/{fn}/run_sync" lines += [ f'@app.post("{sync_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict):", - f' result = await {fn}(body.get("input", body))', + handler_sig, + f" result = {call_expr}", ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', "", ] @@ -336,21 +437,44 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat handler_name = _sanitize_fn_name( f"_route_{worker.resource_name}_{fn_name}" ) + path_params = _extract_path_params(full_path) has_body = method in ("post", "put", "patch", "delete") if has_body: - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict):", - f" return await _lb_execute({config_var}, {fn_name}, body)", - "", - ] + # POST/PUT/PATCH/DELETE: body + optional path params + if path_params: + param_sig = ", ".join(f"{p}: str" for p in path_params) + param_dict = ", ".join(f'"{p}": {p}' for p in path_params) + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict, {param_sig}):", + f" return await _lb_execute({config_var}, {fn_name}, {{**body, {param_dict}}})", + "", + ] + else: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f" return await _lb_execute({config_var}, {fn_name}, body)", + "", + ] else: - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(request: Request):", - f" return await _lb_execute({config_var}, {fn_name}, dict(request.query_params))", - "", - ] + # GET/etc: path params + query params + if path_params: + param_sig = ", ".join(f"{p}: str" for p in path_params) + param_dict = ", ".join(f'"{p}": {p}' for p in path_params) + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}({param_sig}, request: Request):", + f" return await _lb_execute({config_var}, {fn_name}, {{**dict(request.query_params), {param_dict}}})", + "", + ] + else: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(request: Request):", + f" return await _lb_execute({config_var}, {fn_name}, dict(request.query_params))", + "", + ] # Health endpoints lines += [ @@ -382,19 +506,39 @@ def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> Non for worker in workers: if worker.worker_type == "QB": - if len(worker.functions) == 1: - table.add_row( - f"POST {worker.url_prefix}/run_sync", - worker.resource_name, - "QB", - ) - else: - for fn in worker.functions: + total_class_methods = sum(len(c["methods"]) for c in worker.class_remotes) + total_callables = len(worker.functions) + total_class_methods + use_multi = total_callables > 1 + + for fn in worker.functions: + if use_multi: table.add_row( f"POST {worker.url_prefix}/{fn}/run_sync", worker.resource_name, "QB", ) + else: + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) + + for cls_info in worker.class_remotes: + methods = cls_info["methods"] + for method in methods: + if use_multi: + table.add_row( + f"POST {worker.url_prefix}/{method}/run_sync", + worker.resource_name, + "QB", + ) + else: + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) elif worker.worker_type == "LB": for route in worker.lb_routes: sub_path = route["path"].lstrip("/") diff --git a/tests/unit/cli/commands/build_utils/test_scanner.py b/tests/unit/cli/commands/build_utils/test_scanner.py index 0572b7ce..3b5227c4 100644 --- a/tests/unit/cli/commands/build_utils/test_scanner.py +++ b/tests/unit/cli/commands/build_utils/test_scanner.py @@ -878,3 +878,252 @@ async def process_data(): assert routes[0].http_path == "/api/process" assert routes[0].is_async is True assert routes[0].http_method == "POST" + + +def test_class_methods_extraction(): + """Test that public methods are extracted from @remote classes.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "gpu_worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +gpu_config = LiveServerless(name="gpu_worker") + +@remote(gpu_config) +class SimpleSD: + def __init__(self): + self.model = None + + def generate_image(self, prompt): + return {"image": "base64..."} + + def upscale(self, image): + return {"image": "upscaled..."} + + def _load_model(self): + pass + + def __repr__(self): + return "SimpleSD" +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + meta = functions[0] + assert meta.function_name == "SimpleSD" + assert meta.is_class is True + assert meta.class_methods == ["generate_image", "upscale"] + + +def test_class_methods_excludes_private_and_dunder(): + """Test that _private and __dunder__ methods are excluded from class_methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class MyWorker: + def __init__(self): + pass + + def __repr__(self): + return "MyWorker" + + def _internal_helper(self): + pass + + async def process(self, data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].class_methods == ["process"] + + +def test_class_with_no_public_methods(): + """Test @remote class with only private/dunder methods has empty class_methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class EmptyWorker: + def __init__(self): + pass + + def __call__(self, data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].class_methods == [] + + +def test_function_has_empty_class_methods(): + """Test that regular @remote functions have empty class_methods list.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def my_function(data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_class is False + assert functions[0].class_methods == [] + + +def test_param_names_single_param(): + """Test that param_names extracts a single parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def process(data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].param_names == ["data"] + + +def test_param_names_zero_params(): + """Test that param_names is empty for zero-parameter functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def list_images() -> dict: + return {"images": []} +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].param_names == [] + + +def test_param_names_multiple_params(): + """Test that param_names extracts multiple parameters.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def transform(text: str, operation: str = "uppercase") -> dict: + return {"result": text} +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].param_names == ["text", "operation"] + + +def test_class_method_params_extraction(): + """Test that class_method_params extracts params for each public method.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class ImageProcessor: + def __init__(self): + pass + + def generate(self, prompt: str, width: int = 512): + return {} + + def list_models(self): + return [] + + def _internal(self): + pass +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + meta = functions[0] + assert meta.is_class is True + assert meta.class_methods == ["generate", "list_models"] + assert meta.class_method_params == { + "generate": ["prompt", "width"], + "list_models": [], + } + # Classes should have empty param_names + assert meta.param_names == [] diff --git a/tests/unit/cli/commands/test_run.py b/tests/unit/cli/commands/test_run.py new file mode 100644 index 00000000..db24cd29 --- /dev/null +++ b/tests/unit/cli/commands/test_run.py @@ -0,0 +1,583 @@ +"""Tests for flash run dev server generation.""" + +import tempfile +from pathlib import Path + +from runpod_flash.cli.commands.run import ( + WorkerInfo, + _generate_flash_server, + _scan_project_workers, +) + + +def test_scan_separates_classes_from_functions(): + """Test that _scan_project_workers puts classes in class_remotes, not functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "gpu_worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="gpu_worker") + +@remote(config) +async def process(data): + return data + +@remote(config) +class SimpleSD: + def generate_image(self, prompt): + return {"image": "data"} + + def upscale(self, image): + return {"image": "upscaled"} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert worker.worker_type == "QB" + assert worker.functions == ["process"] + assert len(worker.class_remotes) == 1 + assert worker.class_remotes[0]["name"] == "SimpleSD" + assert worker.class_remotes[0]["methods"] == ["generate_image", "upscale"] + + +def test_scan_class_only_worker(): + """Test scanning a file with only a class-based @remote.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "sd_worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="sd_worker") + +@remote(config) +class StableDiffusion: + def __init__(self): + self.model = None + + def generate(self, prompt): + return {"image": "data"} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert worker.worker_type == "QB" + assert worker.functions == [] + assert len(worker.class_remotes) == 1 + assert worker.class_remotes[0]["name"] == "StableDiffusion" + assert worker.class_remotes[0]["methods"] == ["generate"] + + +def test_codegen_class_single_method(): + """Test generated server.py for a class with a single method uses short URL.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("sd_worker.py"), + url_prefix="/sd_worker", + module_path="sd_worker", + resource_name="sd_worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "StableDiffusion", + "methods": ["generate"], + "method_params": {"generate": ["prompt"]}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_instance_StableDiffusion = StableDiffusion()" in content + assert ( + 'result = await _instance_StableDiffusion.generate(body.get("input", body))' + in content + ) + assert '"/sd_worker/run_sync"' in content + # Single method: no method name in URL + assert '"/sd_worker/generate/run_sync"' not in content + + +def test_codegen_class_multiple_methods(): + """Test generated server.py for a class with multiple methods uses method URLs.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("gpu_worker.py"), + url_prefix="/gpu_worker", + module_path="gpu_worker", + resource_name="gpu_worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "SimpleSD", + "methods": ["generate_image", "upscale"], + "method_params": { + "generate_image": ["prompt"], + "upscale": ["image"], + }, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_instance_SimpleSD = SimpleSD()" in content + assert '"/gpu_worker/generate_image/run_sync"' in content + assert '"/gpu_worker/upscale/run_sync"' in content + assert ( + 'await _instance_SimpleSD.generate_image(body.get("input", body))' + in content + ) + assert 'await _instance_SimpleSD.upscale(body.get("input", body))' in content + + +def test_codegen_mixed_function_and_class(): + """Test codegen when a worker has both functions and class remotes.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + class_remotes=[ + { + "name": "MyModel", + "methods": ["predict"], + "method_params": {"predict": ["data"]}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Both should use multi-callable URL pattern (total_callables = 2) + assert '"/worker/process/run_sync"' in content + assert '"/worker/predict/run_sync"' in content + assert "_instance_MyModel = MyModel()" in content + assert 'await _instance_MyModel.predict(body.get("input", body))' in content + + +def test_codegen_function_only_unchanged(): + """Test that function-only workers still generate the same code as before.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("simple.py"), + url_prefix="/simple", + module_path="simple", + resource_name="simple", + worker_type="QB", + functions=["process"], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Single function: short URL + assert '"/simple/run_sync"' in content + assert 'await process(body.get("input", body))' in content + # No instance creation + assert "_instance_" not in content + + +def test_codegen_zero_param_function(): + """Test generated code uses await fn() for zero-param functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["list_images"], + function_params={"list_images": []}, + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "await list_images()" in content + assert 'body.get("input"' not in content + # Handler should not accept body parameter + assert "async def worker_run_sync():" in content + + +def test_codegen_multi_param_function(): + """Test generated code uses await fn(**body.get(...)) for multi-param functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["transform"], + function_params={"transform": ["text", "operation"]}, + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert 'await transform(**body.get("input", body))' in content + + +def test_codegen_single_param_function(): + """Test generated code uses await fn(body.get(...)) for single-param functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + function_params={"process": ["data"]}, + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert 'await process(body.get("input", body))' in content + + +def test_codegen_zero_param_class_method(): + """Test generated code uses await instance.method() for zero-param class methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "ImageProcessor", + "methods": ["list_models"], + "method_params": {"list_models": []}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "await _instance_ImageProcessor.list_models()" in content + # Handler should not accept body parameter + assert "worker_ImageProcessor_run_sync():" in content + + +def test_codegen_multi_param_class_method(): + """Test generated code uses **body spread for multi-param class methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "ImageProcessor", + "methods": ["generate"], + "method_params": {"generate": ["prompt", "width"]}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert ( + 'await _instance_ImageProcessor.generate(**body.get("input", body))' + in content + ) + + +def test_codegen_backward_compat_no_method_params(): + """Test that missing method_params in class_remotes falls back to 1-param pattern.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=[], + class_remotes=[ + {"name": "OldStyle", "methods": ["process"]}, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Should fall back to 1-param pattern when method_params not provided + assert 'await _instance_OldStyle.process(body.get("input", body))' in content + + +def test_scan_populates_function_params(): + """Test that _scan_project_workers populates function_params from scanner.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def no_params() -> dict: + return {} + +@remote(config) +async def one_param(data: dict) -> dict: + return data + +@remote(config) +async def multi_params(text: str, mode: str = "default") -> dict: + return {"text": text} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert worker.function_params == { + "no_params": [], + "one_param": ["data"], + "multi_params": ["text", "mode"], + } + + +def test_scan_populates_class_method_params(): + """Test that _scan_project_workers populates method_params in class_remotes.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class Processor: + def run(self, data: dict): + return data + + def status(self): + return {"ok": True} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert len(worker.class_remotes) == 1 + cls = worker.class_remotes[0] + assert cls["method_params"] == { + "run": ["data"], + "status": [], + } + + +def test_codegen_lb_get_with_path_params(): + """Test LB GET route with path params generates proper Swagger-compatible handler.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["get_image"], + lb_routes=[ + { + "method": "GET", + "path": "/images/{file_id}", + "fn_name": "get_image", + "config_variable": "cpu_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Handler must declare file_id as a typed parameter for Swagger + assert "file_id: str" in content + # Path param must be forwarded in the body dict + assert '"file_id": file_id' in content + # Should NOT use bare request: Request as only param + assert ( + "async def _route_worker_get_image(file_id: str, request: Request):" + in content + ) + + +def test_codegen_lb_get_without_path_params(): + """Test LB GET route without path params uses request: Request.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["health"], + lb_routes=[ + { + "method": "GET", + "path": "/health", + "fn_name": "health", + "config_variable": "cpu_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "async def _route_worker_health(request: Request):" in content + assert "dict(request.query_params)" in content + + +def test_codegen_lb_post_with_path_params(): + """Test LB POST route with path params includes both body and path params.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["update_item"], + lb_routes=[ + { + "method": "POST", + "path": "/items/{item_id}", + "fn_name": "update_item", + "config_variable": "api_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # POST handler must have both body and path param + assert ( + "async def _route_worker_update_item(body: dict, item_id: str):" in content + ) + assert '"item_id": item_id' in content + assert "**body" in content + + +def test_codegen_lb_get_with_multiple_path_params(): + """Test LB GET route with multiple path params.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["get_version"], + lb_routes=[ + { + "method": "GET", + "path": "/items/{item_id}/versions/{version_id}", + "fn_name": "get_version", + "config_variable": "api_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "item_id: str" in content + assert "version_id: str" in content + assert '"item_id": item_id' in content + assert '"version_id": version_id' in content From 6881489c61a480a4525a9349a29ff5df55fa7571 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 10:22:22 -0800 Subject: [PATCH 20/25] feat(run): replace body: dict with Pydantic models for typed Swagger UI Use pydantic.create_model() at server startup to dynamically build input models from @remote function signatures. Swagger UI now shows typed form fields instead of a generic JSON text area. - Add make_input_model(), call_with_body(), to_dict() helpers - Codegen emits model creation lines and typed handler signatures - Simplify _build_call_expr to 2-way branch (zero-param vs body) - Fix class method introspection: use _class_type to bypass RemoteClassWrapper proxy signatures (*args, **kwargs) - Skip VAR_POSITIONAL/VAR_KEYWORD params in model creation as safety net - Fall back to dict when model creation fails (zero disruption) --- .../cli/commands/_run_server_helpers.py | 44 +++++ src/runpod_flash/cli/commands/run.py | 85 ++++++--- tests/unit/cli/commands/test_run.py | 62 +++--- .../cli/commands/test_run_server_helpers.py | 179 ++++++++++++++++++ tests/unit/cli/test_run.py | 13 +- 5 files changed, 328 insertions(+), 55 deletions(-) create mode 100644 tests/unit/cli/commands/test_run_server_helpers.py diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py index 70391bbd..4a92c31f 100644 --- a/src/runpod_flash/cli/commands/_run_server_helpers.py +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -1,8 +1,10 @@ """Helpers for the flash run dev server — loaded inside the generated server.py.""" import inspect +from typing import Any, get_type_hints from fastapi import HTTPException +from pydantic import create_model from runpod_flash.core.resources.resource_manager import ResourceManager from runpod_flash.stubs.load_balancer_sls import LoadBalancerSlsStub @@ -29,6 +31,48 @@ def _map_body_to_params(func, body): return {first_param: body} +def make_input_model(name: str, func) -> type | None: + """Create a Pydantic model from a function's signature for FastAPI body typing. + + Returns None for zero-param functions or on failure (caller uses ``or dict``). + """ + try: + sig = inspect.signature(func) + hints = get_type_hints(func) + except (ValueError, TypeError): + return None + + _SKIP_KINDS = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + fields: dict[str, Any] = {} + for param_name, param in sig.parameters.items(): + if param_name == "self" or param.kind in _SKIP_KINDS: + continue + annotation = hints.get(param_name, Any) + if param.default is not inspect.Parameter.empty: + fields[param_name] = (annotation, param.default) + else: + fields[param_name] = (annotation, ...) + + if not fields: + return None + + return create_model(name, **fields) + + +async def call_with_body(func, body): + """Call func with body kwargs, handling Pydantic models and dicts.""" + if hasattr(body, "model_dump"): + return await func(**body.model_dump()) + raw = body.get("input", body) if isinstance(body, dict) else body + kwargs = _map_body_to_params(func, raw) + return await func(**kwargs) + + +def to_dict(body) -> dict: + """Convert Pydantic model or dict to plain dict.""" + return body.model_dump() if hasattr(body, "model_dump") else body + + async def lb_execute(resource_config, func, body: dict): """Dispatch an LB route to the deployed endpoint via LoadBalancerSlsStub. diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 87fddac9..d416d9a4 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -231,15 +231,11 @@ def _build_call_expr(callable_name: str, params: list[str] | None) -> tuple[str, Returns: Tuple of (call_expression, needs_body). needs_body is False when the - handler signature should omit the ``body: dict`` parameter. + handler signature should omit the body parameter. """ if params is not None and len(params) == 0: return f"await {callable_name}()", False - elif params is not None and len(params) >= 2: - return f'await {callable_name}(**body.get("input", body))', True - else: - # 1 param or unknown (None) — preserve current behavior - return f'await {callable_name}(body.get("input", body))', True + return f"await _call_with_body({callable_name}, body)", True def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: @@ -299,10 +295,16 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat "", ] + lines += [ + "from runpod_flash.cli.commands._run_server_helpers import make_input_model as _make_input_model", + "from runpod_flash.cli.commands._run_server_helpers import call_with_body as _call_with_body", + ] + if has_lb_workers: lines += [ "from fastapi import FastAPI, Request", "from runpod_flash.cli.commands._run_server_helpers import lb_execute as _lb_execute", + "from runpod_flash.cli.commands._run_server_helpers import to_dict as _to_dict", "", ] else: @@ -355,6 +357,44 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat if any(worker.class_remotes for worker in workers): lines.append("") + # Module-level Pydantic model creation for typed Swagger UI + model_lines: list[str] = [] + for worker in workers: + if worker.worker_type == "QB": + for fn in worker.functions: + params = worker.function_params.get(fn) + if params is None or len(params) > 0: + model_var = f"_{worker.resource_name}_{fn}_Input" + model_lines.append( + f'{model_var} = _make_input_model("{model_var}", {fn}) or dict' + ) + for cls_info in worker.class_remotes: + cls_name = cls_info["name"] + method_params = cls_info.get("method_params", {}) + instance_var = f"_instance_{cls_name}" + for method in cls_info["methods"]: + params = method_params.get(method) + if params is None or len(params) > 0: + model_var = f"_{worker.resource_name}_{cls_name}_{method}_Input" + # Use _class_type to get the original unwrapped method + # (RemoteClassWrapper.__getattr__ returns proxies with (*args, **kwargs)) + class_ref = f"getattr({instance_var}, '_class_type', type({instance_var}))" + model_lines.append( + f'{model_var} = _make_input_model("{model_var}", {class_ref}.{method}) or dict' + ) + elif worker.worker_type == "LB": + for route in worker.lb_routes: + method = route["method"].lower() + if method in ("post", "put", "patch", "delete"): + fn_name = route["fn_name"] + model_var = f"_{worker.resource_name}_{fn_name}_Input" + model_lines.append( + f'{model_var} = _make_input_model("{model_var}", {fn_name}) or dict' + ) + if model_lines: + lines.extend(model_lines) + lines.append("") + for worker in workers: tag = f"{worker.url_prefix.lstrip('/')} [{worker.worker_type}]" lines.append(f"# {'─' * 60}") @@ -379,11 +419,11 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat sync_path = f"{worker.url_prefix}/run_sync" params = worker.function_params.get(fn) call_expr, needs_body = _build_call_expr(fn, params) - handler_sig = ( - f"async def {handler_name}(body: dict):" - if needs_body - else f"async def {handler_name}():" - ) + if needs_body: + model_var = f"_{worker.resource_name}_{fn}_Input" + handler_sig = f"async def {handler_name}(body: {model_var}):" + else: + handler_sig = f"async def {handler_name}():" lines += [ f'@app.post("{sync_path}", tags=["{tag}"])', handler_sig, @@ -414,11 +454,11 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat call_expr, needs_body = _build_call_expr( f"{instance_var}.{method}", params ) - handler_sig = ( - f"async def {handler_name}(body: dict):" - if needs_body - else f"async def {handler_name}():" - ) + if needs_body: + model_var = f"_{worker.resource_name}_{cls_name}_{method}_Input" + handler_sig = f"async def {handler_name}(body: {model_var}):" + else: + handler_sig = f"async def {handler_name}():" lines += [ f'@app.post("{sync_path}", tags=["{tag}"])', handler_sig, @@ -440,25 +480,26 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat path_params = _extract_path_params(full_path) has_body = method in ("post", "put", "patch", "delete") if has_body: - # POST/PUT/PATCH/DELETE: body + optional path params + model_var = f"_{worker.resource_name}_{fn_name}_Input" + # POST/PUT/PATCH/DELETE: typed body + optional path params if path_params: param_sig = ", ".join(f"{p}: str" for p in path_params) param_dict = ", ".join(f'"{p}": {p}' for p in path_params) lines += [ f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict, {param_sig}):", - f" return await _lb_execute({config_var}, {fn_name}, {{**body, {param_dict}}})", + f"async def {handler_name}(body: {model_var}, {param_sig}):", + f" return await _lb_execute({config_var}, {fn_name}, {{**_to_dict(body), {param_dict}}})", "", ] else: lines += [ f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: dict):", - f" return await _lb_execute({config_var}, {fn_name}, body)", + f"async def {handler_name}(body: {model_var}):", + f" return await _lb_execute({config_var}, {fn_name}, _to_dict(body))", "", ] else: - # GET/etc: path params + query params + # GET/etc: path params + query params (unchanged) if path_params: param_sig = ", ".join(f"{p}: str" for p in path_params) param_dict = ", ".join(f'"{p}": {p}' for p in path_params) diff --git a/tests/unit/cli/commands/test_run.py b/tests/unit/cli/commands/test_run.py index db24cd29..c8b3311e 100644 --- a/tests/unit/cli/commands/test_run.py +++ b/tests/unit/cli/commands/test_run.py @@ -107,10 +107,9 @@ def test_codegen_class_single_method(): content = server_path.read_text() assert "_instance_StableDiffusion = StableDiffusion()" in content - assert ( - 'result = await _instance_StableDiffusion.generate(body.get("input", body))' - in content - ) + assert "_call_with_body(_instance_StableDiffusion.generate, body)" in content + assert "body: _sd_worker_StableDiffusion_generate_Input" in content + assert "_make_input_model" in content assert '"/sd_worker/run_sync"' in content # Single method: no method name in URL assert '"/sd_worker/generate/run_sync"' not in content @@ -148,11 +147,10 @@ def test_codegen_class_multiple_methods(): assert "_instance_SimpleSD = SimpleSD()" in content assert '"/gpu_worker/generate_image/run_sync"' in content assert '"/gpu_worker/upscale/run_sync"' in content - assert ( - 'await _instance_SimpleSD.generate_image(body.get("input", body))' - in content - ) - assert 'await _instance_SimpleSD.upscale(body.get("input", body))' in content + assert "_call_with_body(_instance_SimpleSD.generate_image, body)" in content + assert "_call_with_body(_instance_SimpleSD.upscale, body)" in content + assert "body: _gpu_worker_SimpleSD_generate_image_Input" in content + assert "body: _gpu_worker_SimpleSD_upscale_Input" in content def test_codegen_mixed_function_and_class(): @@ -185,11 +183,12 @@ def test_codegen_mixed_function_and_class(): assert '"/worker/process/run_sync"' in content assert '"/worker/predict/run_sync"' in content assert "_instance_MyModel = MyModel()" in content - assert 'await _instance_MyModel.predict(body.get("input", body))' in content + assert "_call_with_body(_instance_MyModel.predict, body)" in content + assert "_call_with_body(process, body)" in content -def test_codegen_function_only_unchanged(): - """Test that function-only workers still generate the same code as before.""" +def test_codegen_function_only(): + """Test that function-only workers use Pydantic model and _call_with_body.""" with tempfile.TemporaryDirectory() as tmpdir: project_root = Path(tmpdir) @@ -209,7 +208,9 @@ def test_codegen_function_only_unchanged(): # Single function: short URL assert '"/simple/run_sync"' in content - assert 'await process(body.get("input", body))' in content + assert "_call_with_body(process, body)" in content + assert "_simple_process_Input = _make_input_model(" in content + assert "body: _simple_process_Input" in content # No instance creation assert "_instance_" not in content @@ -241,7 +242,7 @@ def test_codegen_zero_param_function(): def test_codegen_multi_param_function(): - """Test generated code uses await fn(**body.get(...)) for multi-param functions.""" + """Test generated code uses _call_with_body for multi-param functions.""" with tempfile.TemporaryDirectory() as tmpdir: project_root = Path(tmpdir) @@ -260,11 +261,13 @@ def test_codegen_multi_param_function(): server_path = _generate_flash_server(project_root, workers) content = server_path.read_text() - assert 'await transform(**body.get("input", body))' in content + assert "_call_with_body(transform, body)" in content + assert "_worker_transform_Input = _make_input_model(" in content + assert "body: _worker_transform_Input" in content def test_codegen_single_param_function(): - """Test generated code uses await fn(body.get(...)) for single-param functions.""" + """Test generated code uses _call_with_body for single-param functions.""" with tempfile.TemporaryDirectory() as tmpdir: project_root = Path(tmpdir) @@ -283,7 +286,8 @@ def test_codegen_single_param_function(): server_path = _generate_flash_server(project_root, workers) content = server_path.read_text() - assert 'await process(body.get("input", body))' in content + assert "_call_with_body(process, body)" in content + assert "body: _worker_process_Input" in content def test_codegen_zero_param_class_method(): @@ -318,7 +322,7 @@ def test_codegen_zero_param_class_method(): def test_codegen_multi_param_class_method(): - """Test generated code uses **body spread for multi-param class methods.""" + """Test generated code uses _call_with_body for multi-param class methods.""" with tempfile.TemporaryDirectory() as tmpdir: project_root = Path(tmpdir) @@ -343,14 +347,14 @@ def test_codegen_multi_param_class_method(): server_path = _generate_flash_server(project_root, workers) content = server_path.read_text() - assert ( - 'await _instance_ImageProcessor.generate(**body.get("input", body))' - in content - ) + assert "_call_with_body(_instance_ImageProcessor.generate, body)" in content + assert "body: _worker_ImageProcessor_generate_Input" in content + # Model creation uses _class_type to get original method signature + assert "_class_type" in content def test_codegen_backward_compat_no_method_params(): - """Test that missing method_params in class_remotes falls back to 1-param pattern.""" + """Test that missing method_params in class_remotes uses _call_with_body.""" with tempfile.TemporaryDirectory() as tmpdir: project_root = Path(tmpdir) @@ -371,8 +375,9 @@ def test_codegen_backward_compat_no_method_params(): server_path = _generate_flash_server(project_root, workers) content = server_path.read_text() - # Should fall back to 1-param pattern when method_params not provided - assert 'await _instance_OldStyle.process(body.get("input", body))' in content + # Should use _call_with_body when method_params not provided (params=None) + assert "_call_with_body(_instance_OldStyle.process, body)" in content + assert "body: _worker_OldStyle_process_Input" in content def test_scan_populates_function_params(): @@ -542,12 +547,13 @@ def test_codegen_lb_post_with_path_params(): server_path = _generate_flash_server(project_root, workers) content = server_path.read_text() - # POST handler must have both body and path param + # POST handler must have typed body and path param assert ( - "async def _route_worker_update_item(body: dict, item_id: str):" in content + "async def _route_worker_update_item(body: _worker_update_item_Input, item_id: str):" + in content ) assert '"item_id": item_id' in content - assert "**body" in content + assert "_to_dict(body)" in content def test_codegen_lb_get_with_multiple_path_params(): diff --git a/tests/unit/cli/commands/test_run_server_helpers.py b/tests/unit/cli/commands/test_run_server_helpers.py new file mode 100644 index 00000000..ed6529dd --- /dev/null +++ b/tests/unit/cli/commands/test_run_server_helpers.py @@ -0,0 +1,179 @@ +"""Tests for _run_server_helpers: make_input_model, call_with_body, to_dict.""" + +from typing import Any + +import pytest +from pydantic import BaseModel + +from runpod_flash.cli.commands._run_server_helpers import ( + call_with_body, + make_input_model, + to_dict, +) + + +# --- make_input_model --- + + +def test_make_input_model_basic(): + """Function with typed params produces a Pydantic model with correct fields.""" + + async def process(text: str, count: int): + pass + + Model = make_input_model("process_Input", process) + assert Model is not None + assert issubclass(Model, BaseModel) + fields = Model.model_fields + assert "text" in fields + assert "count" in fields + assert fields["text"].annotation is str + assert fields["count"].annotation is int + + +def test_make_input_model_with_defaults(): + """Default values are preserved in the generated model.""" + + async def transform(text: str, mode: str = "default", limit: int = 10): + pass + + Model = make_input_model("transform_Input", transform) + assert Model is not None + fields = Model.model_fields + assert fields["text"].is_required() + assert not fields["mode"].is_required() + assert fields["mode"].default == "default" + assert fields["limit"].default == 10 + + +def test_make_input_model_zero_params(): + """Zero-param function returns None.""" + + async def health(): + pass + + result = make_input_model("health_Input", health) + assert result is None + + +def test_make_input_model_skips_self(): + """Self parameter is excluded from the model (class methods).""" + + class Worker: + def generate(self, prompt: str): + pass + + Model = make_input_model("generate_Input", Worker().generate) + assert Model is not None + assert "self" not in Model.model_fields + assert "prompt" in Model.model_fields + + +def test_make_input_model_untyped_params(): + """Untyped params get Any annotation.""" + + async def process(data): + pass + + Model = make_input_model("process_Input", process) + assert Model is not None + assert Model.model_fields["data"].annotation is Any + + +def test_make_input_model_skips_var_positional_and_keyword(): + """Proxy-style (*args, **kwargs) signatures return None, not a model with args/kwargs fields.""" + + async def method_proxy(*args, **kwargs): + pass + + result = make_input_model("proxy_Input", method_proxy) + assert result is None + + +def test_make_input_model_mixed_regular_and_var_keyword(): + """Regular params are kept, **kwargs is skipped.""" + + async def process(text: str, **extra): + pass + + Model = make_input_model("process_Input", process) + assert Model is not None + assert "text" in Model.model_fields + assert "extra" not in Model.model_fields + + +def test_make_input_model_failure_graceful(): + """Bad input returns None instead of raising.""" + result = make_input_model("bad_Input", 42) + assert result is None + + +# --- call_with_body --- + + +@pytest.mark.asyncio +async def test_call_with_body_pydantic(): + """Pydantic model body is spread as kwargs via model_dump().""" + received = {} + + async def process(text: str, count: int): + received.update(text=text, count=count) + return {"ok": True} + + Model = make_input_model("process_Input", process) + body = Model(text="hello", count=5) + result = await call_with_body(process, body) + assert result == {"ok": True} + assert received == {"text": "hello", "count": 5} + + +@pytest.mark.asyncio +async def test_call_with_body_dict_fallback(): + """Plain dict body uses _map_body_to_params path.""" + received = {} + + async def process(data): + received["data"] = data + return {"ok": True} + + result = await call_with_body(process, {"data": "value"}) + assert result == {"ok": True} + assert received == {"data": "value"} + + +@pytest.mark.asyncio +async def test_call_with_body_dict_with_input_wrapper(): + """Dict body with 'input' key unwraps correctly.""" + received = {} + + async def process(text: str): + received["text"] = text + return text + + result = await call_with_body(process, {"input": {"text": "hello"}}) + assert result == "hello" + assert received == {"text": "hello"} + + +# --- to_dict --- + + +def test_to_dict_pydantic(): + """Pydantic model is converted to plain dict.""" + + async def process(text: str, count: int): + pass + + Model = make_input_model("process_Input", process) + body = Model(text="hello", count=5) + result = to_dict(body) + assert result == {"text": "hello", "count": 5} + assert isinstance(result, dict) + + +def test_to_dict_plain_dict(): + """Plain dict passes through unchanged.""" + body = {"text": "hello", "count": 5} + result = to_dict(body) + assert result == body + assert result is body diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index 6014e9a3..73bee3bd 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -456,12 +456,12 @@ def _make_lb_worker(self, tmp_path: Path, method: str = "GET") -> WorkerInfo: ) def test_post_lb_route_generates_body_param(self, tmp_path): - """POST/PUT/PATCH/DELETE LB routes use body: dict for OpenAPI docs.""" + """POST/PUT/PATCH/DELETE LB routes use typed body for OpenAPI docs.""" for method in ("POST", "PUT", "PATCH", "DELETE"): worker = self._make_lb_worker(tmp_path, method) content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "async def _route_api_list_routes(body: dict):" in content - assert "_lb_execute(api_config, list_routes, body)" in content + assert "body: _api_list_routes_Input" in content + assert "_lb_execute(api_config, list_routes, _to_dict(body))" in content def test_get_lb_route_uses_query_params(self, tmp_path): """GET LB routes pass query params as a dict.""" @@ -499,7 +499,7 @@ def test_qb_function_still_imported_directly(self, tmp_path): ) content = _generate_flash_server(tmp_path, [worker]).read_text() assert "from worker import process" in content - assert "await process(" in content + assert "_call_with_body(process, body)" in content class TestSanitizeFnName: @@ -618,7 +618,10 @@ def test_qb_numeric_dir_function_name_prefixed(self, tmp_path): content = _generate_flash_server(tmp_path, [worker]).read_text() # Function name must start with '_', not a digit - assert "async def _01_hello_gpu_worker_run_sync(body: dict):" in content + assert ( + "async def _01_hello_gpu_worker_run_sync(body: _01_hello_gpu_worker_gpu_hello_Input):" + in content + ) def test_lb_numeric_dir_uses_flash_import(self, tmp_path): """LB workers in numeric dirs use _flash_import for config and function imports.""" From ef92813b23c7ffd1c2143fdde159c540e61e2ded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 11:36:12 -0800 Subject: [PATCH 21/25] feat(stubs): inject remote dispatch stubs for stacked @remote execution When @remote funcA calls @remote funcB, the worker receives only funcA's source via exec(). funcB is undefined in that namespace, causing NameError. This adds dependency_resolver.py which AST-detects calls to other @remote functions, provisions their endpoints via ResourceManager, and generates async dispatch stubs that are prepended to the caller's source code. The worker's exec() then defines both the stubs and the caller in the same namespace, allowing stacked @remote calls to dispatch correctly. - Add dependency_resolver.py with detect, resolve, generate, and build - Change prepare_request to async in LiveServerlessStub and LoadBalancerSlsStub - Move LB stub timeout to constants.py as DEFAULT_LB_STUB_TIMEOUT (60s) - Update registry.py to await prepare_request - Add 24 unit tests for dependency resolver --- src/runpod_flash/core/resources/constants.py | 3 + src/runpod_flash/stubs/dependency_resolver.py | 248 ++++++++++++ src/runpod_flash/stubs/live_serverless.py | 18 +- src/runpod_flash/stubs/load_balancer_sls.py | 27 +- src/runpod_flash/stubs/registry.py | 2 +- tests/integration/test_lb_remote_execution.py | 2 +- tests/unit/test_dependency_resolver.py | 372 ++++++++++++++++++ tests/unit/test_load_balancer_sls_stub.py | 22 +- 8 files changed, 677 insertions(+), 17 deletions(-) create mode 100644 src/runpod_flash/stubs/dependency_resolver.py create mode 100644 tests/unit/test_dependency_resolver.py diff --git a/src/runpod_flash/core/resources/constants.py b/src/runpod_flash/core/resources/constants.py index e927c09b..c2aee1b3 100644 --- a/src/runpod_flash/core/resources/constants.py +++ b/src/runpod_flash/core/resources/constants.py @@ -39,3 +39,6 @@ def _endpoint_domain_from_base_url(base_url: str) -> str: MAX_TARBALL_SIZE_MB = 500 # Maximum tarball size in megabytes VALID_TARBALL_EXTENSIONS = (".tar.gz", ".tgz") # Valid tarball file extensions GZIP_MAGIC_BYTES = (0x1F, 0x8B) # Magic bytes for gzip files + +# Load balancer stub timeout (seconds) +DEFAULT_LB_STUB_TIMEOUT = 60.0 diff --git a/src/runpod_flash/stubs/dependency_resolver.py b/src/runpod_flash/stubs/dependency_resolver.py new file mode 100644 index 00000000..97b8a094 --- /dev/null +++ b/src/runpod_flash/stubs/dependency_resolver.py @@ -0,0 +1,248 @@ +"""Dependency resolver for stacked @remote function execution. + +When @remote funcA calls @remote funcB, the worker only receives funcA's source. +This module detects such dependencies, provisions their endpoints, and generates +dispatch stubs so funcB resolves correctly inside the worker's exec() namespace. +""" + +import ast +import inspect +import logging +from dataclasses import dataclass +from typing import Any + +from .live_serverless import get_function_source + +log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RemoteDependency: + """A resolved @remote dependency ready for stub generation.""" + + name: str + endpoint_id: str + source: str + dependencies: list[str] + system_dependencies: list[str] + + +def detect_remote_dependencies(source: str, func_globals: dict[str, Any]) -> list[str]: + """Find names of @remote functions called in *source*. + + Parses the source AST and checks each direct function call (ast.Name) + against *func_globals* for the ``__remote_config__`` attribute set by + the @remote decorator. + + Only direct calls (``await funcB(x)``) are detected. Attribute calls + (``obj.funcB(x)``) and indirect references (``f = funcB; f(x)``) are + intentionally ignored (V1 limitation). + + Args: + source: Source code string of the calling function. + func_globals: The ``__globals__`` dict of the calling function, + used to resolve called names. + + Returns: + Sorted list of names that resolve to @remote-decorated objects. + """ + tree = ast.parse(source) + called_names: set[str] = set() + + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + called_names.add(node.func.id) + + remote_deps = [ + name + for name in sorted(called_names) + if hasattr(func_globals.get(name), "__remote_config__") + ] + return remote_deps + + +async def resolve_dependencies( + source: str, func_globals: dict[str, Any] +) -> list[RemoteDependency]: + """Detect @remote dependencies and provision their endpoints. + + For each detected dependency: + 1. Extract resource_config from ``__remote_config__`` + 2. Provision via ``ResourceManager().get_or_deploy_resource()`` + 3. Return a ``RemoteDependency`` with the provisioned endpoint_id + + Args: + source: Source code of the calling function. + func_globals: The ``__globals__`` dict of the calling function. + + Returns: + List of resolved dependencies with endpoint IDs. + + Raises: + RuntimeError: If endpoint provisioning fails for any dependency. + """ + dep_names = detect_remote_dependencies(source, func_globals) + if not dep_names: + return [] + + from ..core.resources import ResourceManager + + resource_manager = ResourceManager() + resolved: list[RemoteDependency] = [] + + for name in dep_names: + dep_func = func_globals[name] + config = dep_func.__remote_config__ + + resource_config = config["resource_config"] + remote_resource = await resource_manager.get_or_deploy_resource(resource_config) + + # Get source of the dependency function + unwrapped = inspect.unwrap(dep_func) + dep_source, _ = get_function_source(unwrapped) + + resolved.append( + RemoteDependency( + name=name, + endpoint_id=remote_resource.id, + source=dep_source, + dependencies=config.get("dependencies") or [], + system_dependencies=config.get("system_dependencies") or [], + ) + ) + log.debug( + "Resolved dependency %s -> endpoint %s", + name, + remote_resource.id, + ) + + return resolved + + +def generate_stub_code(dep: RemoteDependency) -> str: + """Generate an async stub function that dispatches to a remote endpoint. + + The stub preserves the original function's parameter names so callers + can use ``await funcB(payload)`` naturally. Inside the stub, arguments + are serialized with cloudpickle and sent via aiohttp to the RunPod + runsync endpoint. + + Args: + dep: Resolved dependency with endpoint_id and source. + + Returns: + Python source code string defining the async stub function. + """ + # Parse the dependency source to extract parameter names + tree = ast.parse(dep.source) + params_str = "*args, **kwargs" + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == dep.name + ): + params_str = _extract_params(node) + break + + # Build serialization expressions for args/kwargs + ser_args_expr, ser_kwargs_expr = _build_serialization_exprs(tree, dep.name) + + I = " " # noqa: E741 — single indent level + lines = [ + f"async def {dep.name}({params_str}):", + f"{I}import os as _os", + f"{I}import base64 as _b64", + f"{I}import cloudpickle as _cp", + f"{I}import aiohttp as _aiohttp", + "", + f"{I}_endpoint_id = {repr(dep.endpoint_id)}", + f'{I}_api_key = _os.environ.get("RUNPOD_API_KEY", "")', + f'{I}_url = f"https://api.runpod.ai/v2/{{_endpoint_id}}/runsync"', + f'{I}_headers = {{"Content-Type": "application/json"}}', + f"{I}if _api_key:", + f'{I}{I}_headers["Authorization"] = f"Bearer {{_api_key}}"', + "", + f"{I}_func_source = {repr(dep.source)}", + f"{I}_ser_args = {ser_args_expr}", + f"{I}_ser_kwargs = {ser_kwargs_expr}", + f"{I}_payload = {{", + f'{I}{I}"input": {{', + f'{I}{I}{I}"function_name": {repr(dep.name)},', + f'{I}{I}{I}"function_code": _func_source,', + f'{I}{I}{I}"args": _ser_args,', + f'{I}{I}{I}"kwargs": _ser_kwargs,', + f'{I}{I}{I}"dependencies": {repr(dep.dependencies)},', + f'{I}{I}{I}"system_dependencies": {repr(dep.system_dependencies)},', + f"{I}{I}}}", + f"{I}}}", + "", + f"{I}_timeout = _aiohttp.ClientTimeout(total=300)", + f"{I}async with _aiohttp.ClientSession(timeout=_timeout) as _sess:", + f"{I}{I}async with _sess.post(_url, json=_payload, headers=_headers) as _resp:", + f"{I}{I}{I}if _resp.status != 200:", + f"{I}{I}{I}{I}_err = await _resp.text()", + f"{I}{I}{I}{I}raise RuntimeError(", + f'{I}{I}{I}{I}{I}f"Remote {dep.name} failed (HTTP {{_resp.status}}): {{_err}}"', + f"{I}{I}{I}{I})", + f"{I}{I}{I}_data = await _resp.json()", + f'{I}{I}{I}_output = _data.get("output", _data)', + f'{I}{I}{I}if not _output.get("success"):', + f"{I}{I}{I}{I}raise RuntimeError(", + f"{I}{I}{I}{I}{I}f\"Remote {dep.name} failed: {{_output.get('error')}}\"", + f"{I}{I}{I}{I})", + f'{I}{I}{I}return _cp.loads(_b64.b64decode(_output["result"]))', + ] + return "\n".join(lines) + "\n" + + +def build_augmented_source(original_source: str, stub_codes: list[str]) -> str: + """Prepend stub code blocks before the original function source. + + Args: + original_source: The calling function's source code. + stub_codes: List of stub code strings to prepend. + + Returns: + Combined source with stubs before the original function. + """ + if not stub_codes: + return original_source + + parts = stub_codes + [original_source] + return "\n\n".join(parts) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _extract_params(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + """Extract parameter list string from an AST function node.""" + params = [arg.arg for arg in func_node.args.args] + return ", ".join(params) if params else "*args, **kwargs" + + +def _build_serialization_exprs(tree: ast.Module, func_name: str) -> tuple[str, str]: + """Return (args_expr, kwargs_expr) for serializing function parameters. + + When the original signature has named params, we serialize each by name. + Otherwise fall back to generic *args/**kwargs serialization. + """ + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == func_name + ): + param_names = [arg.arg for arg in node.args.args] + if param_names: + items = ", ".join( + f"_b64.b64encode(_cp.dumps({p})).decode()" for p in param_names + ) + return f"[{items}]", "{}" + + # Fallback for *args, **kwargs + return ( + "[_b64.b64encode(_cp.dumps(a)).decode() for a in args]", + "{k: _b64.b64encode(_cp.dumps(v)).decode() for k, v in kwargs.items()}", + ) diff --git a/src/runpod_flash/stubs/live_serverless.py b/src/runpod_flash/stubs/live_serverless.py index 256e22d7..af0b9ab1 100644 --- a/src/runpod_flash/stubs/live_serverless.py +++ b/src/runpod_flash/stubs/live_serverless.py @@ -72,7 +72,7 @@ class LiveServerlessStub(RemoteExecutorStub): def __init__(self, server: LiveServerless): self.server = server - def prepare_request( + async def prepare_request( self, func, dependencies, @@ -83,6 +83,22 @@ def prepare_request( ): source, src_hash = get_function_source(func) + # Detect and resolve @remote dependencies for stacked execution + from .dependency_resolver import ( + build_augmented_source, + generate_stub_code, + resolve_dependencies, + ) + + original_func = inspect.unwrap(func) + remote_deps = await resolve_dependencies(source, original_func.__globals__) + if remote_deps: + stub_codes = [generate_stub_code(dep) for dep in remote_deps] + source = build_augmented_source(source, stub_codes) + # Recompute cache key to include dependency endpoints + dep_key = "|".join(f"{d.name}:{d.endpoint_id}" for d in remote_deps) + src_hash = hashlib.sha256((source + dep_key).encode("utf-8")).hexdigest() + request = { "function_name": func.__name__, "dependencies": dependencies, diff --git a/src/runpod_flash/stubs/load_balancer_sls.py b/src/runpod_flash/stubs/load_balancer_sls.py index d08a0c5a..1f096170 100644 --- a/src/runpod_flash/stubs/load_balancer_sls.py +++ b/src/runpod_flash/stubs/load_balancer_sls.py @@ -17,6 +17,8 @@ serialize_args, serialize_kwargs, ) +from runpod_flash.core.resources.constants import DEFAULT_LB_STUB_TIMEOUT + from .live_serverless import get_function_source log = logging.getLogger(__name__) @@ -47,17 +49,15 @@ class LoadBalancerSlsStub: result = await stub(my_func, deps, sys_deps, accel, arg1, arg2) """ - DEFAULT_TIMEOUT = 30.0 # Default timeout in seconds - def __init__(self, server: Any, timeout: Optional[float] = None) -> None: """Initialize stub with LoadBalancerSlsResource server. Args: server: LoadBalancerSlsResource instance with endpoint_url configured - timeout: Request timeout in seconds (default: 30.0) + timeout: Request timeout in seconds (default: DEFAULT_LB_STUB_TIMEOUT) """ self.server = server - self.timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT + self.timeout = timeout if timeout is not None else DEFAULT_LB_STUB_TIMEOUT def _should_use_execute_endpoint(self, func: Callable[..., Any]) -> bool: """Determine if /execute endpoint should be used for this function. @@ -138,7 +138,7 @@ async def __call__( # Determine execution path based on resource type and routing metadata if self._should_use_execute_endpoint(func): # Local development or backward compatibility: use /execute endpoint - request = self._prepare_request( + request = await self._prepare_request( func, dependencies, system_dependencies, @@ -159,7 +159,7 @@ async def __call__( **kwargs, ) - def _prepare_request( + async def _prepare_request( self, func: Callable[..., Any], dependencies: Optional[List[str]], @@ -171,6 +171,7 @@ def _prepare_request( """Prepare HTTP request payload. Extracts function source code and serializes arguments using cloudpickle. + Detects @remote dependencies and injects dispatch stubs for stacked execution. Args: func: Function to serialize @@ -184,6 +185,20 @@ def _prepare_request( Request dictionary with serialized function and arguments """ source, _ = get_function_source(func) + + # Detect and resolve @remote dependencies for stacked execution + from .dependency_resolver import ( + build_augmented_source, + generate_stub_code, + resolve_dependencies, + ) + + original_func = inspect.unwrap(func) + remote_deps = await resolve_dependencies(source, original_func.__globals__) + if remote_deps: + stub_codes = [generate_stub_code(dep) for dep in remote_deps] + source = build_augmented_source(source, stub_codes) + log.debug(f"Extracted source for {func.__name__} ({len(source)} bytes)") request = { diff --git a/src/runpod_flash/stubs/registry.py b/src/runpod_flash/stubs/registry.py index bbea9243..23a50fad 100644 --- a/src/runpod_flash/stubs/registry.py +++ b/src/runpod_flash/stubs/registry.py @@ -42,7 +42,7 @@ async def stubbed_resource( if args == (None,): args = [] - request = stub.prepare_request( + request = await stub.prepare_request( func, dependencies, system_dependencies, diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index 406e1521..11a3d14a 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -72,7 +72,7 @@ def add(x: int, y: int) -> int: return x + y # Prepare request - request = stub._prepare_request(add, None, None, True, 5, 3) + request = await stub._prepare_request(add, None, None, True, 5, 3) # Verify request structure assert request["function_name"] == "add" diff --git a/tests/unit/test_dependency_resolver.py b/tests/unit/test_dependency_resolver.py new file mode 100644 index 00000000..c916c794 --- /dev/null +++ b/tests/unit/test_dependency_resolver.py @@ -0,0 +1,372 @@ +"""Unit tests for dependency_resolver module. + +Tests detection, stub generation, source assembly, and async resolution +of @remote function dependencies for stacked execution. +""" + +import ast +import textwrap +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from runpod_flash.stubs.dependency_resolver import ( + RemoteDependency, + build_augmented_source, + detect_remote_dependencies, + generate_stub_code, + resolve_dependencies, +) + + +# --------------------------------------------------------------------------- +# Helpers: fake @remote-decorated functions for detection tests +# --------------------------------------------------------------------------- + + +def _make_remote_func(name: str, source: str, resource_config=None): + """Create a fake function with __remote_config__ to simulate @remote.""" + ns: dict = {} + exec(compile(source, "", "exec"), ns) + func = ns[name] + func.__remote_config__ = { + "resource_config": resource_config or MagicMock(name=name), + "dependencies": ["numpy"], + "system_dependencies": [], + } + return func + + +# Shared globals dict simulating a module where both funcA and funcB live +_shared_globals: dict = {} + +_funcB_source = textwrap.dedent("""\ +async def funcB(param: dict) -> dict: + return {"result": param} +""") + +_funcB = _make_remote_func("funcB", _funcB_source) +_shared_globals["funcB"] = _funcB + +_funcC_source = textwrap.dedent("""\ +async def funcC(x: int) -> int: + return x + 1 +""") +_funcC = _make_remote_func("funcC", _funcC_source) +_shared_globals["funcC"] = _funcC + + +def _plain_helper(x): + """A plain function — no __remote_config__.""" + return x + + +_shared_globals["_plain_helper"] = _plain_helper + + +# funcA calls funcB (a @remote function) and _plain_helper (not @remote) +_funcA_source = textwrap.dedent("""\ +async def funcA(foo: str) -> dict: + payload = _plain_helper(foo) + return await funcB(payload) +""") + + +# funcD calls both funcB and funcC +_funcD_source = textwrap.dedent("""\ +async def funcD(data: str) -> dict: + b = await funcB({"key": data}) + c = await funcC(42) + return {"b": b, "c": c} +""") + + +# funcE calls nothing remote +_funcE_source = textwrap.dedent("""\ +async def funcE(x: int) -> int: + return x * 2 +""") + + +# funcF calls funcB via attribute (indirect — should NOT be detected) +_funcF_source = textwrap.dedent("""\ +async def funcF(x: int) -> int: + import somemodule + return somemodule.funcB(x) +""") + + +# --------------------------------------------------------------------------- +# Tests: detect_remote_dependencies +# --------------------------------------------------------------------------- + + +class TestDetectRemoteDependencies: + def test_detects_single_remote_dependency(self): + result = detect_remote_dependencies(_funcA_source, _shared_globals) + assert result == ["funcB"] + + def test_detects_multiple_remote_dependencies(self): + result = detect_remote_dependencies(_funcD_source, _shared_globals) + assert sorted(result) == ["funcB", "funcC"] + + def test_no_remote_dependencies(self): + result = detect_remote_dependencies(_funcE_source, _shared_globals) + assert result == [] + + def test_ignores_plain_helpers(self): + result = detect_remote_dependencies(_funcA_source, _shared_globals) + assert "_plain_helper" not in result + + def test_ignores_builtins(self): + source = textwrap.dedent("""\ + async def funcX(x: int) -> str: + return str(len([x])) + """) + result = detect_remote_dependencies(source, _shared_globals) + assert result == [] + + def test_ignores_attribute_calls(self): + """Only ast.Name calls are detected, not ast.Attribute calls.""" + result = detect_remote_dependencies(_funcF_source, _shared_globals) + assert "funcB" not in result + + def test_ignores_names_not_in_globals(self): + source = textwrap.dedent("""\ + async def funcX(x: int) -> int: + return unknown_func(x) + """) + result = detect_remote_dependencies(source, _shared_globals) + assert result == [] + + +# --------------------------------------------------------------------------- +# Tests: generate_stub_code +# --------------------------------------------------------------------------- + + +class TestGenerateStubCode: + def _make_dep(self, name="funcB", endpoint_id="ep-123", source=None): + return RemoteDependency( + name=name, + endpoint_id=endpoint_id, + source=source or _funcB_source, + dependencies=["numpy"], + system_dependencies=[], + ) + + def test_generates_valid_python(self): + dep = self._make_dep() + code = generate_stub_code(dep) + # Must compile without errors + compile(code, "", "exec") + + def test_stub_defines_correct_function_name(self): + dep = self._make_dep(name="funcB") + code = generate_stub_code(dep) + tree = ast.parse(code) + func_names = [ + node.name + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + assert "funcB" in func_names + + def test_stub_is_async(self): + dep = self._make_dep() + code = generate_stub_code(dep) + tree = ast.parse(code) + async_funcs = [ + node for node in ast.walk(tree) if isinstance(node, ast.AsyncFunctionDef) + ] + assert len(async_funcs) >= 1 + + def test_endpoint_id_embedded(self): + dep = self._make_dep(endpoint_id="ep-abc-999") + code = generate_stub_code(dep) + assert "ep-abc-999" in code + + def test_function_source_embedded(self): + dep = self._make_dep() + code = generate_stub_code(dep) + # The original source should appear somewhere in the stub (as a string) + assert "funcB" in code + + def test_preserves_original_signature(self): + """Stub should accept same params as original function.""" + dep = self._make_dep() + code = generate_stub_code(dep) + # The stub for funcB(param: dict) should have 'param' in its signature + assert "param" in code + + def test_handles_multi_param_function(self): + multi_src = textwrap.dedent("""\ + async def multi(a: int, b: str, c: float = 1.0) -> dict: + return {"a": a, "b": b, "c": c} + """) + dep = self._make_dep(name="multi", source=multi_src) + code = generate_stub_code(dep) + compile(code, "", "exec") + assert "multi" in code + + def test_handles_triple_quotes_in_source(self): + """Source with triple-quoted docstrings should be safely escaped.""" + src_with_docs = textwrap.dedent('''\ + async def documented(x: int) -> int: + """Process x with triple-quoted docstring.""" + return x + ''') + dep = self._make_dep(name="documented", source=src_with_docs) + code = generate_stub_code(dep) + compile(code, "", "exec") + + +# --------------------------------------------------------------------------- +# Tests: build_augmented_source +# --------------------------------------------------------------------------- + + +class TestBuildAugmentedSource: + def test_no_stubs_returns_original(self): + original = "async def funcA(x): return x\n" + result = build_augmented_source(original, []) + assert result == original + + def test_stubs_prepended_before_original(self): + original = "async def funcA(x): return x\n" + stub = "async def funcB(y): return y\n" + result = build_augmented_source(original, [stub]) + # stub should appear before original + assert result.index("funcB") < result.index("funcA") + + def test_augmented_source_is_valid_python(self): + original = textwrap.dedent("""\ + async def funcA(foo: str) -> dict: + return await funcB(foo) + """) + stub = textwrap.dedent("""\ + async def funcB(param: dict) -> dict: + return {"stub": True} + """) + result = build_augmented_source(original, [stub]) + compile(result, "", "exec") + + def test_multiple_stubs_prepended(self): + original = "async def funcA(x): return x\n" + stubs = [ + "async def funcB(y): return y\n", + "async def funcC(z): return z\n", + ] + result = build_augmented_source(original, stubs) + assert "funcB" in result + assert "funcC" in result + assert result.index("funcB") < result.index("funcA") + assert result.index("funcC") < result.index("funcA") + + +# --------------------------------------------------------------------------- +# Tests: resolve_dependencies (async, mocked ResourceManager) +# --------------------------------------------------------------------------- + + +class TestResolveDependencies: + """Tests for resolve_dependencies with mocked ResourceManager and get_function_source.""" + + def _patch_resolve(self, mock_rm): + """Return combined patch context for ResourceManager and get_function_source.""" + return ( + patch( + "runpod_flash.core.resources.ResourceManager", + return_value=mock_rm, + ), + patch( + "runpod_flash.stubs.dependency_resolver.get_function_source", + side_effect=lambda func: ( + f"async def {func.__name__}(): pass\n", + "hash", + ), + ), + ) + + @pytest.mark.asyncio + async def test_resolves_single_dependency(self): + mock_resource = MagicMock() + mock_resource.id = "ep-resolved-123" + + mock_rm = MagicMock() + mock_rm.get_or_deploy_resource = AsyncMock(return_value=mock_resource) + + rm_patch, gfs_patch = self._patch_resolve(mock_rm) + with rm_patch, gfs_patch: + deps = await resolve_dependencies(_funcA_source, _shared_globals) + + assert len(deps) == 1 + assert deps[0].name == "funcB" + assert deps[0].endpoint_id == "ep-resolved-123" + + @pytest.mark.asyncio + async def test_resolves_multiple_dependencies(self): + mock_resource_b = MagicMock() + mock_resource_b.id = "ep-b" + mock_resource_c = MagicMock() + mock_resource_c.id = "ep-c" + + async def mock_deploy(config): + if config is _funcB.__remote_config__["resource_config"]: + return mock_resource_b + return mock_resource_c + + mock_rm = MagicMock() + mock_rm.get_or_deploy_resource = AsyncMock(side_effect=mock_deploy) + + rm_patch, gfs_patch = self._patch_resolve(mock_rm) + with rm_patch, gfs_patch: + deps = await resolve_dependencies(_funcD_source, _shared_globals) + + assert len(deps) == 2 + names = {d.name for d in deps} + assert names == {"funcB", "funcC"} + + @pytest.mark.asyncio + async def test_no_dependencies_returns_empty(self): + deps = await resolve_dependencies(_funcE_source, _shared_globals) + assert deps == [] + + @pytest.mark.asyncio + async def test_provisioning_failure_raises(self): + mock_rm = MagicMock() + mock_rm.get_or_deploy_resource = AsyncMock( + side_effect=RuntimeError("deploy failed") + ) + + rm_patch, gfs_patch = self._patch_resolve(mock_rm) + with rm_patch, gfs_patch: + with pytest.raises(RuntimeError, match="deploy failed"): + await resolve_dependencies(_funcA_source, _shared_globals) + + +# --------------------------------------------------------------------------- +# Tests: exec() integration — verify augmented source works at runtime +# --------------------------------------------------------------------------- + + +class TestExecIntegration: + def test_exec_augmented_source_defines_both_functions(self): + """When we exec() augmented source, both funcA and the funcB stub exist.""" + dep = RemoteDependency( + name="funcB", + endpoint_id="ep-test", + source=_funcB_source, + dependencies=[], + system_dependencies=[], + ) + stub_code = generate_stub_code(dep) + augmented = build_augmented_source(_funcA_source, [stub_code]) + + namespace: dict = {"_plain_helper": lambda x: x} + exec(compile(augmented, "", "exec"), namespace) + + assert "funcA" in namespace + assert "funcB" in namespace + assert callable(namespace["funcA"]) + assert callable(namespace["funcB"]) diff --git a/tests/unit/test_load_balancer_sls_stub.py b/tests/unit/test_load_balancer_sls_stub.py index 206f8eea..7bb4bd5e 100644 --- a/tests/unit/test_load_balancer_sls_stub.py +++ b/tests/unit/test_load_balancer_sls_stub.py @@ -20,14 +20,15 @@ class TestLoadBalancerSlsStubPrepareRequest: """Test suite for _prepare_request method.""" - def test_prepare_request_with_no_args(self): + @pytest.mark.asyncio + async def test_prepare_request_with_no_args(self): """Test request preparation with no arguments.""" stub = LoadBalancerSlsStub(test_lb_resource) def test_func(): return "result" - request = stub._prepare_request(test_func, None, None, True) + request = await stub._prepare_request(test_func, None, None, True) assert request["function_name"] == "test_func" assert "def test_func" in request["function_code"] @@ -37,7 +38,8 @@ def test_func(): assert "args" not in request or request["args"] == [] assert "kwargs" not in request or request["kwargs"] == {} - def test_prepare_request_with_args(self): + @pytest.mark.asyncio + async def test_prepare_request_with_args(self): """Test request preparation with positional arguments.""" stub = LoadBalancerSlsStub(test_lb_resource) @@ -46,7 +48,7 @@ def add(x, y): arg1 = 5 arg2 = 3 - request = stub._prepare_request(add, None, None, True, arg1, arg2) + request = await stub._prepare_request(add, None, None, True, arg1, arg2) assert request["function_name"] == "add" assert len(request["args"]) == 2 @@ -57,14 +59,15 @@ def add(x, y): assert decoded_arg1 == 5 assert decoded_arg2 == 3 - def test_prepare_request_with_kwargs(self): + @pytest.mark.asyncio + async def test_prepare_request_with_kwargs(self): """Test request preparation with keyword arguments.""" stub = LoadBalancerSlsStub(test_lb_resource) def greet(name, greeting="Hello"): return f"{greeting}, {name}!" - request = stub._prepare_request( + request = await stub._prepare_request( greet, None, None, True, name="Alice", greeting="Hi" ) @@ -79,7 +82,8 @@ def greet(name, greeting="Hello"): assert decoded_name == "Alice" assert decoded_greeting == "Hi" - def test_prepare_request_with_dependencies(self): + @pytest.mark.asyncio + async def test_prepare_request_with_dependencies(self): """Test request preparation includes dependencies.""" stub = LoadBalancerSlsStub(test_lb_resource) @@ -89,7 +93,9 @@ def test_func(): dependencies = ["requests", "numpy"] system_deps = ["git"] - request = stub._prepare_request(test_func, dependencies, system_deps, True) + request = await stub._prepare_request( + test_func, dependencies, system_deps, True + ) assert request["dependencies"] == dependencies assert request["system_dependencies"] == system_deps From f262f654ed161b43ba7e908f7e65360cad35f12a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 12:06:11 -0800 Subject: [PATCH 22/25] refactor(run): group Swagger UI tags by project directory Tags now use the parent directory path instead of per-file worker type labels. Routes from the same project appear under a single collapsible group in the Swagger UI, making multi-worker projects easier to navigate. --- src/runpod_flash/cli/commands/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index d416d9a4..2b739906 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -396,7 +396,11 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat lines.append("") for worker in workers: - tag = f"{worker.url_prefix.lstrip('/')} [{worker.worker_type}]" + # Group routes by project directory in Swagger UI. + # Nested: /03_mixed_workers/cpu_worker -> "03_mixed_workers/" + # Root: /worker -> "worker" + prefix = worker.url_prefix.lstrip("/") + tag = f"{prefix.rsplit('/', 1)[0]}/" if "/" in prefix else prefix lines.append(f"# {'─' * 60}") lines.append(f"# {worker.worker_type}: {worker.file_path.name}") lines.append(f"# {'─' * 60}") From ea426c5d3ebd4c3e3b62445863d9bcb384271511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 12:55:14 -0800 Subject: [PATCH 23/25] fix: address PR 208 review feedback - Guard watcher_thread creation behind reload flag - Fix project_root derivation in build mode (use build_dir, not parent) - Update PRD QB route spec to match /run_sync-only implementation - Harden _cleanup_live_endpoints with granular error handling - Replace bare except in _watch_and_regenerate with specific handlers - Differentiate error types in lb_execute (422 for app errors, 500 for infra) - Add templateId/template mutual exclusivity validation - Improve _flash_import sys.path cleanup with index-based pop - Use explicit event loop in cleanup to avoid nested loop errors - Parallelize dependency provisioning with asyncio.gather - Clarify null-safety in detect_remote_dependencies --- PRD.md | 11 +- .../cli/commands/_run_server_helpers.py | 4 + .../cli/commands/build_utils/manifest.py | 6 +- src/runpod_flash/cli/commands/run.py | 140 ++++++++++-------- src/runpod_flash/core/resources/serverless.py | 6 + src/runpod_flash/stubs/dependency_resolver.py | 34 +++-- 6 files changed, 120 insertions(+), 81 deletions(-) diff --git a/PRD.md b/PRD.md index 2df30adc..d860c347 100644 --- a/PRD.md +++ b/PRD.md @@ -100,8 +100,8 @@ workers/gpu/inference.py → /workers/gpu/inference | Condition | Routes | |---|---| -| One `@remote` function in file | `POST {file_prefix}/run` and `POST {file_prefix}/run_sync` | -| Multiple `@remote` functions in file | `POST {file_prefix}/{fn_name}/run` and `POST {file_prefix}/{fn_name}/run_sync` | +| One `@remote` function in file | `POST {file_prefix}/run_sync` | +| Multiple `@remote` functions in file | `POST {file_prefix}/{fn_name}/run_sync` | ### 5.3 LB route generation @@ -180,9 +180,8 @@ Flash Dev Server http://localhost:8888 Local path Resource Type ────────────────────────────────── ─────────────────── ──── - POST /gpu_worker/run gpu_worker QB POST /gpu_worker/run_sync gpu_worker QB - POST /longruns/stage1/run longruns_stage1 QB + POST /longruns/stage1/run_sync longruns_stage1 QB POST /preprocess/first_pass/compute preprocess_first_pass LB Visit http://localhost:8888/docs for Swagger UI @@ -261,14 +260,12 @@ app = FastAPI( ) # QB: gpu_worker.py -@app.post("/gpu_worker/run", tags=["gpu_worker [QB]"]) @app.post("/gpu_worker/run_sync", tags=["gpu_worker [QB]"]) async def gpu_worker_run(body: dict): result = await gpu_hello(body.get("input", body)) return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} # QB: longruns/stage1.py -@app.post("/longruns/stage1/run", tags=["longruns/stage1 [QB]"]) @app.post("/longruns/stage1/run_sync", tags=["longruns/stage1 [QB]"]) async def longruns_stage1_run(body: dict): result = await stage1_process(body.get("input", body)) @@ -316,7 +313,7 @@ longruns/stage1.py has: stage1_preprocess, stage1_infer ## 13. Edge Cases - **No `@remote` functions found**: Error with clear message and usage instructions -- **Multiple `@remote` functions per file (QB)**: Sub-prefixed routes `/{file_prefix}/{fn_name}/run` +- **Multiple `@remote` functions per file (QB)**: Sub-prefixed routes `/{file_prefix}/{fn_name}/run_sync` - **`__init__.py` files**: Skipped — not treated as worker files - **File path with hyphens** (e.g., `my-worker.py`): Resource name sanitized to `my_worker`, URL prefix `/my-worker` (hyphens valid in URLs, underscores in Python identifiers) - **LB function calling another LB function**: Not supported via `@remote` — emit a warning at build time diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py index 4a92c31f..abf48f06 100644 --- a/src/runpod_flash/cli/commands/_run_server_helpers.py +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -102,5 +102,9 @@ async def lb_execute(resource_config, func, body: dict): raise HTTPException(status_code=504, detail=str(e)) except ConnectionError as e: raise HTTPException(status_code=502, detail=str(e)) + except HTTPException: + raise + except (ValueError, KeyError, TypeError) as e: + raise HTTPException(status_code=422, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index af2a283f..4c5998bf 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -207,8 +207,10 @@ def build(self) -> Dict[str, Any]: str, Dict[str, str] ] = {} # resource_name -> {route_key -> function_name} - # Determine project root for path derivation - project_root = self.build_dir.parent if self.build_dir else Path.cwd() + # Determine project root for path derivation. + # build_dir is .flash/.build which *contains* the copied project files, + # so use it directly (not its parent, which would be .flash/). + project_root = self.build_dir if self.build_dir else Path.cwd() for resource_name, functions in sorted(resources.items()): # Use actual resource type from first function in group diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 2b739906..db2d354c 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -287,9 +287,12 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat " try:", " return getattr(_importlib.import_module(module_path), name)", " finally:", - " if _path:", + " if _path is not None:", " try:", - " sys.path.remove(_path)", + " if sys.path and sys.path[0] == _path:", + " sys.path.pop(0)", + " else:", + " sys.path.remove(_path)", " except ValueError:", " pass", "", @@ -612,68 +615,74 @@ def _cleanup_live_endpoints() -> None: if not _RESOURCE_STATE_FILE.exists(): return - try: - import asyncio - import cloudpickle - from ...core.utils.file_lock import file_lock + import asyncio + import cloudpickle + from ...core.utils.file_lock import file_lock + # Load persisted resource state. If this fails (lock error, corruption), + # log and return — don't let it prevent the rest of shutdown. + try: with open(_RESOURCE_STATE_FILE, "rb") as f: with file_lock(f, exclusive=False): data = cloudpickle.load(f) + except Exception as e: + logger.warning(f"Could not read resource state for cleanup: {e}") + return - if isinstance(data, tuple): - resources, configs = data - else: - resources, configs = data, {} + if isinstance(data, tuple): + resources, configs = data + else: + resources, configs = data, {} - live_items = { - key: resource - for key, resource in resources.items() - if hasattr(resource, "name") - and resource.name - and resource.name.startswith("live-") - } + live_items = { + key: resource + for key, resource in resources.items() + if hasattr(resource, "name") + and resource.name + and resource.name.startswith("live-") + } - if not live_items: - return + if not live_items: + return - import time + import time - async def _do_cleanup(): - undeployed = 0 - for key, resource in live_items.items(): - name = getattr(resource, "name", key) - try: - success = await resource._do_undeploy() - if success: - console.print(f" Deprovisioned: {name}") - undeployed += 1 - else: - logger.warning(f"Failed to deprovision: {name}") - except Exception as e: - logger.warning(f"Error deprovisioning {name}: {e}") - return undeployed - - t0 = time.monotonic() - undeployed = asyncio.run(_do_cleanup()) - elapsed = time.monotonic() - t0 - console.print( - f" Cleanup completed: {undeployed}/{len(live_items)} " - f"resource(s) undeployed in {elapsed:.1f}s" - ) + async def _do_cleanup(): + undeployed = 0 + for key, resource in live_items.items(): + name = getattr(resource, "name", key) + try: + success = await resource._do_undeploy() + if success: + console.print(f" Deprovisioned: {name}") + undeployed += 1 + else: + logger.warning(f"Failed to deprovision: {name}") + except Exception as e: + logger.warning(f"Error deprovisioning {name}: {e}") + return undeployed - # Remove live- entries from persisted state so they don't linger. - remaining = {k: v for k, v in resources.items() if k not in live_items} - remaining_configs = {k: v for k, v in configs.items() if k not in live_items} - try: - with open(_RESOURCE_STATE_FILE, "wb") as f: - with file_lock(f, exclusive=True): - cloudpickle.dump((remaining, remaining_configs), f) - except Exception as e: - logger.warning(f"Could not update resource state after cleanup: {e}") + t0 = time.monotonic() + loop = asyncio.new_event_loop() + try: + undeployed = loop.run_until_complete(_do_cleanup()) + finally: + loop.close() + elapsed = time.monotonic() - t0 + console.print( + f" Cleanup completed: {undeployed}/{len(live_items)} " + f"resource(s) undeployed in {elapsed:.1f}s" + ) + # Remove live- entries from persisted state so they don't linger. + remaining = {k: v for k, v in resources.items() if k not in live_items} + remaining_configs = {k: v for k, v in configs.items() if k not in live_items} + try: + with open(_RESOURCE_STATE_FILE, "wb") as f: + with file_lock(f, exclusive=True): + cloudpickle.dump((remaining, remaining_configs), f) except Exception as e: - logger.warning(f"Live endpoint cleanup failed: {e}") + logger.warning(f"Could not update resource state after cleanup: {e}") def _is_reload() -> bool: @@ -707,8 +716,11 @@ def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> No logger.debug("server.py regenerated (%d changed)", len(py_changed)) except Exception as e: logger.warning("Failed to regenerate server.py: %s", e) - except Exception: - pass # stop_event was set or watchfiles unavailable — both are fine + except ModuleNotFoundError as e: + logger.warning("File watching disabled: %s", e) + except Exception as e: + if not stop_event.is_set(): + logger.exception("Unexpected error in file watcher: %s", e) def _discover_resources(project_root: Path): @@ -887,12 +899,14 @@ def run_command( ] stop_event = threading.Event() - watcher_thread = threading.Thread( - target=_watch_and_regenerate, - args=(project_root, stop_event), - daemon=True, - name="flash-watcher", - ) + watcher_thread = None + if reload: + watcher_thread = threading.Thread( + target=_watch_and_regenerate, + args=(project_root, stop_event), + daemon=True, + name="flash-watcher", + ) process = None try: @@ -903,7 +917,7 @@ def run_command( else: process = subprocess.Popen(cmd, preexec_fn=os.setsid) - if reload: + if watcher_thread is not None: watcher_thread.start() process.wait() @@ -912,7 +926,7 @@ def run_command( console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") stop_event.set() - if watcher_thread.is_alive(): + if watcher_thread is not None and watcher_thread.is_alive(): watcher_thread.join(timeout=2) if process: @@ -942,7 +956,7 @@ def run_command( console.print(f"[red]Error:[/red] {e}") stop_event.set() - if watcher_thread.is_alive(): + if watcher_thread is not None and watcher_thread.is_alive(): watcher_thread.join(timeout=2) if process: diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 4793d5f7..09470549 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -479,6 +479,7 @@ def is_deployed(self) -> bool: # endpoint exists but the health API hasn't registered it yet. # Trusting the cached ID is correct here; actual failures surface # on the first real run/run_sync call. + # Case-insensitive check; unset env var defaults to "" via getenv. if os.getenv("FLASH_IS_LIVE_PROVISIONING", "").lower() == "true": return True @@ -495,6 +496,11 @@ def _payload_exclude(self) -> Set[str]: # When templateId is already set, exclude template from the payload. # RunPod rejects requests that contain both fields simultaneously. if self.templateId: + if self.template is not None: + raise ValueError( + "Invalid state: both 'templateId' and 'template' are set. " + "Only one may be provided." + ) exclude_fields.add("template") return exclude_fields diff --git a/src/runpod_flash/stubs/dependency_resolver.py b/src/runpod_flash/stubs/dependency_resolver.py index 97b8a094..f8449fb0 100644 --- a/src/runpod_flash/stubs/dependency_resolver.py +++ b/src/runpod_flash/stubs/dependency_resolver.py @@ -56,7 +56,7 @@ def detect_remote_dependencies(source: str, func_globals: dict[str, Any]) -> lis remote_deps = [ name for name in sorted(called_names) - if hasattr(func_globals.get(name), "__remote_config__") + if name in func_globals and hasattr(func_globals[name], "__remote_config__") ] return remote_deps @@ -85,29 +85,45 @@ async def resolve_dependencies( if not dep_names: return [] + import asyncio + from ..core.resources import ResourceManager resource_manager = ResourceManager() - resolved: list[RemoteDependency] = [] + # Gather metadata needed for each dependency before parallel provisioning. + dep_info: list[tuple[str, Any, str, list[str], list[str]]] = [] for name in dep_names: dep_func = func_globals[name] config = dep_func.__remote_config__ - - resource_config = config["resource_config"] - remote_resource = await resource_manager.get_or_deploy_resource(resource_config) - - # Get source of the dependency function unwrapped = inspect.unwrap(dep_func) dep_source, _ = get_function_source(unwrapped) + dep_info.append( + ( + name, + config["resource_config"], + dep_source, + config.get("dependencies") or [], + config.get("system_dependencies") or [], + ) + ) + + # Provision all endpoints in parallel. + remote_resources = await asyncio.gather( + *(resource_manager.get_or_deploy_resource(rc) for _, rc, _, _, _ in dep_info) + ) + resolved: list[RemoteDependency] = [] + for (name, _, dep_source, deps, sys_deps), remote_resource in zip( + dep_info, remote_resources + ): resolved.append( RemoteDependency( name=name, endpoint_id=remote_resource.id, source=dep_source, - dependencies=config.get("dependencies") or [], - system_dependencies=config.get("system_dependencies") or [], + dependencies=deps, + system_dependencies=sys_deps, ) ) log.debug( From 9f1928d4b0d250518256a005918eadb16b102c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 13:37:21 -0800 Subject: [PATCH 24/25] fix(serverless): prevent ValueError when deploy mutates config with both templateId and template _payload_exclude() raised ValueError after _do_deploy() set templateId on a config object that already had template from initialization. Remove the raise in favor of silently excluding template (templateId takes precedence), and clear self.template after deploy mutation to prevent the inconsistent state at its source. --- src/runpod_flash/core/resources/serverless.py | 10 +++++----- tests/unit/resources/test_serverless.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 09470549..869b85a8 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -495,12 +495,9 @@ def _payload_exclude(self) -> Set[str]: exclude_fields.discard("flashEnvironmentId") # When templateId is already set, exclude template from the payload. # RunPod rejects requests that contain both fields simultaneously. + # Both can coexist after deploy mutates config (sets templateId while + # template remains from initialization) — templateId takes precedence. if self.templateId: - if self.template is not None: - raise ValueError( - "Invalid state: both 'templateId' and 'template' are set. " - "Only one may be provided." - ) exclude_fields.add("template") return exclude_fields @@ -679,6 +676,9 @@ async def _do_deploy(self) -> "DeployableResource": endpoint = await self._sync_graphql_object_with_inputs(endpoint) self.id = endpoint.id self.templateId = endpoint.templateId + self.template = ( + None # templateId takes precedence; clear to avoid conflict + ) return endpoint raise ValueError("Deployment failed, no endpoint was returned.") diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 124eb136..25899904 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -973,6 +973,24 @@ def test_payload_exclude_adds_template_when_template_id_set(self): assert "template" in excluded + def test_payload_exclude_tolerates_both_template_id_and_template(self): + """_payload_exclude does not raise when both templateId and template are set. + + After deploy mutates the config object, both fields can coexist. + templateId takes precedence and template should be excluded. + """ + serverless = ServerlessResource(name="test") + serverless.templateId = "tmpl-123" + serverless.template = PodTemplate( + name="test-template", + imageName="runpod/test:latest", + containerDiskInGb=20, + ) + + excluded = serverless._payload_exclude() + + assert "template" in excluded + def test_payload_exclude_does_not_exclude_template_without_template_id(self): """_payload_exclude does not exclude template when templateId is absent.""" serverless = ServerlessResource(name="test") From 35af650f0419898cee9d805b492636801219986c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 20 Feb 2026 15:14:30 -0800 Subject: [PATCH 25/25] chore: remove PRD.md from branch Not meant to be committed; internal planning document. --- PRD.md | 338 --------------------------------------------------------- 1 file changed, 338 deletions(-) delete mode 100644 PRD.md diff --git a/PRD.md b/PRD.md deleted file mode 100644 index d860c347..00000000 --- a/PRD.md +++ /dev/null @@ -1,338 +0,0 @@ -# Flash SDK: Zero-Boilerplate Experience — Product Requirements Document - -## 1. Problem Statement - -Flash currently forces every project into a FastAPI-first model: - -- Users must create `main.py` with a `FastAPI()` instance -- HTTP routing boilerplate adds no semantic value — the routes simply call `@remote` functions -- No straightforward path for deploying a standalone QB function without wrapping it in a FastAPI app -- The "mothership" concept introduces an implicit coordinator with no clear ownership model -- `flash run` fails unless `main.py` exists with a FastAPI app, blocking the simplest use cases - -## 2. Goals - -- **Zero boilerplate**: a `@remote`-decorated function in any `.py` file is sufficient for `flash run` and `flash deploy` -- **File-system-as-namespace**: the project directory structure maps 1:1 to URL paths on the local dev server -- **Single command**: `flash run` works for all project topologies (one QB function, many files, mixed QB+LB) without any configuration -- **`flash deploy` requires no additional configuration** beyond the `@remote` declarations themselves -- **Peer endpoints**: every `@resource_config` is a first-class endpoint; no implicit coordinator - -## 3. Non-Goals - -- No backward compatibility with `main.py`/FastAPI-first style -- No implicit "mothership" concept; all endpoints are peers -- No changes to the QB runtime (`generic_handler.py`) or QB stub behavior -- No changes to deployed endpoint behavior (RunPod QB/LB APIs are unchanged) - -## 4. Developer Experience Specification - -### 4.1 Minimum viable QB project - -```python -# gpu_worker.py -from runpod_flash import LiveServerless, GpuGroup, remote - -gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) - -@remote(gpu_config) -async def process(input_data: dict) -> dict: - return {"result": "processed", "input": input_data} -``` - -`flash run` → `POST /gpu_worker/run_sync` -`flash deploy` → standalone QB endpoint at `api.runpod.ai/v2/{id}/run` - -### 4.2 LB endpoint - -```python -# api/routes.py -from runpod_flash import CpuLiveLoadBalancer, remote - -lb_config = CpuLiveLoadBalancer(name="api_routes") - -@remote(lb_config, method="POST", path="/compute") -async def compute(input_data: dict) -> dict: - return {"result": input_data} -``` - -`flash run` → `POST /api/routes/compute` -`flash deploy` → LB endpoint at `{id}.api.runpod.ai/compute` - -### 4.3 Mixed QB + LB (LB calling QB) - -```python -# api/routes.py (LB) -from runpod_flash import CpuLiveLoadBalancer, remote -from workers.gpu import heavy_compute # QB stub - -lb_config = CpuLiveLoadBalancer(name="api_routes") - -@remote(lb_config, method="POST", path="/process") -async def process_route(data: dict): - return await heavy_compute(data) # dispatches to QB endpoint - -# workers/gpu.py (QB) -from runpod_flash import LiveServerless, GpuGroup, remote - -gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) - -@remote(gpu_config) -async def heavy_compute(data: dict) -> dict: ... -``` - -## 5. URL Path Specification - -### 5.1 File prefix derivation - -The local dev server uses the project directory structure as a URL namespace. Each file's URL prefix is its path relative to the project root with `.py` stripped: - -``` -File Local URL prefix -────────────────────────────── ──────────────────────────── -gpu_worker.py → /gpu_worker -longruns/stage1.py → /longruns/stage1 -preprocess/first_pass.py → /preprocess/first_pass -workers/gpu/inference.py → /workers/gpu/inference -``` - -### 5.2 QB route generation - -| Condition | Routes | -|---|---| -| One `@remote` function in file | `POST {file_prefix}/run_sync` | -| Multiple `@remote` functions in file | `POST {file_prefix}/{fn_name}/run_sync` | - -### 5.3 LB route generation - -| Condition | Route | -|---|---| -| `@remote(lb_config, method="POST", path="/compute")` | `POST {file_prefix}/compute` | - -The declared `path=` is appended to the file prefix. The `method=` determines the HTTP verb. - -### 5.4 QB request/response envelope - -Mirrors RunPod's API for consistency: - -``` -POST /gpu_worker/run_sync -Body: {"input": {"key": "value"}} -Response: {"id": "uuid", "status": "COMPLETED", "output": {...}} -``` - -## 6. Deployed Topology Specification - -Each unique resource config gets its own RunPod endpoint: - -| Type | Deployed URL | Example | -|---|---|---| -| QB | `https://api.runpod.ai/v2/{endpoint_id}/run` | `https://api.runpod.ai/v2/uoy3n7hkyb052a/run` | -| QB sync | `https://api.runpod.ai/v2/{endpoint_id}/run_sync` | | -| LB | `https://{endpoint_id}.api.runpod.ai/{declared_path}` | `https://rzlk6lph6gw7dk.api.runpod.ai/compute` | - -## 7. `.flash/` Folder Specification - -All generated artifacts go to `.flash/` in the project root. Auto-created, gitignored, never committed. - -``` -my_project/ -├── gpu_worker.py -├── longruns/ -│ └── stage1.py -└── .flash/ - ├── server.py ← generated by flash run - └── manifest.json ← generated by flash build -``` - -- `.flash/` is added to `.gitignore` automatically on first `flash run` -- `server.py` and `manifest.json` are overwritten on each run/build; other files preserved -- The `.flash/` directory itself is never committed - -### 7.1 Dev server launch - -Uvicorn is launched with `--app-dir .flash/` so `server:app` is importable. The server inserts the project root into `sys.path` so user modules resolve: - -```bash -uvicorn server:app \ - --app-dir .flash/ \ - --reload \ - --reload-dir . \ - --reload-include "*.py" -``` - -## 8. `flash run` Behavior - -1. Scan project for all `@remote` functions (QB and LB) in any `.py` file - - Skip: `.flash/`, `__pycache__`, `*.pyc`, `__init__.py` -2. If none found: print error with usage instructions, exit 1 -3. Generate `.flash/server.py` with routes for all discovered functions -4. Add `.flash/` to `.gitignore` if not already present -5. Start uvicorn with `--reload` watching both `.flash/` and project root -6. Print startup table: local paths → resource names → types -7. Swagger UI available at `http://localhost:{port}/docs` -8. On exit (Ctrl+C or SIGTERM): deprovision all Live Serverless endpoints provisioned during this session - -### 8.1 Startup table format - -``` -Flash Dev Server http://localhost:8888 - - Local path Resource Type - ────────────────────────────────── ─────────────────── ──── - POST /gpu_worker/run_sync gpu_worker QB - POST /longruns/stage1/run_sync longruns_stage1 QB - POST /preprocess/first_pass/compute preprocess_first_pass LB - - Visit http://localhost:8888/docs for Swagger UI -``` - -## 9. `flash build` Behavior - -1. Scan project for all `@remote` functions (QB and LB) -2. Build `.flash/manifest.json` with flat resource structure (see §10) -3. For LB resources: generate deployed handler files using `module_path` -4. Package build artifact - -## 10. Manifest Structure - -Resource names are derived from file paths (slashes → underscores): - -```json -{ - "version": "1.0", - "project_name": "my_project", - "resources": { - "gpu_worker": { - "resource_type": "LiveServerless", - "file_path": "gpu_worker.py", - "local_path_prefix": "/gpu_worker", - "module_path": "gpu_worker", - "functions": ["gpu_hello"], - "is_load_balanced": false, - "makes_remote_calls": false - }, - "longruns_stage1": { - "resource_type": "LiveServerless", - "file_path": "longruns/stage1.py", - "local_path_prefix": "/longruns/stage1", - "module_path": "longruns.stage1", - "functions": ["stage1_process"], - "is_load_balanced": false, - "makes_remote_calls": false - }, - "preprocess_first_pass": { - "resource_type": "CpuLiveLoadBalancer", - "file_path": "preprocess/first_pass.py", - "local_path_prefix": "/preprocess/first_pass", - "module_path": "preprocess.first_pass", - "functions": [ - {"name": "first_pass_fn", "http_method": "POST", "http_path": "/compute"} - ], - "is_load_balanced": true, - "makes_remote_calls": true - } - } -} -``` - -## 11. `.flash/server.py` Structure - -```python -"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run.""" -import sys -import uuid -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from fastapi import FastAPI - -# QB imports -from gpu_worker import gpu_hello -from longruns.stage1 import stage1_process - -# LB imports -from preprocess.first_pass import first_pass_fn - -app = FastAPI( - title="Flash Dev Server", - description="Auto-generated by `flash run`. Visit /docs for interactive testing.", -) - -# QB: gpu_worker.py -@app.post("/gpu_worker/run_sync", tags=["gpu_worker [QB]"]) -async def gpu_worker_run(body: dict): - result = await gpu_hello(body.get("input", body)) - return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} - -# QB: longruns/stage1.py -@app.post("/longruns/stage1/run_sync", tags=["longruns/stage1 [QB]"]) -async def longruns_stage1_run(body: dict): - result = await stage1_process(body.get("input", body)) - return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} - -# LB: preprocess/first_pass.py -@app.post("/preprocess/first_pass/compute", tags=["preprocess/first_pass [LB]"]) -async def _route_first_pass_compute(body: dict): - return await first_pass_fn(body) - -# Health -@app.get("/", tags=["health"]) -def home(): - return {"message": "Flash Dev Server", "docs": "/docs"} - -@app.get("/ping", tags=["health"]) -def ping(): - return {"status": "healthy"} -``` - -Subdirectory imports use dotted module paths: `longruns/stage1.py` → `from longruns.stage1 import fn`. - -Multi-function QB files (2+ `@remote` functions) get sub-prefixed routes: -``` -longruns/stage1.py has: stage1_preprocess, stage1_infer -→ POST /longruns/stage1/stage1_preprocess/run -→ POST /longruns/stage1/stage1_preprocess/run_sync -→ POST /longruns/stage1/stage1_infer/run -→ POST /longruns/stage1/stage1_infer/run_sync -``` - -## 12. Acceptance Criteria - -- [ ] A file with one `@remote(QB_config)` function and nothing else is a valid Flash project -- [ ] `flash run` produces a Swagger UI showing all routes grouped by source file -- [ ] QB routes accept `{"input": {...}}` and return `{"id": ..., "status": "COMPLETED", "output": {...}}` -- [ ] Subdirectory files produce URL prefixes matching their relative path -- [ ] Multiple `@remote` functions in one file each get their own sub-prefixed routes -- [ ] LB route handler body executes directly (not dispatched remotely) -- [ ] QB calls inside LB route handler body route to the remote QB endpoint -- [ ] `flash deploy` creates a RunPod endpoint for each resource config -- [ ] `flash build` produces `.flash/manifest.json` with `file_path`, `local_path_prefix`, `module_path` per resource -- [ ] When `flash run` exits, all Live Serverless endpoints provisioned during that session are automatically undeployed - -## 13. Edge Cases - -- **No `@remote` functions found**: Error with clear message and usage instructions -- **Multiple `@remote` functions per file (QB)**: Sub-prefixed routes `/{file_prefix}/{fn_name}/run_sync` -- **`__init__.py` files**: Skipped — not treated as worker files -- **File path with hyphens** (e.g., `my-worker.py`): Resource name sanitized to `my_worker`, URL prefix `/my-worker` (hyphens valid in URLs, underscores in Python identifiers) -- **LB function calling another LB function**: Not supported via `@remote` — emit a warning at build time -- **`.flash/` already exists**: `server.py` and `manifest.json` overwritten; other files preserved -- **`flash deploy` with no LB endpoints**: QB-only deploy -- **Subdirectory `__init__.py`** imports needed: Generator checks and warns if missing - -## 14. Implementation Files - -| File | Change | -|------|--------| -| `flash/main/PRD.md` | This document | -| `src/runpod_flash/client.py` | Passthrough for LB route handlers (`__is_lb_route_handler__`) | -| `cli/commands/run.py` | Unified server generation; `--app-dir .flash/`; file-path-based route discovery | -| `cli/commands/build_utils/scanner.py` | Path utilities; `is_lb_route_handler` field; file-based resource identity | -| `cli/commands/build_utils/manifest.py` | Flat resource structure; `file_path`/`local_path_prefix`/`module_path` fields | -| `cli/commands/build_utils/lb_handler_generator.py` | Import module by `module_path`, walk `__is_lb_route_handler__`, register routes | -| `cli/commands/build.py` | Remove main.py requirement from `validate_project_structure` | -| `core/resources/serverless.py` | Inject `FLASH_MODULE_PATH` env var | -| `flash-examples/.../01_hello_world/` | Rewrite to bare minimum | -| `flash-examples/.../00_standalone_worker/` | New | -| `flash-examples/.../00_multi_resource/` | New |