diff --git a/PRD.md b/PRD.md new file mode 100644 index 00000000..a5d9d98f --- /dev/null +++ b/PRD.md @@ -0,0 +1,341 @@ +# Flash SDK: Zero-Boilerplate Experience — Product Requirements Document + +## 1. Problem Statement + +Flash currently forces every project into a FastAPI-first model: + +- Users must create `main.py` with a `FastAPI()` instance +- HTTP routing boilerplate adds no semantic value — the routes simply call `@remote` functions +- No straightforward path for deploying a standalone QB function without wrapping it in a FastAPI app +- The "mothership" concept introduces an implicit coordinator with no clear ownership model +- `flash run` fails unless `main.py` exists with a FastAPI app, blocking the simplest use cases + +## 2. Goals + +- **Zero boilerplate**: a `@remote`-decorated function in any `.py` file is sufficient for `flash run` and `flash deploy` +- **File-system-as-namespace**: the project directory structure maps 1:1 to URL paths on the local dev server +- **Single command**: `flash run` works for all project topologies (one QB function, many files, mixed QB+LB) without any configuration +- **`flash deploy` requires no additional configuration** beyond the `@remote` declarations themselves +- **Peer endpoints**: every `@resource_config` is a first-class endpoint; no implicit coordinator + +## 3. Non-Goals + +- No backward compatibility with `main.py`/FastAPI-first style +- No implicit "mothership" concept; all endpoints are peers +- No changes to the QB runtime (`generic_handler.py`) or QB stub behavior +- No changes to deployed endpoint behavior (RunPod QB/LB APIs are unchanged) + +## 4. Developer Experience Specification + +### 4.1 Minimum viable QB project + +```python +# gpu_worker.py +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(input_data: dict) -> dict: + return {"result": "processed", "input": input_data} +``` + +`flash run` → `POST /gpu_worker/run` and `POST /gpu_worker/run_sync` +`flash deploy` → standalone QB endpoint at `api.runpod.ai/v2/{id}/run` + +### 4.2 LB endpoint + +```python +# api/routes.py +from runpod_flash import CpuLiveLoadBalancer, remote + +lb_config = CpuLiveLoadBalancer(name="api_routes") + +@remote(lb_config, method="POST", path="/compute") +async def compute(input_data: dict) -> dict: + return {"result": input_data} +``` + +`flash run` → `POST /api/routes/compute` +`flash deploy` → LB endpoint at `{id}.api.runpod.ai/compute` + +### 4.3 Mixed QB + LB (LB calling QB) + +```python +# api/routes.py (LB) +from runpod_flash import CpuLiveLoadBalancer, remote +from workers.gpu import heavy_compute # QB stub + +lb_config = CpuLiveLoadBalancer(name="api_routes") + +@remote(lb_config, method="POST", path="/process") +async def process_route(data: dict): + return await heavy_compute(data) # dispatches to QB endpoint + +# workers/gpu.py (QB) +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def heavy_compute(data: dict) -> dict: ... +``` + +## 5. URL Path Specification + +### 5.1 File prefix derivation + +The local dev server uses the project directory structure as a URL namespace. Each file's URL prefix is its path relative to the project root with `.py` stripped: + +``` +File Local URL prefix +────────────────────────────── ──────────────────────────── +gpu_worker.py → /gpu_worker +longruns/stage1.py → /longruns/stage1 +preprocess/first_pass.py → /preprocess/first_pass +workers/gpu/inference.py → /workers/gpu/inference +``` + +### 5.2 QB route generation + +| Condition | Routes | +|---|---| +| One `@remote` function in file | `POST {file_prefix}/run` and `POST {file_prefix}/run_sync` | +| Multiple `@remote` functions in file | `POST {file_prefix}/{fn_name}/run` and `POST {file_prefix}/{fn_name}/run_sync` | + +### 5.3 LB route generation + +| Condition | Route | +|---|---| +| `@remote(lb_config, method="POST", path="/compute")` | `POST {file_prefix}/compute` | + +The declared `path=` is appended to the file prefix. The `method=` determines the HTTP verb. + +### 5.4 QB request/response envelope + +Mirrors RunPod's API for consistency: + +``` +POST /gpu_worker/run_sync +Body: {"input": {"key": "value"}} +Response: {"id": "uuid", "status": "COMPLETED", "output": {...}} +``` + +## 6. Deployed Topology Specification + +Each unique resource config gets its own RunPod endpoint: + +| Type | Deployed URL | Example | +|---|---|---| +| QB | `https://api.runpod.ai/v2/{endpoint_id}/run` | `https://api.runpod.ai/v2/uoy3n7hkyb052a/run` | +| QB sync | `https://api.runpod.ai/v2/{endpoint_id}/run_sync` | | +| LB | `https://{endpoint_id}.api.runpod.ai/{declared_path}` | `https://rzlk6lph6gw7dk.api.runpod.ai/compute` | + +## 7. `.flash/` Folder Specification + +All generated artifacts go to `.flash/` in the project root. Auto-created, gitignored, never committed. + +``` +my_project/ +├── gpu_worker.py +├── longruns/ +│ └── stage1.py +└── .flash/ + ├── server.py ← generated by flash run + └── manifest.json ← generated by flash build +``` + +- `.flash/` is added to `.gitignore` automatically on first `flash run` +- `server.py` and `manifest.json` are overwritten on each run/build; other files preserved +- The `.flash/` directory itself is never committed + +### 7.1 Dev server launch + +Uvicorn is launched with `--app-dir .flash/` so `server:app` is importable. The server inserts the project root into `sys.path` so user modules resolve: + +```bash +uvicorn server:app \ + --app-dir .flash/ \ + --reload \ + --reload-dir . \ + --reload-include "*.py" +``` + +## 8. `flash run` Behavior + +1. Scan project for all `@remote` functions (QB and LB) in any `.py` file + - Skip: `.flash/`, `__pycache__`, `*.pyc`, `__init__.py` +2. If none found: print error with usage instructions, exit 1 +3. Generate `.flash/server.py` with routes for all discovered functions +4. Add `.flash/` to `.gitignore` if not already present +5. Start uvicorn with `--reload` watching both `.flash/` and project root +6. Print startup table: local paths → resource names → types +7. Swagger UI available at `http://localhost:{port}/docs` +8. On exit (Ctrl+C or SIGTERM): deprovision all Live Serverless endpoints provisioned during this session + +### 8.1 Startup table format + +``` +Flash Dev Server http://localhost:8888 + + Local path Resource Type + ────────────────────────────────── ─────────────────── ──── + POST /gpu_worker/run gpu_worker QB + POST /gpu_worker/run_sync gpu_worker QB + POST /longruns/stage1/run longruns_stage1 QB + POST /preprocess/first_pass/compute preprocess_first_pass LB + + Visit http://localhost:8888/docs for Swagger UI +``` + +## 9. `flash build` Behavior + +1. Scan project for all `@remote` functions (QB and LB) +2. Build `.flash/manifest.json` with flat resource structure (see §10) +3. For LB resources: generate deployed handler files using `module_path` +4. Package build artifact + +## 10. Manifest Structure + +Resource names are derived from file paths (slashes → underscores): + +```json +{ + "version": "1.0", + "project_name": "my_project", + "resources": { + "gpu_worker": { + "resource_type": "LiveServerless", + "file_path": "gpu_worker.py", + "local_path_prefix": "/gpu_worker", + "module_path": "gpu_worker", + "functions": ["gpu_hello"], + "is_load_balanced": false, + "makes_remote_calls": false + }, + "longruns_stage1": { + "resource_type": "LiveServerless", + "file_path": "longruns/stage1.py", + "local_path_prefix": "/longruns/stage1", + "module_path": "longruns.stage1", + "functions": ["stage1_process"], + "is_load_balanced": false, + "makes_remote_calls": false + }, + "preprocess_first_pass": { + "resource_type": "CpuLiveLoadBalancer", + "file_path": "preprocess/first_pass.py", + "local_path_prefix": "/preprocess/first_pass", + "module_path": "preprocess.first_pass", + "functions": [ + {"name": "first_pass_fn", "http_method": "POST", "http_path": "/compute"} + ], + "is_load_balanced": true, + "makes_remote_calls": true + } + } +} +``` + +## 11. `.flash/server.py` Structure + +```python +"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run.""" +import sys +import uuid +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from fastapi import FastAPI + +# QB imports +from gpu_worker import gpu_hello +from longruns.stage1 import stage1_process + +# LB imports +from preprocess.first_pass import first_pass_fn + +app = FastAPI( + title="Flash Dev Server", + description="Auto-generated by `flash run`. Visit /docs for interactive testing.", +) + +# QB: gpu_worker.py +@app.post("/gpu_worker/run", tags=["gpu_worker [QB]"]) +@app.post("/gpu_worker/run_sync", tags=["gpu_worker [QB]"]) +async def gpu_worker_run(body: dict): + result = await gpu_hello(body.get("input", body)) + return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} + +# QB: longruns/stage1.py +@app.post("/longruns/stage1/run", tags=["longruns/stage1 [QB]"]) +@app.post("/longruns/stage1/run_sync", tags=["longruns/stage1 [QB]"]) +async def longruns_stage1_run(body: dict): + result = await stage1_process(body.get("input", body)) + return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result} + +# LB: preprocess/first_pass.py +@app.post("/preprocess/first_pass/compute", tags=["preprocess/first_pass [LB]"]) +async def _route_first_pass_compute(body: dict): + return await first_pass_fn(body) + +# Health +@app.get("/", tags=["health"]) +def home(): + return {"message": "Flash Dev Server", "docs": "/docs"} + +@app.get("/ping", tags=["health"]) +def ping(): + return {"status": "healthy"} +``` + +Subdirectory imports use dotted module paths: `longruns/stage1.py` → `from longruns.stage1 import fn`. + +Multi-function QB files (2+ `@remote` functions) get sub-prefixed routes: +``` +longruns/stage1.py has: stage1_preprocess, stage1_infer +→ POST /longruns/stage1/stage1_preprocess/run +→ POST /longruns/stage1/stage1_preprocess/run_sync +→ POST /longruns/stage1/stage1_infer/run +→ POST /longruns/stage1/stage1_infer/run_sync +``` + +## 12. Acceptance Criteria + +- [ ] A file with one `@remote(QB_config)` function and nothing else is a valid Flash project +- [ ] `flash run` produces a Swagger UI showing all routes grouped by source file +- [ ] QB routes accept `{"input": {...}}` and return `{"id": ..., "status": "COMPLETED", "output": {...}}` +- [ ] Subdirectory files produce URL prefixes matching their relative path +- [ ] Multiple `@remote` functions in one file each get their own sub-prefixed routes +- [ ] LB route handler body executes directly (not dispatched remotely) +- [ ] QB calls inside LB route handler body route to the remote QB endpoint +- [ ] `flash deploy` creates a RunPod endpoint for each resource config +- [ ] `flash build` produces `.flash/manifest.json` with `file_path`, `local_path_prefix`, `module_path` per resource +- [ ] When `flash run` exits, all Live Serverless endpoints provisioned during that session are automatically undeployed + +## 13. Edge Cases + +- **No `@remote` functions found**: Error with clear message and usage instructions +- **Multiple `@remote` functions per file (QB)**: Sub-prefixed routes `/{file_prefix}/{fn_name}/run` +- **`__init__.py` files**: Skipped — not treated as worker files +- **File path with hyphens** (e.g., `my-worker.py`): Resource name sanitized to `my_worker`, URL prefix `/my-worker` (hyphens valid in URLs, underscores in Python identifiers) +- **LB function calling another LB function**: Not supported via `@remote` — emit a warning at build time +- **`.flash/` already exists**: `server.py` and `manifest.json` overwritten; other files preserved +- **`flash deploy` with no LB endpoints**: QB-only deploy +- **Subdirectory `__init__.py`** imports needed: Generator checks and warns if missing + +## 14. Implementation Files + +| File | Change | +|------|--------| +| `flash/main/PRD.md` | This document | +| `src/runpod_flash/client.py` | Passthrough for LB route handlers (`__is_lb_route_handler__`) | +| `cli/commands/run.py` | Unified server generation; `--app-dir .flash/`; file-path-based route discovery | +| `cli/commands/build_utils/scanner.py` | Path utilities; `is_lb_route_handler` field; file-based resource identity | +| `cli/commands/build_utils/manifest.py` | Flat resource structure; `file_path`/`local_path_prefix`/`module_path` fields | +| `cli/commands/build_utils/lb_handler_generator.py` | Import module by `module_path`, walk `__is_lb_route_handler__`, register routes | +| `cli/commands/build.py` | Remove main.py requirement from `validate_project_structure` | +| `core/resources/serverless.py` | Inject `FLASH_MODULE_PATH` env var | +| `flash-examples/.../01_hello_world/` | Rewrite to bare minimum | +| `flash-examples/.../00_standalone_worker/` | New | +| `flash-examples/.../00_multi_resource/` | New | diff --git a/src/runpod_flash/cli/commands/build.py b/src/runpod_flash/cli/commands/build.py index 00b152ff..6b325de0 100644 --- a/src/runpod_flash/cli/commands/build.py +++ b/src/runpod_flash/cli/commands/build.py @@ -23,6 +23,7 @@ from runpod_flash.core.resources.constants import MAX_TARBALL_SIZE_MB from ..utils.ignore import get_file_tree, load_ignore_patterns +from .build_utils.lb_handler_generator import LBHandlerGenerator from .build_utils.manifest import ManifestBuilder from .build_utils.scanner import RemoteDecoratorScanner @@ -239,6 +240,9 @@ def run_build( manifest_path = build_dir / "flash_manifest.json" manifest_path.write_text(json.dumps(manifest, indent=2)) + lb_generator = LBHandlerGenerator(manifest, build_dir) + lb_generator.generate_handlers() + flash_dir = project_dir / ".flash" deployment_manifest_path = flash_dir / "flash_manifest.json" shutil.copy2(manifest_path, deployment_manifest_path) @@ -425,28 +429,19 @@ def validate_project_structure(project_dir: Path) -> bool: """ Validate that directory is a Flash project. + A Flash project is any directory containing Python files. The + RemoteDecoratorScanner validates that @remote functions exist. + Args: project_dir: Directory to validate Returns: True if valid Flash project """ - main_py = project_dir / "main.py" - - if not main_py.exists(): - console.print(f"[red]Error:[/red] main.py not found in {project_dir}") + py_files = list(project_dir.rglob("*.py")) + if not py_files: + console.print(f"[red]Error:[/red] No Python files found in {project_dir}") return False - - # Check if main.py has FastAPI app - try: - content = main_py.read_text(encoding="utf-8") - if "FastAPI" not in content: - console.print( - "[yellow]Warning:[/yellow] main.py does not appear to have a FastAPI app" - ) - except Exception: - pass - return True diff --git a/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py b/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py index dcd0845d..a0d28601 100644 --- a/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py +++ b/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py @@ -21,13 +21,10 @@ - Real-time communication patterns """ -import asyncio import logging from contextlib import asynccontextmanager -from pathlib import Path -from typing import Optional -from fastapi import FastAPI, Request +from fastapi import FastAPI from runpod_flash.runtime.lb_handler import create_lb_handler logger = logging.getLogger(__name__) @@ -45,57 +42,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): """Handle application startup and shutdown.""" - # Startup logger.info("Starting {resource_name} endpoint") - - # Check if this is the mothership and run reconciliation - # Note: Resources are now provisioned upfront by the CLI during deployment. - # This background task runs reconciliation on mothership startup to ensure - # all resources are still deployed and in sync with the manifest. - try: - from runpod_flash.runtime.mothership_provisioner import ( - is_mothership, - reconcile_children, - get_mothership_url, - ) - from runpod_flash.runtime.state_manager_client import StateManagerClient - - if is_mothership(): - logger.info("=" * 60) - logger.info("Mothership detected - Starting reconciliation task") - logger.info("Resources are provisioned upfront by the CLI") - logger.info("This task ensures all resources remain in sync") - logger.info("=" * 60) - try: - mothership_url = get_mothership_url() - logger.info(f"Mothership URL: {{mothership_url}}") - - # Initialize State Manager client for reconciliation - state_client = StateManagerClient() - - # Spawn background reconciliation task (non-blocking) - # This will verify all resources from manifest are deployed - manifest_path = Path(__file__).parent / "flash_manifest.json" - task = asyncio.create_task( - reconcile_children(manifest_path, mothership_url, state_client) - ) - # Add error callback to catch and log background task exceptions - task.add_done_callback( - lambda t: logger.error(f"Reconciliation task failed: {{t.exception()}}") - if t.exception() - else None - ) - - except Exception as e: - logger.error(f"Failed to start reconciliation task: {{e}}") - # Don't fail startup - continue serving traffic - - except ImportError: - logger.debug("Mothership provisioning modules not available") - yield - - # Shutdown logger.info("Shutting down {resource_name} endpoint") diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index b67ce9bd..af2a283f 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -9,45 +9,17 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from runpod_flash.core.resources.constants import ( - DEFAULT_WORKERS_MAX, - DEFAULT_WORKERS_MIN, - FLASH_CPU_LB_IMAGE, - FLASH_LB_IMAGE, +from .scanner import ( + RemoteFunctionMetadata, + file_to_module_path, + file_to_url_prefix, ) -from .scanner import RemoteFunctionMetadata, detect_explicit_mothership, detect_main_app - logger = logging.getLogger(__name__) RESERVED_PATHS = ["/execute", "/ping"] -def _serialize_routes(routes: List[RemoteFunctionMetadata]) -> List[Dict[str, Any]]: - """Convert RemoteFunctionMetadata to manifest dict format. - - Args: - routes: List of route metadata objects - - Returns: - List of dicts with route information for manifest - """ - return [ - { - "name": route.function_name, - "module": route.module_path, - "is_async": route.is_async, - "is_class": route.is_class, - "is_load_balanced": route.is_load_balanced, - "is_live_resource": route.is_live_resource, - "config_variable": route.config_variable, - "http_method": route.http_method, - "http_path": route.http_path, - } - for route in routes - ] - - @dataclass class ManifestFunction: """Function entry in manifest.""" @@ -213,83 +185,13 @@ def _extract_deployment_config( return config - def _create_mothership_resource(self, main_app_config: dict) -> Dict[str, Any]: - """Create implicit mothership resource from main.py. - - Args: - main_app_config: Dict with 'file_path', 'app_variable', 'has_routes', 'fastapi_routes' keys - - Returns: - Dictionary representing the mothership resource for the manifest - """ - # Extract FastAPI routes if present - fastapi_routes = main_app_config.get("fastapi_routes", []) - functions_list = _serialize_routes(fastapi_routes) - - return { - "resource_type": "CpuLiveLoadBalancer", - "functions": functions_list, - "is_load_balanced": True, - "is_live_resource": True, - "is_mothership": True, - "main_file": main_app_config["file_path"].name, - "app_variable": main_app_config["app_variable"], - "imageName": FLASH_CPU_LB_IMAGE, - "workersMin": DEFAULT_WORKERS_MIN, - "workersMax": DEFAULT_WORKERS_MAX, - } - - def _create_mothership_from_explicit( - self, explicit_config: dict, search_dir: Path - ) -> Dict[str, Any]: - """Create mothership resource from explicit mothership.py configuration. - - Args: - explicit_config: Configuration dict from detect_explicit_mothership() - search_dir: Project directory + def build(self) -> Dict[str, Any]: + """Build the manifest dictionary. - Returns: - Dictionary representing the mothership resource for the manifest + Resources are keyed by resource_config_name for runtime compatibility. + Each resource entry includes file_path, local_path_prefix, and module_path + for the dev server and LB handler generator. """ - # Detect FastAPI app details for handler generation - main_app_config = detect_main_app(search_dir, explicit_mothership_exists=False) - - if not main_app_config: - # No FastAPI app found, use defaults - main_file = "main.py" - app_variable = "app" - fastapi_routes = [] - else: - main_file = main_app_config["file_path"].name - app_variable = main_app_config["app_variable"] - fastapi_routes = main_app_config.get("fastapi_routes", []) - - # Extract FastAPI routes into functions list - functions_list = _serialize_routes(fastapi_routes) - - # Map resource type to image name - resource_type = explicit_config.get("resource_type", "CpuLiveLoadBalancer") - if resource_type == "LiveLoadBalancer": - image_name = FLASH_LB_IMAGE # GPU load balancer - else: - image_name = FLASH_CPU_LB_IMAGE # CPU load balancer - - return { - "resource_type": resource_type, - "functions": functions_list, - "is_load_balanced": True, - "is_live_resource": True, - "is_mothership": True, - "is_explicit": True, # Flag to indicate explicit configuration - "main_file": main_file, - "app_variable": app_variable, - "imageName": image_name, - "workersMin": explicit_config.get("workersMin", DEFAULT_WORKERS_MIN), - "workersMax": explicit_config.get("workersMax", DEFAULT_WORKERS_MAX), - } - - def build(self) -> Dict[str, Any]: - """Build the manifest dictionary.""" # Group functions by resource_config_name resources: Dict[str, List[RemoteFunctionMetadata]] = {} @@ -305,6 +207,9 @@ def build(self) -> Dict[str, Any]: str, Dict[str, str] ] = {} # resource_name -> {route_key -> function_name} + # Determine project root for path derivation + project_root = self.build_dir.parent if self.build_dir else Path.cwd() + for resource_name, functions in sorted(resources.items()): # Use actual resource type from first function in group resource_type = ( @@ -315,6 +220,27 @@ def build(self) -> Dict[str, Any]: is_load_balanced = functions[0].is_load_balanced if functions else False is_live_resource = functions[0].is_live_resource if functions else False + # Derive path fields from the first function's source file. + # All functions in a resource share the same source file per convention. + first_file = functions[0].file_path if functions else None + file_path_str = "" + local_path_prefix = "" + resource_module_path = functions[0].module_path if functions else "" + + if first_file and first_file.exists(): + try: + file_path_str = str(first_file.relative_to(project_root)) + local_path_prefix = file_to_url_prefix(first_file, project_root) + resource_module_path = file_to_module_path(first_file, project_root) + except ValueError: + # File is outside project root — fall back to module_path + file_path_str = str(first_file) + local_path_prefix = "/" + functions[0].module_path.replace(".", "/") + elif first_file: + # File path may be relative (in test scenarios) + file_path_str = str(first_file) + local_path_prefix = "/" + functions[0].module_path.replace(".", "/") + # Validate and collect routing for LB endpoints resource_routes = {} if is_load_balanced: @@ -374,6 +300,9 @@ def build(self) -> Dict[str, Any]: resources_dict[resource_name] = { "resource_type": resource_type, + "file_path": file_path_str, + "local_path_prefix": local_path_prefix, + "module_path": resource_module_path, "functions": functions_list, "is_load_balanced": is_load_balanced, "is_live_resource": is_live_resource, @@ -395,68 +324,6 @@ def build(self) -> Dict[str, Any]: ) function_registry[f.function_name] = resource_name - # === MOTHERSHIP DETECTION (EXPLICIT THEN FALLBACK) === - search_dir = self.build_dir if self.build_dir else Path.cwd() - - # Step 1: Check for explicit mothership.py - explicit_mothership = detect_explicit_mothership(search_dir) - - if explicit_mothership: - # Use explicit configuration - logger.debug("Found explicit mothership configuration in mothership.py") - - # Check for name conflict - mothership_name = explicit_mothership.get("name", "mothership") - if mothership_name in resources_dict: - logger.warning( - f"Project has a @remote resource named '{mothership_name}'. " - f"Using 'mothership-entrypoint' for explicit mothership endpoint." - ) - mothership_name = "mothership-entrypoint" - - # Create mothership resource from explicit config - mothership_resource = self._create_mothership_from_explicit( - explicit_mothership, search_dir - ) - resources_dict[mothership_name] = mothership_resource - - else: - # Step 2: Fallback to auto-detection - main_app_config = detect_main_app( - search_dir, explicit_mothership_exists=False - ) - - if main_app_config and main_app_config["has_routes"]: - logger.warning( - "Auto-detected FastAPI app in main.py (no mothership.py found). " - "Consider running 'flash init' to create explicit mothership configuration." - ) - - # Check for name conflict - if "mothership" in resources_dict: - logger.warning( - "Project has a @remote resource named 'mothership'. " - "Using 'mothership-entrypoint' for auto-generated mothership endpoint." - ) - mothership_name = "mothership-entrypoint" - else: - mothership_name = "mothership" - - # Create mothership resource from auto-detection (legacy behavior) - mothership_resource = self._create_mothership_resource(main_app_config) - resources_dict[mothership_name] = mothership_resource - - # Extract routes from mothership resources - for resource_name, resource in resources_dict.items(): - if resource.get("is_mothership") and resource.get("functions"): - mothership_routes = {} - for func in resource["functions"]: - if func.get("http_method") and func.get("http_path"): - route_key = f"{func['http_method']} {func['http_path']}" - mothership_routes[route_key] = func["name"] - if mothership_routes: - routes_dict[resource_name] = mothership_routes - manifest = { "version": "1.0", "generated_at": datetime.now(timezone.utc) diff --git a/src/runpod_flash/cli/commands/build_utils/scanner.py b/src/runpod_flash/cli/commands/build_utils/scanner.py index 2215ab9e..d217dcb3 100644 --- a/src/runpod_flash/cli/commands/build_utils/scanner.py +++ b/src/runpod_flash/cli/commands/build_utils/scanner.py @@ -3,6 +3,7 @@ import ast import importlib import logging +import os import re from dataclasses import dataclass, field from pathlib import Path @@ -11,6 +12,61 @@ logger = logging.getLogger(__name__) +def file_to_url_prefix(file_path: Path, project_root: Path) -> str: + """Derive the local dev server URL prefix from a source file path. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + URL prefix starting with "/" (e.g., /longruns/stage1) + + Example: + longruns/stage1.py → /longruns/stage1 + """ + rel = file_path.relative_to(project_root).with_suffix("") + return "/" + str(rel).replace(os.sep, "/") + + +def file_to_resource_name(file_path: Path, project_root: Path) -> str: + """Derive the manifest resource name from a source file path. + + Slashes and hyphens are replaced with underscores to produce a valid + Python identifier suitable for use as a resource name. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + Resource name using underscores (e.g., longruns_stage1) + + Example: + longruns/stage1.py → longruns_stage1 + my-worker.py → my_worker + """ + rel = file_path.relative_to(project_root).with_suffix("") + return str(rel).replace(os.sep, "_").replace("/", "_").replace("-", "_") + + +def file_to_module_path(file_path: Path, project_root: Path) -> str: + """Derive the Python dotted module path from a source file path. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + Dotted module path (e.g., longruns.stage1) + + Example: + longruns/stage1.py → longruns.stage1 + """ + rel = file_path.relative_to(project_root).with_suffix("") + return str(rel).replace(os.sep, ".").replace("/", ".") + + @dataclass class RemoteFunctionMetadata: """Metadata about a @remote decorated function or class.""" @@ -35,6 +91,9 @@ class RemoteFunctionMetadata: called_remote_functions: List[str] = field( default_factory=list ) # Names of @remote functions called + is_lb_route_handler: bool = ( + False # LB @remote with method= and path= — runs directly as HTTP handler + ) class RemoteDecoratorScanner: @@ -62,7 +121,9 @@ def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: rel_path = f.relative_to(self.project_dir) # Check if first part of path is in excluded_root_dirs if rel_path.parts and rel_path.parts[0] not in excluded_root_dirs: - self.py_files.append(f) + # Exclude __init__.py — not valid worker entry points + if f.name != "__init__.py": + self.py_files.append(f) except (ValueError, IndexError): # Include files that can't be made relative self.py_files.append(f) @@ -220,6 +281,15 @@ def _extract_remote_functions( {"is_load_balanced": False, "is_live_resource": False}, ) + # An LB route handler is an LB @remote function that has + # both method= and path= declared. Its body runs directly + # on the LB endpoint — it is NOT a remote dispatch stub. + is_lb_route_handler = ( + flags["is_load_balanced"] + and http_method is not None + and http_path is not None + ) + metadata = RemoteFunctionMetadata( function_name=node.name, module_path=module_path, @@ -235,6 +305,7 @@ def _extract_remote_functions( config_variable=self.resource_variables.get( resource_config_name ), + is_lb_route_handler=is_lb_route_handler, ) functions.append(metadata) diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 051115a8..86fceb59 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -5,16 +5,391 @@ import signal import subprocess import sys +import threading +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import List -import questionary import typer from rich.console import Console +from rich.table import Table +from watchfiles import DefaultFilter as _WatchfilesDefaultFilter +from watchfiles import watch as _watchfiles_watch + +from .build_utils.scanner import ( + RemoteDecoratorScanner, + file_to_module_path, + file_to_resource_name, + file_to_url_prefix, +) logger = logging.getLogger(__name__) console = Console() +# Resource state file written by ResourceManager in the uvicorn subprocess. +_RESOURCE_STATE_FILE = Path(".runpod") / "resources.pkl" + + +@dataclass +class WorkerInfo: + """Info about a discovered @remote function for dev server generation.""" + + file_path: Path + url_prefix: str # e.g. /longruns/stage1 + module_path: str # e.g. longruns.stage1 + resource_name: str # e.g. longruns_stage1 + worker_type: str # "QB" or "LB" + functions: List[str] # function names + lb_routes: List[dict] = field(default_factory=list) # [{method, path, fn_name}] + + +def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: + """Scan the project for all @remote decorated functions. + + Walks all .py files (excluding .flash/, __pycache__, __init__.py) and + builds WorkerInfo for each file that contains @remote functions. + + Files with QB functions produce one WorkerInfo per file (QB type). + Files with LB functions produce one WorkerInfo per file (LB type). + A file can have both QB and LB functions (unusual but supported). + + Args: + project_root: Root directory of the Flash project + + Returns: + List of WorkerInfo, one entry per discovered source file + """ + scanner = RemoteDecoratorScanner(project_root) + remote_functions = scanner.discover_remote_functions() + + # Group by file path + by_file: dict[Path, List] = {} + for func in remote_functions: + by_file.setdefault(func.file_path, []).append(func) + + workers: List[WorkerInfo] = [] + for file_path, funcs in sorted(by_file.items()): + url_prefix = file_to_url_prefix(file_path, project_root) + module_path = file_to_module_path(file_path, project_root) + resource_name = file_to_resource_name(file_path, project_root) + + qb_funcs = [f for f in funcs if not f.is_load_balanced] + lb_funcs = [f for f in funcs if f.is_load_balanced and f.is_lb_route_handler] + + if qb_funcs: + workers.append( + WorkerInfo( + file_path=file_path, + url_prefix=url_prefix, + module_path=module_path, + resource_name=resource_name, + worker_type="QB", + functions=[f.function_name for f in qb_funcs], + ) + ) + + if lb_funcs: + lb_routes = [ + { + "method": f.http_method, + "path": f.http_path, + "fn_name": f.function_name, + } + for f in lb_funcs + ] + workers.append( + WorkerInfo( + file_path=file_path, + url_prefix=url_prefix, + module_path=module_path, + resource_name=resource_name, + worker_type="LB", + functions=[f.function_name for f in lb_funcs], + lb_routes=lb_routes, + ) + ) + + return workers + + +def _ensure_gitignore(project_root: Path) -> None: + """Add .flash/ to .gitignore if not already present.""" + gitignore = project_root / ".gitignore" + entry = ".flash/" + + if gitignore.exists(): + content = gitignore.read_text(encoding="utf-8") + if entry in content: + return + # Append with a newline + if not content.endswith("\n"): + content += "\n" + gitignore.write_text(content + entry + "\n", encoding="utf-8") + else: + gitignore.write_text(entry + "\n", encoding="utf-8") + + +def _sanitize_fn_name(name: str) -> str: + """Sanitize a string for use as a Python function name.""" + return name.replace("/", "_").replace(".", "_").replace("-", "_") + + +def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: + """Generate .flash/server.py from the discovered workers. + + Args: + project_root: Root of the Flash project + workers: List of discovered worker infos + + Returns: + Path to the generated server.py + """ + flash_dir = project_root / ".flash" + flash_dir.mkdir(exist_ok=True) + + _ensure_gitignore(project_root) + + lines = [ + '"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run."""', + "import sys", + "import uuid", + "from pathlib import Path", + "sys.path.insert(0, str(Path(__file__).parent.parent))", + "", + "from fastapi import FastAPI", + "", + ] + + # Collect all imports + all_imports: List[str] = [] + for worker in workers: + for fn_name in worker.functions: + all_imports.append(f"from {worker.module_path} import {fn_name}") + + if all_imports: + lines.extend(all_imports) + lines.append("") + + lines += [ + "app = FastAPI(", + ' title="Flash Dev Server",', + ' description="Auto-generated by `flash run`. Visit /docs for interactive testing.",', + ")", + "", + ] + + for worker in workers: + tag = f"{worker.url_prefix.lstrip('/')} [{worker.worker_type}]" + lines.append(f"# {'─' * 60}") + lines.append(f"# {worker.worker_type}: {worker.file_path.name}") + lines.append(f"# {'─' * 60}") + + if worker.worker_type == "QB": + if len(worker.functions) == 1: + fn = worker.functions[0] + handler_name = _sanitize_fn_name(f"{worker.resource_name}_run") + run_path = f"{worker.url_prefix}/run" + sync_path = f"{worker.url_prefix}/run_sync" + lines += [ + f'@app.post("{run_path}", tags=["{tag}"])', + f'@app.post("{sync_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f' result = await {fn}(body.get("input", body))', + ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', + "", + ] + else: + for fn in worker.functions: + handler_name = _sanitize_fn_name(f"{worker.resource_name}_{fn}_run") + run_path = f"{worker.url_prefix}/{fn}/run" + sync_path = f"{worker.url_prefix}/{fn}/run_sync" + lines += [ + f'@app.post("{run_path}", tags=["{tag}"])', + f'@app.post("{sync_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f' result = await {fn}(body.get("input", body))', + ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', + "", + ] + + elif worker.worker_type == "LB": + for route in worker.lb_routes: + method = route["method"].lower() + sub_path = route["path"].lstrip("/") + fn_name = route["fn_name"] + full_path = f"{worker.url_prefix}/{sub_path}" + handler_name = _sanitize_fn_name( + f"_route_{worker.resource_name}_{fn_name}" + ) + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: dict):", + f" return await {fn_name}(body)", + "", + ] + + # Health endpoints + lines += [ + "# Health", + '@app.get("/", tags=["health"])', + "def home():", + ' return {"message": "Flash Dev Server", "docs": "/docs"}', + "", + '@app.get("/ping", tags=["health"])', + "def ping():", + ' return {"status": "healthy"}', + "", + ] + + server_path = flash_dir / "server.py" + server_path.write_text("\n".join(lines), encoding="utf-8") + return server_path + + +def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> None: + """Print the startup table showing local paths, resource names, and types.""" + console.print(f"\n[bold green]Flash Dev Server[/bold green] http://{host}:{port}") + console.print() + + table = Table(show_header=True, header_style="bold") + table.add_column("Local path", style="cyan") + table.add_column("Resource", style="white") + table.add_column("Type", style="yellow") + + for worker in workers: + if worker.worker_type == "QB": + if len(worker.functions) == 1: + table.add_row( + f"POST {worker.url_prefix}/run", + worker.resource_name, + "QB", + ) + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) + else: + for fn in worker.functions: + table.add_row( + f"POST {worker.url_prefix}/{fn}/run", + worker.resource_name, + "QB", + ) + table.add_row( + f"POST {worker.url_prefix}/{fn}/run_sync", + worker.resource_name, + "QB", + ) + elif worker.worker_type == "LB": + for route in worker.lb_routes: + sub_path = route["path"].lstrip("/") + full_path = f"{worker.url_prefix}/{sub_path}" + table.add_row( + f"{route['method']} {full_path}", + worker.resource_name, + "LB", + ) + + console.print(table) + console.print(f"\n Visit [bold]http://{host}:{port}/docs[/bold] for Swagger UI\n") + + +def _cleanup_live_endpoints() -> None: + """Deprovision all Live Serverless endpoints created during this session. + + Reads the resource state file written by the uvicorn subprocess, finds + all endpoints with the 'live-' name prefix, and deprovisions them. + Best-effort: errors per endpoint are logged but do not prevent cleanup + of other endpoints. + """ + if not _RESOURCE_STATE_FILE.exists(): + return + + try: + import asyncio + import cloudpickle + from ...core.utils.file_lock import file_lock + + with open(_RESOURCE_STATE_FILE, "rb") as f: + with file_lock(f, exclusive=False): + data = cloudpickle.load(f) + + if isinstance(data, tuple): + resources, configs = data + else: + resources, configs = data, {} + + live_items = { + key: resource + for key, resource in resources.items() + if hasattr(resource, "name") + and resource.name + and resource.name.startswith("live-") + } + + if not live_items: + return + + async def _do_cleanup(): + for key, resource in live_items.items(): + name = getattr(resource, "name", key) + try: + success = await resource._do_undeploy() + if success: + console.print(f" Deprovisioned: {name}") + else: + logger.warning(f"Failed to deprovision: {name}") + except Exception as e: + logger.warning(f"Error deprovisioning {name}: {e}") + + asyncio.run(_do_cleanup()) + + # Remove live- entries from persisted state so they don't linger. + remaining = {k: v for k, v in resources.items() if k not in live_items} + remaining_configs = {k: v for k, v in configs.items() if k not in live_items} + try: + with open(_RESOURCE_STATE_FILE, "wb") as f: + with file_lock(f, exclusive=True): + cloudpickle.dump((remaining, remaining_configs), f) + except Exception as e: + logger.warning(f"Could not update resource state after cleanup: {e}") + + except Exception as e: + logger.warning(f"Live endpoint cleanup failed: {e}") + + +def _is_reload() -> bool: + """Check if running in uvicorn reload subprocess.""" + return "UVICORN_RELOADER_PID" in os.environ + + +def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> None: + """Watch project .py files and regenerate server.py when they change. + + Ignores .flash/ to avoid reacting to our own writes. Runs until + stop_event is set. + """ + watch_filter = _WatchfilesDefaultFilter(ignore_paths=[str(project_root / ".flash")]) + + try: + for changes in _watchfiles_watch( + project_root, + watch_filter=watch_filter, + stop_event=stop_event, + ): + py_changed = [p for _, p in changes if p.endswith(".py")] + if not py_changed: + continue + try: + workers = _scan_project_workers(project_root) + _generate_flash_server(project_root, workers) + logger.debug("server.py regenerated (%d changed)", len(py_changed)) + except Exception as e: + logger.warning("Failed to regenerate server.py: %s", e) + except Exception: + pass # stop_event was set or watchfiles unavailable — both are fine + def run_command( host: str = typer.Option( @@ -33,68 +408,51 @@ def run_command( reload: bool = typer.Option( True, "--reload/--no-reload", help="Enable auto-reload" ), - auto_provision: bool = typer.Option( - False, - "--auto-provision", - help="Auto-provision deployable resources on startup", - ), ): - """Run Flash development server with uvicorn.""" + """Run Flash development server. - # Discover entry point - entry_point = discover_entry_point() - if not entry_point: - console.print("[red]Error:[/red] No entry point found") - console.print("Create main.py with a FastAPI app") - raise typer.Exit(1) + Scans the project for @remote decorated functions, generates a dev server + at .flash/server.py, and starts uvicorn with hot-reload. - # Check if entry point has FastAPI app - app_location = check_fastapi_app(entry_point) - if not app_location: - console.print(f"[red]Error:[/red] No FastAPI app found in {entry_point}") - console.print("Make sure your main.py contains: app = FastAPI()") - raise typer.Exit(1) + No main.py or FastAPI boilerplate required. Any .py file with @remote + decorated functions is a valid Flash project. + """ + project_root = Path.cwd() - # Set flag for all flash run sessions to ensure both auto-provisioned - # and on-the-fly provisioned resources get the live- prefix + # Set flag for live provisioning so stubs get the live- prefix if not _is_reload(): os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true" - # Auto-provision resources if flag is set and not a reload - if auto_provision and not _is_reload(): - try: - resources = _discover_resources(entry_point) + # Discover @remote functions + workers = _scan_project_workers(project_root) - if resources: - # If many resources found, ask for confirmation - if len(resources) > 5: - if not _confirm_large_provisioning(resources): - console.print("[yellow]Auto-provisioning cancelled[/yellow]\n") - else: - _provision_resources(resources) - else: - _provision_resources(resources) - except Exception as e: - logger.error("Auto-provisioning failed", exc_info=True) - console.print( - f"[yellow]Warning:[/yellow] Resource provisioning failed: {e}" - ) - console.print( - "[yellow]Note:[/yellow] Resources will be deployed on-demand when first called" - ) + if not workers: + console.print("[red]Error:[/red] No @remote functions found.") + console.print("Add @remote decorators to your functions to get started.") + console.print("\nExample:") + console.print( + " from runpod_flash import LiveServerless, remote\n" + " gpu_config = LiveServerless(name='my_worker')\n" + "\n" + " @remote(gpu_config)\n" + " async def process(input_data: dict) -> dict:\n" + " return {'result': input_data}" + ) + raise typer.Exit(1) + + # Generate .flash/server.py + _generate_flash_server(project_root, workers) - console.print("\n[green]Starting Flash Server[/green]") - console.print(f"Entry point: [bold]{app_location}[/bold]") - console.print(f"Server: [bold]http://{host}:{port}[/bold]") - console.print(f"Auto-reload: [bold]{'enabled' if reload else 'disabled'}[/bold]") - console.print("\nPress CTRL+C to stop\n") + _print_startup_table(workers, host, port) - # Build uvicorn command + # Build uvicorn command using --app-dir so server:app is importable cmd = [ sys.executable, "-m", "uvicorn", - app_location, + "server:app", + "--app-dir", + ".flash", "--host", host, "--port", @@ -104,13 +462,24 @@ def run_command( ] if reload: - cmd.append("--reload") + cmd += [ + "--reload", + "--reload-dir", + ".flash", + "--reload-include", + "server.py", + ] + + stop_event = threading.Event() + watcher_thread = threading.Thread( + target=_watch_and_regenerate, + args=(project_root, stop_event), + daemon=True, + name="flash-watcher", + ) - # Run uvicorn with proper process group handling process = None try: - # Create new process group to ensure all child processes can be killed together - # On Unix systems, use process group; on Windows, CREATE_NEW_PROCESS_GROUP if sys.platform == "win32": process = subprocess.Popen( cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP @@ -118,27 +487,27 @@ def run_command( else: process = subprocess.Popen(cmd, preexec_fn=os.setsid) - # Wait for process to complete + if reload: + watcher_thread.start() + process.wait() except KeyboardInterrupt: - console.print("\n[yellow]Stopping server and cleaning up processes...[/yellow]") + console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") + + stop_event.set() + watcher_thread.join(timeout=2) - # Kill the entire process group to ensure all child processes are terminated if process: try: if sys.platform == "win32": - # Windows: terminate the process process.terminate() else: - # Unix: kill entire process group os.killpg(os.getpgid(process.pid), signal.SIGTERM) - # Wait briefly for graceful shutdown try: process.wait(timeout=2) except subprocess.TimeoutExpired: - # Force kill if didn't terminate gracefully if sys.platform == "win32": process.kill() else: @@ -146,14 +515,18 @@ def run_command( process.wait() except (ProcessLookupError, OSError): - # Process already terminated pass + _cleanup_live_endpoints() console.print("[green]Server stopped[/green]") raise typer.Exit(0) except Exception as e: console.print(f"[red]Error:[/red] {e}") + + stop_event.set() + watcher_thread.join(timeout=2) + if process: try: if sys.platform == "win32": @@ -162,135 +535,5 @@ def run_command( os.killpg(os.getpgid(process.pid), signal.SIGTERM) except (ProcessLookupError, OSError): pass + _cleanup_live_endpoints() raise typer.Exit(1) - - -def discover_entry_point() -> Optional[str]: - """Discover the main entry point file.""" - candidates = ["main.py", "app.py", "server.py"] - - for candidate in candidates: - if Path(candidate).exists(): - return candidate - - return None - - -def check_fastapi_app(entry_point: str) -> Optional[str]: - """ - Check if entry point has a FastAPI app and return the app location. - - Returns: - App location in format "module:app" or None - """ - try: - # Read the file - content = Path(entry_point).read_text() - - # Check for FastAPI app - if "app = FastAPI(" in content or "app=FastAPI(" in content: - # Extract module name from file path - module = entry_point.replace(".py", "").replace("/", ".") - return f"{module}:app" - - return None - - except Exception: - return None - - -def _is_reload() -> bool: - """Check if running in uvicorn reload subprocess. - - Returns: - True if running in a reload subprocess - """ - return "UVICORN_RELOADER_PID" in os.environ - - -def _discover_resources(entry_point: str): - """Discover deployable resources in entry point. - - Args: - entry_point: Path to entry point file - - Returns: - List of discovered DeployableResource instances - """ - from ...core.discovery import ResourceDiscovery - - try: - discovery = ResourceDiscovery(entry_point, max_depth=2) - resources = discovery.discover() - - # Debug: Log what was discovered - if resources: - console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") - for res in resources: - res_name = getattr(res, "name", "Unknown") - res_type = res.__class__.__name__ - console.print(f" [dim]• {res_name} ({res_type})[/dim]") - console.print() - - return resources - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Resource discovery failed: {e}") - return [] - - -def _confirm_large_provisioning(resources) -> bool: - """Show resources and prompt user for confirmation. - - Args: - resources: List of resources to provision - - Returns: - True if user confirms, False otherwise - """ - try: - console.print( - f"\n[yellow]Found {len(resources)} resources to provision:[/yellow]" - ) - - for resource in resources: - name = getattr(resource, "name", "Unknown") - resource_type = resource.__class__.__name__ - console.print(f" • {name} ({resource_type})") - - console.print() - - confirmed = questionary.confirm( - "This may take several minutes. Do you want to proceed?" - ).ask() - - return confirmed if confirmed is not None else False - - except (KeyboardInterrupt, EOFError): - console.print("\n[yellow]Cancelled[/yellow]") - return False - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Confirmation failed: {e}") - return False - - -def _provision_resources(resources): - """Provision resources and wait for completion. - - Args: - resources: List of resources to provision - """ - import asyncio - from ...core.deployment import DeploymentOrchestrator - - try: - console.print(f"\n[bold]Provisioning {len(resources)} resource(s)...[/bold]") - orchestrator = DeploymentOrchestrator(max_concurrent=3) - - # Run provisioning with progress shown - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(orchestrator.deploy_all(resources, show_progress=True)) - loop.close() - - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Provisioning failed: {e}") diff --git a/src/runpod_flash/client.py b/src/runpod_flash/client.py index ed68bc30..8709cf75 100644 --- a/src/runpod_flash/client.py +++ b/src/runpod_flash/client.py @@ -159,6 +159,20 @@ def decorator(func_or_class): "system_dependencies": system_dependencies, } + # LB route handler passthrough — return the function unwrapped. + # + # When @remote is applied to an LB resource (LiveLoadBalancer, + # CpuLiveLoadBalancer, LoadBalancerSlsResource) with method= and path=, + # the decorated function IS the HTTP route handler. Its body executes + # directly on the LB endpoint server; it is not dispatched to a remote + # process. QB @remote calls inside its body still use their own stubs. + is_lb_route_handler = is_lb_resource and method is not None and path is not None + if is_lb_route_handler: + routing_config["is_lb_route_handler"] = True + func_or_class.__remote_config__ = routing_config + func_or_class.__is_lb_route_handler__ = True + return func_or_class + # Local execution mode - execute without provisioning remote servers if local: func_or_class.__remote_config__ = routing_config diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index d2a0f7b0..27d811a4 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -474,6 +474,14 @@ def is_deployed(self) -> bool: if not self.id: return False + # During flash run, skip the health check. Newly-created endpoints + # can fail health checks due to RunPod propagation delay — the + # endpoint exists but the health API hasn't registered it yet. + # Trusting the cached ID is correct here; actual failures surface + # on the first real run/run_sync call. + if os.getenv("FLASH_IS_LIVE_PROVISIONING", "").lower() == "true": + return True + response = self.endpoint.health() return response is not None except Exception as e: @@ -484,6 +492,10 @@ def _payload_exclude(self) -> Set[str]: # flashEnvironmentId is input-only but must be sent when provided exclude_fields = set(self._input_only or set()) exclude_fields.discard("flashEnvironmentId") + # When templateId is already set, exclude template from the payload. + # RunPod rejects requests that contain both fields simultaneously. + if self.templateId: + exclude_fields.add("template") return exclude_fields @staticmethod @@ -564,12 +576,45 @@ def _check_makes_remote_calls(self) -> bool: ) return True # Safe default on error + def _get_module_path(self) -> Optional[str]: + """Get module_path from build manifest for this resource. + + Returns: + Dotted module path (e.g., 'preprocess.first_pass'), or None if not found. + """ + try: + manifest_path = Path.cwd() / "flash_manifest.json" + if not manifest_path.exists(): + manifest_path = Path("/flash_manifest.json") + if not manifest_path.exists(): + return None + + with open(manifest_path) as f: + manifest_data = json.load(f) + + resources = manifest_data.get("resources", {}) + + lookup_name = self.name + if lookup_name.endswith("-fb"): + lookup_name = lookup_name[:-3] + if lookup_name.startswith(LIVE_PREFIX): + lookup_name = lookup_name[len(LIVE_PREFIX) :] + + resource_config = resources.get(lookup_name) + if not resource_config: + return None + + return resource_config.get("module_path") + + except Exception: + return None + async def _do_deploy(self) -> "DeployableResource": """ Deploys the serverless resource using the provided configuration. - For queue-based endpoints that make remote calls, injects RUNPOD_API_KEY - into environment variables if not already set. + For queue-based endpoints that make remote calls, injects RUNPOD_API_KEY. + For load-balanced endpoints, injects FLASH_MODULE_PATH. Returns a DeployableResource object. """ @@ -604,6 +649,17 @@ async def _do_deploy(self) -> "DeployableResource": self.env = env_dict + # Inject module path for load-balanced endpoints + elif self.type == ServerlessType.LB: + env_dict = self.env or {} + + module_path = self._get_module_path() + if module_path and "FLASH_MODULE_PATH" not in env_dict: + env_dict["FLASH_MODULE_PATH"] = module_path + log.info(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}") + + self.env = env_dict + # Ensure network volume is deployed first await self._ensure_network_volume_deployed() diff --git a/tests/integration/test_run_auto_provision.py b/tests/integration/test_run_auto_provision.py deleted file mode 100644 index 9478f442..00000000 --- a/tests/integration/test_run_auto_provision.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Integration tests for flash run --auto-provision command.""" - -import pytest -from unittest.mock import patch, MagicMock -from textwrap import dedent -from typer.testing import CliRunner - -from runpod_flash.cli.main import app - -runner = CliRunner() - - -class TestRunAutoProvision: - """Test flash run --auto-provision integration.""" - - @pytest.fixture - def temp_project(self, tmp_path): - """Create temporary Flash project for testing.""" - # Create main.py with FastAPI app - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - app = FastAPI() - - gpu_config = ServerlessResource( - name="test-gpu", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(resource_config=gpu_config) - async def gpu_task(): - return "result" - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - return tmp_path - - @pytest.fixture - def temp_project_many_resources(self, tmp_path): - """Create temporary project with many resources (> 5).""" - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - app = FastAPI() - - # Create 6 resources to trigger confirmation - configs = [ - ServerlessResource( - name=f"endpoint-{i}", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - for i in range(6) - ] - - @remote(resource_config=configs[0]) - async def task1(): pass - - @remote(resource_config=configs[1]) - async def task2(): pass - - @remote(resource_config=configs[2]) - async def task3(): pass - - @remote(resource_config=configs[3]) - async def task4(): pass - - @remote(resource_config=configs[4]) - async def task5(): pass - - @remote(resource_config=configs[5]) - async def task6(): pass - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - return tmp_path - - def test_run_without_auto_provision(self, temp_project, monkeypatch): - """Test that flash run without --auto-provision doesn't deploy resources.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations to prevent hanging - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to track if it was called - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - runner.invoke(app, ["run"]) - - # Discovery should not be called - mock_discover.assert_not_called() - - def test_run_with_auto_provision_single_resource(self, temp_project, monkeypatch): - """Test flash run --auto-provision with single resource.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock deployment orchestrator - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should be called - mock_provision.assert_called_once() - - def test_run_with_auto_provision_skips_reload(self, temp_project, monkeypatch): - """Test that auto-provision is skipped on reload.""" - monkeypatch.chdir(temp_project) - - # Simulate reload environment - monkeypatch.setenv("UVICORN_RELOADER_PID", "12345") - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock provisioning - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should NOT be called on reload - mock_provision.assert_not_called() - - def test_run_with_auto_provision_many_resources_confirmed( - self, temp_project, monkeypatch - ): - """Test auto-provision with > 5 resources and user confirmation.""" - monkeypatch.chdir(temp_project) - - # Create 6 mock resources - mock_resources = [MagicMock(name=f"endpoint-{i}") for i in range(6)] - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to return > 5 resources - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = mock_resources - - # Mock questionary to simulate user confirmation - with patch( - "runpod_flash.cli.commands.run.questionary.confirm" - ) as mock_confirm: - mock_confirm.return_value.ask.return_value = True - - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Should prompt for confirmation - mock_confirm.assert_called_once() - - # Should provision after confirmation - mock_provision.assert_called_once() - - def test_run_with_auto_provision_many_resources_cancelled( - self, temp_project, monkeypatch - ): - """Test auto-provision with > 5 resources and user cancellation.""" - monkeypatch.chdir(temp_project) - - # Create 6 mock resources - mock_resources = [MagicMock(name=f"endpoint-{i}") for i in range(6)] - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to return > 5 resources - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = mock_resources - - # Mock questionary to simulate user cancellation - with patch( - "runpod_flash.cli.commands.run.questionary.confirm" - ) as mock_confirm: - mock_confirm.return_value.ask.return_value = False - - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Should prompt for confirmation - mock_confirm.assert_called_once() - - # Should NOT provision after cancellation - mock_provision.assert_not_called() - - def test_run_auto_provision_discovery_error(self, temp_project, monkeypatch): - """Test that run handles discovery errors gracefully.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to raise exception - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = [] - - runner.invoke(app, ["run", "--auto-provision"]) - - # Server should still start despite discovery error - mock_popen.assert_called_once() - - def test_run_auto_provision_no_resources_found(self, tmp_path, monkeypatch): - """Test auto-provision when no resources are found.""" - monkeypatch.chdir(tmp_path) - - # Create main.py without any @remote decorators - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - - app = FastAPI() - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should not be called - mock_provision.assert_not_called() - - # Server should still start - mock_popen.assert_called_once() diff --git a/tests/unit/cli/commands/build_utils/test_manifest_mothership.py b/tests/unit/cli/commands/build_utils/test_manifest_mothership.py deleted file mode 100644 index 896eefdf..00000000 --- a/tests/unit/cli/commands/build_utils/test_manifest_mothership.py +++ /dev/null @@ -1,404 +0,0 @@ -"""Tests for mothership resource creation in manifest.""" - -import tempfile -from pathlib import Path -from unittest.mock import patch - -from runpod_flash.cli.commands.build_utils.manifest import ManifestBuilder -from runpod_flash.cli.commands.build_utils.scanner import RemoteFunctionMetadata -from runpod_flash.core.resources.constants import ( - FLASH_CPU_LB_IMAGE, - FLASH_LB_IMAGE, -) - - -class TestManifestMothership: - """Test mothership resource creation in manifest.""" - - def test_manifest_includes_mothership_with_main_py(self): - """Test mothership resource added to manifest when main.py detected.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create a simple function file - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -gpu_config = LiveServerless(name="gpu_worker") - -@remote(resource_config=gpu_config) -def process(data): - return data -""" - ) - - # Change to project directory for detection - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", - remote_functions=[], - ) - manifest = builder.build() - - # Check mothership is in resources - assert "mothership" in manifest["resources"] - mothership = manifest["resources"]["mothership"] - assert mothership["is_mothership"] is True - assert mothership["main_file"] == "main.py" - assert mothership["app_variable"] == "app" - assert mothership["resource_type"] == "CpuLiveLoadBalancer" - assert mothership["imageName"] == FLASH_CPU_LB_IMAGE - - def test_manifest_skips_mothership_without_routes(self): - """Test mothership NOT added if main.py has no routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py without routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() -# No routes defined -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Mothership should NOT be in resources - assert "mothership" not in manifest["resources"] - - def test_manifest_skips_mothership_without_main_py(self): - """Test mothership NOT added if no main.py exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Mothership should NOT be in resources - assert "mothership" not in manifest["resources"] - - def test_manifest_handles_mothership_name_conflict(self): - """Test mothership uses alternate name if conflict with @remote resource.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create a remote function with name "mothership" (conflict) - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -mothership_config = LiveServerless(name="mothership") - -@remote(resource_config=mothership_config) -def process(data): - return data -""" - ) - - # Create remote function metadata with resource named "mothership" - remote_func = RemoteFunctionMetadata( - function_name="process", - module_path="functions", - resource_config_name="mothership", - resource_type="LiveServerless", - is_async=False, - is_class=False, - file_path=func_file, - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", remote_functions=[remote_func] - ) - manifest = builder.build() - - # Original mothership should be in resources - assert "mothership" in manifest["resources"] - # Auto-generated mothership should use alternate name - assert "mothership-entrypoint" in manifest["resources"] - entrypoint = manifest["resources"]["mothership-entrypoint"] - assert entrypoint["is_mothership"] is True - - def test_mothership_resource_config(self): - """Test mothership resource has correct configuration.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - mothership = manifest["resources"]["mothership"] - - # Check all expected fields - assert mothership["resource_type"] == "CpuLiveLoadBalancer" - # Functions should include the FastAPI route - assert len(mothership["functions"]) == 1 - assert mothership["functions"][0]["name"] == "root" - assert mothership["functions"][0]["http_method"] == "GET" - assert mothership["functions"][0]["http_path"] == "/" - assert mothership["is_load_balanced"] is True - assert mothership["is_live_resource"] is True - assert mothership["imageName"] == FLASH_CPU_LB_IMAGE - assert mothership["workersMin"] == 1 - assert mothership["workersMax"] == 1 - - def test_manifest_uses_explicit_mothership_config(self): - """Test explicit mothership.py config takes precedence over auto-detection.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create explicit mothership.py with custom config - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="my-api", - workersMin=3, - workersMax=7, -) -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Check explicit config is used - assert "my-api" in manifest["resources"] - mothership = manifest["resources"]["my-api"] - assert mothership["is_explicit"] is True - assert mothership["workersMin"] == 3 - assert mothership["workersMax"] == 7 - - def test_manifest_skips_auto_detect_with_explicit_config(self): - """Test auto-detection is skipped when explicit config exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create explicit mothership.py - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="explicit-mothership", - workersMin=2, - workersMax=4, -) -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Check only explicit config is in resources (not auto-detected "mothership") - assert "explicit-mothership" in manifest["resources"] - assert ( - manifest["resources"]["explicit-mothership"]["is_explicit"] is True - ) - assert "mothership" not in manifest["resources"] - - def test_manifest_handles_explicit_mothership_name_conflict(self): - """Test explicit mothership uses alternate name if conflict with @remote.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create explicit mothership.py with name that conflicts with resource - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="api", # Will conflict with @remote resource named "api" - workersMin=1, - workersMax=3, -) -""" - ) - - # Create a remote function with name "api" (conflict) - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -api_config = LiveServerless(name="api") - -@remote(resource_config=api_config) -def process(data): - return data -""" - ) - - remote_func = RemoteFunctionMetadata( - function_name="process", - module_path="functions", - resource_config_name="api", - resource_type="LiveServerless", - is_async=False, - is_class=False, - file_path=func_file, - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", remote_functions=[remote_func] - ) - manifest = builder.build() - - # Original resource should be in resources - assert "api" in manifest["resources"] - # Explicit mothership should use alternate name - assert "mothership-entrypoint" in manifest["resources"] - entrypoint = manifest["resources"]["mothership-entrypoint"] - assert entrypoint["is_explicit"] is True - - def test_manifest_explicit_mothership_with_gpu_load_balancer(self): - """Test explicit GPU-based load balancer config.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create explicit mothership.py with GPU load balancer - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import LiveLoadBalancer - -mothership = LiveLoadBalancer( - name="gpu-mothership", - workersMin=1, - workersMax=2, -) -""" - ) - - # Create main.py for FastAPI app - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - mothership = manifest["resources"]["gpu-mothership"] - assert mothership["resource_type"] == "LiveLoadBalancer" - assert mothership["imageName"] == FLASH_LB_IMAGE - assert mothership["is_explicit"] is True diff --git a/tests/unit/cli/commands/build_utils/test_path_utilities.py b/tests/unit/cli/commands/build_utils/test_path_utilities.py new file mode 100644 index 00000000..73ec3557 --- /dev/null +++ b/tests/unit/cli/commands/build_utils/test_path_utilities.py @@ -0,0 +1,217 @@ +"""TDD tests for scanner path utility functions. + +Written first (failing) per the plan's TDD requirement. +These test file_to_url_prefix, file_to_resource_name, file_to_module_path. +""" + +import os + + +from runpod_flash.cli.commands.build_utils.scanner import ( + file_to_module_path, + file_to_resource_name, + file_to_url_prefix, +) + + +class TestFileToUrlPrefix: + """Tests for file_to_url_prefix utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → /gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_url_prefix(f, tmp_path) == "/gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → /longruns/stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_url_prefix(f, tmp_path) == "/longruns/stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → /preprocess/first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_url_prefix(f, tmp_path) == "/preprocess/first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → /workers/gpu/inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_url_prefix(f, tmp_path) == "/workers/gpu/inference" + + def test_hyphenated_filename(self, tmp_path): + """my-worker.py → /my-worker (hyphens valid in URLs)""" + f = tmp_path / "my-worker.py" + assert file_to_url_prefix(f, tmp_path) == "/my-worker" + + def test_starts_with_slash(self, tmp_path): + """Result always starts with /""" + f = tmp_path / "worker.py" + result = file_to_url_prefix(f, tmp_path) + assert result.startswith("/") + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_url_prefix(f, tmp_path) + assert ".py" not in result + + +class TestFileToResourceName: + """Tests for file_to_resource_name utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_resource_name(f, tmp_path) == "gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → longruns_stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_resource_name(f, tmp_path) == "longruns_stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → preprocess_first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_resource_name(f, tmp_path) == "preprocess_first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → workers_gpu_inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_resource_name(f, tmp_path) == "workers_gpu_inference" + + def test_hyphenated_filename(self, tmp_path): + """my-worker.py → my_worker (hyphens replaced with underscores for Python identifiers)""" + f = tmp_path / "my-worker.py" + assert file_to_resource_name(f, tmp_path) == "my_worker" + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_resource_name(f, tmp_path) + assert ".py" not in result + + def test_no_path_separators(self, tmp_path): + """Result contains no / or os.sep characters""" + f = tmp_path / "a" / "b" / "worker.py" + result = file_to_resource_name(f, tmp_path) + assert "/" not in result + assert os.sep not in result + + +class TestFileToModulePath: + """Tests for file_to_module_path utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_module_path(f, tmp_path) == "gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → longruns.stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_module_path(f, tmp_path) == "longruns.stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → preprocess.first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_module_path(f, tmp_path) == "preprocess.first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → workers.gpu.inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_module_path(f, tmp_path) == "workers.gpu.inference" + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_module_path(f, tmp_path) + assert ".py" not in result + + def test_uses_dots_not_slashes(self, tmp_path): + """Result uses dots as separators, not slashes""" + f = tmp_path / "a" / "b" / "worker.py" + result = file_to_module_path(f, tmp_path) + assert "." in result + assert "/" not in result + assert os.sep not in result + + +class TestIsLbRouteHandlerField: + """Tests that RemoteFunctionMetadata.is_lb_route_handler is set correctly.""" + + def test_lb_function_with_method_and_path_is_handler(self, tmp_path): + """An LB @remote function with method= and path= is marked as LB route handler.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "routes.py").write_text( + """ +from runpod_flash import CpuLiveLoadBalancer, remote + +lb_config = CpuLiveLoadBalancer(name="my_lb") + +@remote(lb_config, method="POST", path="/compute") +async def compute(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_lb_route_handler is True + + def test_qb_function_is_not_handler(self, tmp_path): + """A QB @remote function is NOT marked as LB route handler.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "worker.py").write_text( + """ +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_lb_route_handler is False + + def test_init_py_files_excluded(self, tmp_path): + """__init__.py files are excluded from scanning.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "__init__.py").write_text( + """ +from runpod_flash import LiveServerless, remote + +gpu_config = LiveServerless(name="gpu_worker") + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + (tmp_path / "worker.py").write_text( + """ +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + # Only the worker.py function should be discovered, not __init__.py + assert len(functions) == 1 + assert functions[0].file_path.name == "worker.py" diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index a652aa75..1e0c549a 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -15,15 +15,27 @@ def runner(): @pytest.fixture def temp_fastapi_app(tmp_path): - """Create minimal FastAPI app for testing.""" - main_file = tmp_path / "main.py" - main_file.write_text("from fastapi import FastAPI\napp = FastAPI()") + """Create minimal Flash project with @remote function for testing.""" + worker_file = tmp_path / "worker.py" + worker_file.write_text( + "from runpod_flash import LiveServerless, remote\n" + "gpu_config = LiveServerless(name='test_worker')\n" + "@remote(gpu_config)\n" + "async def process(data: dict) -> dict:\n" + " return data\n" + ) return tmp_path class TestRunCommandEnvironmentVariables: """Test flash run command environment variable support.""" + @pytest.fixture(autouse=True) + def patch_watcher(self): + """Prevent the background watcher thread from blocking tests.""" + with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): + yield + def test_port_from_environment_variable( self, runner, temp_fastapi_app, monkeypatch ): @@ -215,3 +227,199 @@ def test_short_port_flag_overrides_environment( assert "--port" in call_args port_index = call_args.index("--port") assert call_args[port_index + 1] == "7000" + + +class TestRunCommandHotReload: + """Test flash run hot-reload behavior.""" + + @pytest.fixture(autouse=True) + def patch_watcher(self): + """Prevent the background watcher thread from blocking tests.""" + with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): + yield + + def _invoke_run(self, runner, monkeypatch, temp_fastapi_app, extra_args=None): + """Helper: invoke flash run and return the Popen call args.""" + monkeypatch.chdir(temp_fastapi_app) + monkeypatch.delenv("FLASH_PORT", raising=False) + monkeypatch.delenv("FLASH_HOST", raising=False) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + runner.invoke(app, ["run"] + (extra_args or [])) + + return mock_popen.call_args[0][0] + + def test_reload_watches_flash_server_py( + self, runner, temp_fastapi_app, monkeypatch + ): + """Uvicorn watches .flash/server.py, not the whole project.""" + cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) + + assert "--reload" in cmd + assert "--reload-dir" in cmd + reload_dir_index = cmd.index("--reload-dir") + assert cmd[reload_dir_index + 1] == ".flash" + + assert "--reload-include" in cmd + reload_include_index = cmd.index("--reload-include") + assert cmd[reload_include_index + 1] == "server.py" + + def test_reload_does_not_watch_project_root( + self, runner, temp_fastapi_app, monkeypatch + ): + """Uvicorn reload-dir must not be '.' to prevent double-reload.""" + cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) + + reload_dir_index = cmd.index("--reload-dir") + assert cmd[reload_dir_index + 1] != "." + + def test_no_reload_skips_watcher_thread( + self, runner, temp_fastapi_app, monkeypatch + ): + """--no-reload: neither uvicorn reload args nor watcher thread started.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + runner.invoke(app, ["run", "--no-reload"]) + + cmd = mock_popen.call_args[0][0] + assert "--reload" not in cmd + mock_thread.start.assert_not_called() + + def test_watcher_thread_started_on_reload( + self, runner, temp_fastapi_app, monkeypatch, patch_watcher + ): + """When reload=True, the background watcher thread is started.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + runner.invoke(app, ["run"]) + + mock_thread.start.assert_called_once() + + def test_watcher_thread_stopped_on_keyboard_interrupt( + self, runner, temp_fastapi_app, monkeypatch + ): + """KeyboardInterrupt sets stop_event and joins the watcher thread.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + with patch( + "runpod_flash.cli.commands.run.threading.Event" + ) as mock_event_cls: + mock_stop = MagicMock() + mock_event_cls.return_value = mock_stop + + runner.invoke(app, ["run"]) + + mock_stop.set.assert_called_once() + mock_thread.join.assert_called_once_with(timeout=2) + + +class TestWatchAndRegenerate: + """Unit tests for the _watch_and_regenerate background function.""" + + def test_regenerates_server_py_on_py_file_change(self, tmp_path): + """When a .py file changes, server.py is regenerated.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch( + "runpod_flash.cli.commands.run._scan_project_workers", return_value=[] + ) as mock_scan: + with patch( + "runpod_flash.cli.commands.run._generate_flash_server" + ) as mock_gen: + with patch( + "runpod_flash.cli.commands.run._watchfiles_watch" + ) as mock_watch: + # Yield one batch of changes then stop + mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) + stop.set() # ensures the loop exits after one iteration + _watch_and_regenerate(tmp_path, stop) + + mock_scan.assert_called_once_with(tmp_path) + mock_gen.assert_called_once() + + def test_ignores_non_py_changes(self, tmp_path): + """Changes to non-.py files do not trigger regeneration.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch("runpod_flash.cli.commands.run._scan_project_workers") as mock_scan: + with patch( + "runpod_flash.cli.commands.run._generate_flash_server" + ) as mock_gen: + with patch( + "runpod_flash.cli.commands.run._watchfiles_watch" + ) as mock_watch: + mock_watch.return_value = iter([{(1, "/path/to/README.md")}]) + _watch_and_regenerate(tmp_path, stop) + + mock_scan.assert_not_called() + mock_gen.assert_not_called() + + def test_scan_error_does_not_crash_watcher(self, tmp_path): + """If regeneration raises, the watcher logs a warning and continues.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch( + "runpod_flash.cli.commands.run._scan_project_workers", + side_effect=RuntimeError("scan failed"), + ): + with patch("runpod_flash.cli.commands.run._watchfiles_watch") as mock_watch: + mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) + # Should not raise + _watch_and_regenerate(tmp_path, stop) diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index b5e3cbbe..124eb136 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -467,6 +467,32 @@ def test_is_deployed_false_when_no_id(self): assert serverless.is_deployed() is False + def test_is_deployed_skips_health_check_during_live_provisioning(self, monkeypatch): + """During flash run, is_deployed returns True based on ID alone.""" + monkeypatch.setenv("FLASH_IS_LIVE_PROVISIONING", "true") + serverless = ServerlessResource(name="test") + serverless.id = "ep-live-123" + + # health() must NOT be called — no mock needed, any call would raise + assert serverless.is_deployed() is True + + def test_is_deployed_uses_health_check_outside_live_provisioning(self, monkeypatch): + """Outside flash run, is_deployed falls back to health check.""" + monkeypatch.delenv("FLASH_IS_LIVE_PROVISIONING", raising=False) + serverless = ServerlessResource(name="test") + serverless.id = "ep-123" + + mock_endpoint = MagicMock() + mock_endpoint.health.return_value = {"workers": {}} + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + assert serverless.is_deployed() is True + mock_endpoint.health.assert_called_once() + @pytest.mark.asyncio async def test_deploy_already_deployed(self): """Test deploy returns early when already deployed.""" @@ -938,6 +964,24 @@ def test_is_deployed_with_exception(self): assert result is False + def test_payload_exclude_adds_template_when_template_id_set(self): + """_payload_exclude excludes template field when templateId is already set.""" + serverless = ServerlessResource(name="test") + serverless.templateId = "tmpl-123" + + excluded = serverless._payload_exclude() + + assert "template" in excluded + + def test_payload_exclude_does_not_exclude_template_without_template_id(self): + """_payload_exclude does not exclude template when templateId is absent.""" + serverless = ServerlessResource(name="test") + serverless.templateId = None + + excluded = serverless._payload_exclude() + + assert "template" not in excluded + def test_reverse_sync_from_backend_response(self): """Test reverse sync when receiving backend response with gpuIds.""" # This tests the lines 173-176 which convert gpuIds back to gpus list