diff --git a/scripts/validate-wheel.sh b/scripts/validate-wheel.sh index 7e4dd517..a38db25f 100755 --- a/scripts/validate-wheel.sh +++ b/scripts/validate-wheel.sh @@ -21,14 +21,12 @@ REQUIRED_TEMPLATE_FILES=( "runpod_flash/cli/utils/skeleton_template/.env.example" "runpod_flash/cli/utils/skeleton_template/.gitignore" "runpod_flash/cli/utils/skeleton_template/.flashignore" - "runpod_flash/cli/utils/skeleton_template/main.py" + "runpod_flash/cli/utils/skeleton_template/cpu_worker.py" + "runpod_flash/cli/utils/skeleton_template/gpu_worker.py" + "runpod_flash/cli/utils/skeleton_template/lb_worker.py" + "runpod_flash/cli/utils/skeleton_template/pyproject.toml" "runpod_flash/cli/utils/skeleton_template/README.md" "runpod_flash/cli/utils/skeleton_template/requirements.txt" - "runpod_flash/cli/utils/skeleton_template/workers/__init__.py" - "runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py" - "runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py" - "runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py" - "runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py" ) MISSING_IN_WHEEL=0 @@ -77,7 +75,7 @@ flash init test_project > /dev/null 2>&1 # Verify critical files exist echo "" echo "Verifying created files..." -REQUIRED_FILES=(".env.example" ".gitignore" ".flashignore" "main.py" "README.md" "requirements.txt") +REQUIRED_FILES=(".env.example" ".gitignore" ".flashignore" "cpu_worker.py" "gpu_worker.py" "lb_worker.py" "pyproject.toml" "README.md" "requirements.txt") MISSING_IN_OUTPUT=0 for file in "${REQUIRED_FILES[@]}"; do @@ -94,15 +92,6 @@ for file in "${REQUIRED_FILES[@]}"; do fi done -# Verify workers directory structure -if [ -d "test_project/workers/cpu" ] && [ -d "test_project/workers/gpu" ]; then - echo "[OK] workers/cpu/" - echo "[OK] workers/gpu/" -else - echo "[MISSING] workers directory structure" - MISSING_IN_OUTPUT=$((MISSING_IN_OUTPUT + 1)) -fi - # Cleanup deactivate cd - > /dev/null diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py new file mode 100644 index 00000000..abf48f06 --- /dev/null +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -0,0 +1,110 @@ +"""Helpers for the flash run dev server — loaded inside the generated server.py.""" + +import inspect +from typing import Any, get_type_hints + +from fastapi import HTTPException +from pydantic import create_model + +from runpod_flash.core.resources.resource_manager import ResourceManager +from runpod_flash.stubs.load_balancer_sls import LoadBalancerSlsStub + +_resource_manager = ResourceManager() + + +def _map_body_to_params(func, body): + """Map an HTTP request body to function parameters. + + If the body is a dict whose keys match the function's parameter names, + spread it as kwargs. Otherwise pass the whole body as the value of the + first parameter (mirrors how FastAPI maps a JSON body to a single param). + """ + sig = inspect.signature(func) + param_names = set(sig.parameters.keys()) + + if isinstance(body, dict) and body.keys() <= param_names: + return body + + first_param = next(iter(sig.parameters), None) + if first_param is None: + return {} + return {first_param: body} + + +def make_input_model(name: str, func) -> type | None: + """Create a Pydantic model from a function's signature for FastAPI body typing. + + Returns None for zero-param functions or on failure (caller uses ``or dict``). + """ + try: + sig = inspect.signature(func) + hints = get_type_hints(func) + except (ValueError, TypeError): + return None + + _SKIP_KINDS = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + fields: dict[str, Any] = {} + for param_name, param in sig.parameters.items(): + if param_name == "self" or param.kind in _SKIP_KINDS: + continue + annotation = hints.get(param_name, Any) + if param.default is not inspect.Parameter.empty: + fields[param_name] = (annotation, param.default) + else: + fields[param_name] = (annotation, ...) + + if not fields: + return None + + return create_model(name, **fields) + + +async def call_with_body(func, body): + """Call func with body kwargs, handling Pydantic models and dicts.""" + if hasattr(body, "model_dump"): + return await func(**body.model_dump()) + raw = body.get("input", body) if isinstance(body, dict) else body + kwargs = _map_body_to_params(func, raw) + return await func(**kwargs) + + +def to_dict(body) -> dict: + """Convert Pydantic model or dict to plain dict.""" + return body.model_dump() if hasattr(body, "model_dump") else body + + +async def lb_execute(resource_config, func, body: dict): + """Dispatch an LB route to the deployed endpoint via LoadBalancerSlsStub. + + Provisions the endpoint via ResourceManager, maps the HTTP body to + function kwargs, then dispatches through the stub's /execute path + which serializes the function via cloudpickle to the remote container. + + Args: + resource_config: The resource config object (e.g. LiveLoadBalancer instance). + func: The @remote LB route handler function. + body: Parsed request body (from FastAPI's automatic JSON parsing). + """ + try: + deployed = await _resource_manager.get_or_deploy_resource(resource_config) + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Failed to provision '{resource_config.name}': {e}", + ) + + stub = LoadBalancerSlsStub(deployed) + kwargs = _map_body_to_params(func, body) + + try: + return await stub(func, None, None, False, **kwargs) + except TimeoutError as e: + raise HTTPException(status_code=504, detail=str(e)) + except ConnectionError as e: + raise HTTPException(status_code=502, detail=str(e)) + except HTTPException: + raise + except (ValueError, KeyError, TypeError) as e: + raise HTTPException(status_code=422, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/runpod_flash/cli/commands/build.py b/src/runpod_flash/cli/commands/build.py index 44a1000e..1945c1ad 100644 --- a/src/runpod_flash/cli/commands/build.py +++ b/src/runpod_flash/cli/commands/build.py @@ -23,6 +23,7 @@ from runpod_flash.core.resources.constants import MAX_TARBALL_SIZE_MB from ..utils.ignore import get_file_tree, load_ignore_patterns +from .build_utils.lb_handler_generator import LBHandlerGenerator from .build_utils.manifest import ManifestBuilder from .build_utils.scanner import RemoteDecoratorScanner @@ -240,6 +241,9 @@ def run_build( manifest_path = build_dir / "flash_manifest.json" manifest_path.write_text(json.dumps(manifest, indent=2)) + lb_generator = LBHandlerGenerator(manifest, build_dir) + lb_generator.generate_handlers() + flash_dir = project_dir / ".flash" deployment_manifest_path = flash_dir / "flash_manifest.json" shutil.copy2(manifest_path, deployment_manifest_path) @@ -426,28 +430,19 @@ def validate_project_structure(project_dir: Path) -> bool: """ Validate that directory is a Flash project. + A Flash project is any directory containing Python files. The + RemoteDecoratorScanner validates that @remote functions exist. + Args: project_dir: Directory to validate Returns: True if valid Flash project """ - main_py = project_dir / "main.py" - - if not main_py.exists(): - console.print(f"[red]Error:[/red] main.py not found in {project_dir}") + py_files = list(project_dir.rglob("*.py")) + if not py_files: + console.print(f"[red]Error:[/red] No Python files found in {project_dir}") return False - - # Check if main.py has FastAPI app - try: - content = main_py.read_text(encoding="utf-8") - if "FastAPI" not in content: - console.print( - "[yellow]Warning:[/yellow] main.py does not appear to have a FastAPI app" - ) - except Exception: - pass - return True diff --git a/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py b/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py index dcd0845d..a0d28601 100644 --- a/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py +++ b/src/runpod_flash/cli/commands/build_utils/lb_handler_generator.py @@ -21,13 +21,10 @@ - Real-time communication patterns """ -import asyncio import logging from contextlib import asynccontextmanager -from pathlib import Path -from typing import Optional -from fastapi import FastAPI, Request +from fastapi import FastAPI from runpod_flash.runtime.lb_handler import create_lb_handler logger = logging.getLogger(__name__) @@ -45,57 +42,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): """Handle application startup and shutdown.""" - # Startup logger.info("Starting {resource_name} endpoint") - - # Check if this is the mothership and run reconciliation - # Note: Resources are now provisioned upfront by the CLI during deployment. - # This background task runs reconciliation on mothership startup to ensure - # all resources are still deployed and in sync with the manifest. - try: - from runpod_flash.runtime.mothership_provisioner import ( - is_mothership, - reconcile_children, - get_mothership_url, - ) - from runpod_flash.runtime.state_manager_client import StateManagerClient - - if is_mothership(): - logger.info("=" * 60) - logger.info("Mothership detected - Starting reconciliation task") - logger.info("Resources are provisioned upfront by the CLI") - logger.info("This task ensures all resources remain in sync") - logger.info("=" * 60) - try: - mothership_url = get_mothership_url() - logger.info(f"Mothership URL: {{mothership_url}}") - - # Initialize State Manager client for reconciliation - state_client = StateManagerClient() - - # Spawn background reconciliation task (non-blocking) - # This will verify all resources from manifest are deployed - manifest_path = Path(__file__).parent / "flash_manifest.json" - task = asyncio.create_task( - reconcile_children(manifest_path, mothership_url, state_client) - ) - # Add error callback to catch and log background task exceptions - task.add_done_callback( - lambda t: logger.error(f"Reconciliation task failed: {{t.exception()}}") - if t.exception() - else None - ) - - except Exception as e: - logger.error(f"Failed to start reconciliation task: {{e}}") - # Don't fail startup - continue serving traffic - - except ImportError: - logger.debug("Mothership provisioning modules not available") - yield - - # Shutdown logger.info("Shutting down {resource_name} endpoint") diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index b67ce9bd..4c5998bf 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -9,45 +9,17 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from runpod_flash.core.resources.constants import ( - DEFAULT_WORKERS_MAX, - DEFAULT_WORKERS_MIN, - FLASH_CPU_LB_IMAGE, - FLASH_LB_IMAGE, +from .scanner import ( + RemoteFunctionMetadata, + file_to_module_path, + file_to_url_prefix, ) -from .scanner import RemoteFunctionMetadata, detect_explicit_mothership, detect_main_app - logger = logging.getLogger(__name__) RESERVED_PATHS = ["/execute", "/ping"] -def _serialize_routes(routes: List[RemoteFunctionMetadata]) -> List[Dict[str, Any]]: - """Convert RemoteFunctionMetadata to manifest dict format. - - Args: - routes: List of route metadata objects - - Returns: - List of dicts with route information for manifest - """ - return [ - { - "name": route.function_name, - "module": route.module_path, - "is_async": route.is_async, - "is_class": route.is_class, - "is_load_balanced": route.is_load_balanced, - "is_live_resource": route.is_live_resource, - "config_variable": route.config_variable, - "http_method": route.http_method, - "http_path": route.http_path, - } - for route in routes - ] - - @dataclass class ManifestFunction: """Function entry in manifest.""" @@ -213,83 +185,13 @@ def _extract_deployment_config( return config - def _create_mothership_resource(self, main_app_config: dict) -> Dict[str, Any]: - """Create implicit mothership resource from main.py. - - Args: - main_app_config: Dict with 'file_path', 'app_variable', 'has_routes', 'fastapi_routes' keys - - Returns: - Dictionary representing the mothership resource for the manifest - """ - # Extract FastAPI routes if present - fastapi_routes = main_app_config.get("fastapi_routes", []) - functions_list = _serialize_routes(fastapi_routes) - - return { - "resource_type": "CpuLiveLoadBalancer", - "functions": functions_list, - "is_load_balanced": True, - "is_live_resource": True, - "is_mothership": True, - "main_file": main_app_config["file_path"].name, - "app_variable": main_app_config["app_variable"], - "imageName": FLASH_CPU_LB_IMAGE, - "workersMin": DEFAULT_WORKERS_MIN, - "workersMax": DEFAULT_WORKERS_MAX, - } - - def _create_mothership_from_explicit( - self, explicit_config: dict, search_dir: Path - ) -> Dict[str, Any]: - """Create mothership resource from explicit mothership.py configuration. - - Args: - explicit_config: Configuration dict from detect_explicit_mothership() - search_dir: Project directory + def build(self) -> Dict[str, Any]: + """Build the manifest dictionary. - Returns: - Dictionary representing the mothership resource for the manifest + Resources are keyed by resource_config_name for runtime compatibility. + Each resource entry includes file_path, local_path_prefix, and module_path + for the dev server and LB handler generator. """ - # Detect FastAPI app details for handler generation - main_app_config = detect_main_app(search_dir, explicit_mothership_exists=False) - - if not main_app_config: - # No FastAPI app found, use defaults - main_file = "main.py" - app_variable = "app" - fastapi_routes = [] - else: - main_file = main_app_config["file_path"].name - app_variable = main_app_config["app_variable"] - fastapi_routes = main_app_config.get("fastapi_routes", []) - - # Extract FastAPI routes into functions list - functions_list = _serialize_routes(fastapi_routes) - - # Map resource type to image name - resource_type = explicit_config.get("resource_type", "CpuLiveLoadBalancer") - if resource_type == "LiveLoadBalancer": - image_name = FLASH_LB_IMAGE # GPU load balancer - else: - image_name = FLASH_CPU_LB_IMAGE # CPU load balancer - - return { - "resource_type": resource_type, - "functions": functions_list, - "is_load_balanced": True, - "is_live_resource": True, - "is_mothership": True, - "is_explicit": True, # Flag to indicate explicit configuration - "main_file": main_file, - "app_variable": app_variable, - "imageName": image_name, - "workersMin": explicit_config.get("workersMin", DEFAULT_WORKERS_MIN), - "workersMax": explicit_config.get("workersMax", DEFAULT_WORKERS_MAX), - } - - def build(self) -> Dict[str, Any]: - """Build the manifest dictionary.""" # Group functions by resource_config_name resources: Dict[str, List[RemoteFunctionMetadata]] = {} @@ -305,6 +207,11 @@ def build(self) -> Dict[str, Any]: str, Dict[str, str] ] = {} # resource_name -> {route_key -> function_name} + # Determine project root for path derivation. + # build_dir is .flash/.build which *contains* the copied project files, + # so use it directly (not its parent, which would be .flash/). + project_root = self.build_dir if self.build_dir else Path.cwd() + for resource_name, functions in sorted(resources.items()): # Use actual resource type from first function in group resource_type = ( @@ -315,6 +222,27 @@ def build(self) -> Dict[str, Any]: is_load_balanced = functions[0].is_load_balanced if functions else False is_live_resource = functions[0].is_live_resource if functions else False + # Derive path fields from the first function's source file. + # All functions in a resource share the same source file per convention. + first_file = functions[0].file_path if functions else None + file_path_str = "" + local_path_prefix = "" + resource_module_path = functions[0].module_path if functions else "" + + if first_file and first_file.exists(): + try: + file_path_str = str(first_file.relative_to(project_root)) + local_path_prefix = file_to_url_prefix(first_file, project_root) + resource_module_path = file_to_module_path(first_file, project_root) + except ValueError: + # File is outside project root — fall back to module_path + file_path_str = str(first_file) + local_path_prefix = "/" + functions[0].module_path.replace(".", "/") + elif first_file: + # File path may be relative (in test scenarios) + file_path_str = str(first_file) + local_path_prefix = "/" + functions[0].module_path.replace(".", "/") + # Validate and collect routing for LB endpoints resource_routes = {} if is_load_balanced: @@ -374,6 +302,9 @@ def build(self) -> Dict[str, Any]: resources_dict[resource_name] = { "resource_type": resource_type, + "file_path": file_path_str, + "local_path_prefix": local_path_prefix, + "module_path": resource_module_path, "functions": functions_list, "is_load_balanced": is_load_balanced, "is_live_resource": is_live_resource, @@ -395,68 +326,6 @@ def build(self) -> Dict[str, Any]: ) function_registry[f.function_name] = resource_name - # === MOTHERSHIP DETECTION (EXPLICIT THEN FALLBACK) === - search_dir = self.build_dir if self.build_dir else Path.cwd() - - # Step 1: Check for explicit mothership.py - explicit_mothership = detect_explicit_mothership(search_dir) - - if explicit_mothership: - # Use explicit configuration - logger.debug("Found explicit mothership configuration in mothership.py") - - # Check for name conflict - mothership_name = explicit_mothership.get("name", "mothership") - if mothership_name in resources_dict: - logger.warning( - f"Project has a @remote resource named '{mothership_name}'. " - f"Using 'mothership-entrypoint' for explicit mothership endpoint." - ) - mothership_name = "mothership-entrypoint" - - # Create mothership resource from explicit config - mothership_resource = self._create_mothership_from_explicit( - explicit_mothership, search_dir - ) - resources_dict[mothership_name] = mothership_resource - - else: - # Step 2: Fallback to auto-detection - main_app_config = detect_main_app( - search_dir, explicit_mothership_exists=False - ) - - if main_app_config and main_app_config["has_routes"]: - logger.warning( - "Auto-detected FastAPI app in main.py (no mothership.py found). " - "Consider running 'flash init' to create explicit mothership configuration." - ) - - # Check for name conflict - if "mothership" in resources_dict: - logger.warning( - "Project has a @remote resource named 'mothership'. " - "Using 'mothership-entrypoint' for auto-generated mothership endpoint." - ) - mothership_name = "mothership-entrypoint" - else: - mothership_name = "mothership" - - # Create mothership resource from auto-detection (legacy behavior) - mothership_resource = self._create_mothership_resource(main_app_config) - resources_dict[mothership_name] = mothership_resource - - # Extract routes from mothership resources - for resource_name, resource in resources_dict.items(): - if resource.get("is_mothership") and resource.get("functions"): - mothership_routes = {} - for func in resource["functions"]: - if func.get("http_method") and func.get("http_path"): - route_key = f"{func['http_method']} {func['http_path']}" - mothership_routes[route_key] = func["name"] - if mothership_routes: - routes_dict[resource_name] = mothership_routes - manifest = { "version": "1.0", "generated_at": datetime.now(timezone.utc) diff --git a/src/runpod_flash/cli/commands/build_utils/scanner.py b/src/runpod_flash/cli/commands/build_utils/scanner.py index 2215ab9e..892f3bcf 100644 --- a/src/runpod_flash/cli/commands/build_utils/scanner.py +++ b/src/runpod_flash/cli/commands/build_utils/scanner.py @@ -3,6 +3,7 @@ import ast import importlib import logging +import os import re from dataclasses import dataclass, field from pathlib import Path @@ -11,6 +12,61 @@ logger = logging.getLogger(__name__) +def file_to_url_prefix(file_path: Path, project_root: Path) -> str: + """Derive the local dev server URL prefix from a source file path. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + URL prefix starting with "/" (e.g., /longruns/stage1) + + Example: + longruns/stage1.py → /longruns/stage1 + """ + rel = file_path.relative_to(project_root).with_suffix("") + return "/" + str(rel).replace(os.sep, "/") + + +def file_to_resource_name(file_path: Path, project_root: Path) -> str: + """Derive the manifest resource name from a source file path. + + Slashes and hyphens are replaced with underscores to produce a valid + Python identifier suitable for use as a resource name. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + Resource name using underscores (e.g., longruns_stage1) + + Example: + longruns/stage1.py → longruns_stage1 + my-worker.py → my_worker + """ + rel = file_path.relative_to(project_root).with_suffix("") + return str(rel).replace(os.sep, "_").replace("/", "_").replace("-", "_") + + +def file_to_module_path(file_path: Path, project_root: Path) -> str: + """Derive the Python dotted module path from a source file path. + + Args: + file_path: Absolute path to the Python source file + project_root: Absolute path to the project root directory + + Returns: + Dotted module path (e.g., longruns.stage1) + + Example: + longruns/stage1.py → longruns.stage1 + """ + rel = file_path.relative_to(project_root).with_suffix("") + return str(rel).replace(os.sep, ".").replace("/", ".") + + @dataclass class RemoteFunctionMetadata: """Metadata about a @remote decorated function or class.""" @@ -35,6 +91,18 @@ class RemoteFunctionMetadata: called_remote_functions: List[str] = field( default_factory=list ) # Names of @remote functions called + is_lb_route_handler: bool = ( + False # LB @remote with method= and path= — runs directly as HTTP handler + ) + class_methods: List[str] = field( + default_factory=list + ) # Public methods for @remote classes + param_names: List[str] = field( + default_factory=list + ) # Function params excluding self + class_method_params: Dict[str, List[str]] = field( + default_factory=dict + ) # method_name -> param_names (for classes) class RemoteDecoratorScanner: @@ -62,7 +130,9 @@ def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: rel_path = f.relative_to(self.project_dir) # Check if first part of path is in excluded_root_dirs if rel_path.parts and rel_path.parts[0] not in excluded_root_dirs: - self.py_files.append(f) + # Exclude __init__.py — not valid worker entry points + if f.name != "__init__.py": + self.py_files.append(f) except (ValueError, IndexError): # Include files that can't be made relative self.py_files.append(f) @@ -220,6 +290,39 @@ def _extract_remote_functions( {"is_load_balanced": False, "is_live_resource": False}, ) + # An LB route handler is an LB @remote function that has + # both method= and path= declared. Its body runs directly + # on the LB endpoint — it is NOT a remote dispatch stub. + is_lb_route_handler = ( + flags["is_load_balanced"] + and http_method is not None + and http_path is not None + ) + + # Extract public methods for @remote classes + class_methods: List[str] = [] + class_method_params: Dict[str, List[str]] = {} + if is_class: + for n in node.body: + if isinstance( + n, (ast.FunctionDef, ast.AsyncFunctionDef) + ) and not n.name.startswith("_"): + class_methods.append(n.name) + class_method_params[n.name] = [ + arg.arg + for arg in n.args.args + if arg.arg != "self" + ] + + # Extract param names for functions (not classes) + param_names: List[str] = [] + if not is_class and isinstance( + node, (ast.FunctionDef, ast.AsyncFunctionDef) + ): + param_names = [ + arg.arg for arg in node.args.args if arg.arg != "self" + ] + metadata = RemoteFunctionMetadata( function_name=node.name, module_path=module_path, @@ -235,6 +338,10 @@ def _extract_remote_functions( config_variable=self.resource_variables.get( resource_config_name ), + is_lb_route_handler=is_lb_route_handler, + class_methods=class_methods, + param_names=param_names, + class_method_params=class_method_params, ) functions.append(metadata) diff --git a/src/runpod_flash/cli/commands/init.py b/src/runpod_flash/cli/commands/init.py index 15a96d3d..eabd7583 100644 --- a/src/runpod_flash/cli/commands/init.py +++ b/src/runpod_flash/cli/commands/init.py @@ -5,6 +5,8 @@ import typer from rich.console import Console +from rich.panel import Panel +from rich.table import Table from ..utils.skeleton import create_project_skeleton, detect_file_conflicts @@ -19,70 +21,99 @@ def init_command( ): """Create new Flash project with Flash Server and GPU workers.""" + # Determine target directory and initialization mode if project_name is None or project_name == ".": + # Initialize in current directory project_dir = Path.cwd() is_current_dir = True + # Use current directory name as project name actual_project_name = project_dir.name else: + # Create new directory project_dir = Path(project_name) is_current_dir = False actual_project_name = project_name + # Create project directory if needed if not is_current_dir: project_dir.mkdir(parents=True, exist_ok=True) + # Check for file conflicts in target directory conflicts = detect_file_conflicts(project_dir) - should_overwrite = force + should_overwrite = force # Start with force flag value if conflicts and not force: + # Show warning and prompt user console.print( - "[yellow]Warning:[/yellow] The following files will be overwritten:\n" + Panel( + "[yellow]Warning: The following files will be overwritten:[/yellow]\n\n" + + "\n".join(f" • {conflict}" for conflict in conflicts), + title="File Conflicts Detected", + expand=False, + ) ) - for conflict in conflicts: - console.print(f" {conflict}") - console.print() + # Prompt user for confirmation proceed = typer.confirm("Continue and overwrite these files?", default=False) if not proceed: - console.print("[yellow]Cancelled[/yellow]") + console.print("[yellow]Initialization aborted.[/yellow]") raise typer.Exit(0) + # User confirmed, so we should overwrite should_overwrite = True + # Create project skeleton status_msg = ( - "Initializing Flash project..." + "Initializing Flash project in current directory..." if is_current_dir else f"Creating Flash project '{project_name}'..." ) with console.status(status_msg): create_project_skeleton(project_dir, should_overwrite) - console.print(f"[green]Created[/green] [bold]{actual_project_name}[/bold]\n") - - prefix = "./" if is_current_dir else f"{actual_project_name}/" - console.print(f" {prefix}") - console.print(" ├── main.py FastAPI server") - console.print(" ├── mothership.py Mothership config") - console.print(" ├── pyproject.toml") - console.print(" ├── workers/") - console.print(" │ ├── gpu/") - console.print(" │ └── cpu/") - console.print(" ├── .env.example") - console.print(" ├── requirements.txt") - console.print(" └── README.md") - + # Success output + if is_current_dir: + panel_content = f"Flash project '[bold]{actual_project_name}[/bold]' initialized in current directory!\n\n" + panel_content += "Project structure:\n" + panel_content += " ./\n" + else: + panel_content = f"Flash project '[bold]{actual_project_name}[/bold]' created successfully!\n\n" + panel_content += "Project structure:\n" + panel_content += f" {actual_project_name}/\n" + + panel_content += " ├── gpu_worker.py # GPU serverless worker\n" + panel_content += " ├── cpu_worker.py # CPU serverless worker\n" + panel_content += " ├── lb_worker.py # CPU load-balanced API\n" + panel_content += " ├── pyproject.toml\n" + panel_content += " ├── .env.example\n" + panel_content += " ├── requirements.txt\n" + panel_content += " └── README.md\n" + + title = "Project Initialized" if is_current_dir else "Project Created" + console.print(Panel(panel_content, title=title, expand=False)) + + # Next steps console.print("\n[bold]Next steps:[/bold]") + steps_table = Table(show_header=False, box=None, padding=(0, 1)) + steps_table.add_column("Step", style="bold cyan") + steps_table.add_column("Description") + step_num = 1 if not is_current_dir: - console.print(f" {step_num}. cd {actual_project_name}") + steps_table.add_row(f"{step_num}.", f"cd {actual_project_name}") step_num += 1 - console.print(f" {step_num}. pip install -r requirements.txt") + + steps_table.add_row(f"{step_num}.", "pip install -r requirements.txt") + step_num += 1 + steps_table.add_row(f"{step_num}.", "cp .env.example .env") step_num += 1 - console.print(f" {step_num}. cp .env.example .env && add RUNPOD_API_KEY") + steps_table.add_row(f"{step_num}.", "Add your RUNPOD_API_KEY to .env") step_num += 1 - console.print(f" {step_num}. flash run") + steps_table.add_row(f"{step_num}.", "flash run") - console.print( - "\n [dim]API keys: https://docs.runpod.io/get-started/api-keys[/dim]" - ) - console.print(" [dim]Docs: http://localhost:8888/docs (after running)[/dim]") + console.print(steps_table) + + console.print("\n[bold]Get your API key:[/bold]") + console.print(" https://docs.runpod.io/get-started/api-keys") + console.print("\nVisit http://localhost:8888/docs after running") + console.print("\nCheck out the README.md for more") diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 051115a8..db2d354c 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -2,19 +2,804 @@ import logging import os +import re import signal import subprocess import sys +import threading +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import List -import questionary import typer from rich.console import Console +from rich.table import Table + +try: + from watchfiles import DefaultFilter as _WatchfilesDefaultFilter + from watchfiles import watch as _watchfiles_watch +except ModuleNotFoundError: + + def _watchfiles_watch(*_a, **_kw): # type: ignore[misc] + raise ModuleNotFoundError( + "watchfiles is required for flash run --reload. " + "Install it with: pip install watchfiles" + ) + + class _WatchfilesDefaultFilter: # type: ignore[no-redef] + def __init__(self, **_kw): + pass + + +from .build_utils.scanner import ( + RemoteDecoratorScanner, + file_to_module_path, + file_to_resource_name, + file_to_url_prefix, +) logger = logging.getLogger(__name__) console = Console() +# Resource state file written by ResourceManager in the uvicorn subprocess. +_RESOURCE_STATE_FILE = Path(".runpod") / "resources.pkl" + + +@dataclass +class WorkerInfo: + """Info about a discovered @remote function for dev server generation.""" + + file_path: Path + url_prefix: str # e.g. /longruns/stage1 + module_path: str # e.g. longruns.stage1 + resource_name: str # e.g. longruns_stage1 + worker_type: str # "QB" or "LB" + functions: List[str] # function names + class_remotes: List[dict] = field( + default_factory=list + ) # [{name, methods, method_params}] + lb_routes: List[dict] = field(default_factory=list) # [{method, path, fn_name}] + function_params: dict[str, list[str]] = field( + default_factory=dict + ) # fn_name -> param_names + + +def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: + """Scan the project for all @remote decorated functions. + + Walks all .py files (excluding .flash/, __pycache__, __init__.py) and + builds WorkerInfo for each file that contains @remote functions. + + Files with QB functions produce one WorkerInfo per file (QB type). + Files with LB functions produce one WorkerInfo per file (LB type). + A file can have both QB and LB functions (unusual but supported). + + Args: + project_root: Root directory of the Flash project + + Returns: + List of WorkerInfo, one entry per discovered source file + """ + scanner = RemoteDecoratorScanner(project_root) + remote_functions = scanner.discover_remote_functions() + + # Group by file path + by_file: dict[Path, List] = {} + for func in remote_functions: + by_file.setdefault(func.file_path, []).append(func) + + workers: List[WorkerInfo] = [] + for file_path, funcs in sorted(by_file.items()): + url_prefix = file_to_url_prefix(file_path, project_root) + module_path = file_to_module_path(file_path, project_root) + resource_name = file_to_resource_name(file_path, project_root) + + qb_funcs = [f for f in funcs if not f.is_load_balanced and not f.is_class] + qb_classes = [f for f in funcs if not f.is_load_balanced and f.is_class] + lb_funcs = [f for f in funcs if f.is_load_balanced and f.is_lb_route_handler] + + if qb_funcs or qb_classes: + workers.append( + WorkerInfo( + file_path=file_path, + url_prefix=url_prefix, + module_path=module_path, + resource_name=resource_name, + worker_type="QB", + functions=[f.function_name for f in qb_funcs], + class_remotes=[ + { + "name": c.function_name, + "methods": c.class_methods, + "method_params": c.class_method_params, + } + for c in qb_classes + ], + function_params={f.function_name: f.param_names for f in qb_funcs}, + ) + ) + + if lb_funcs: + lb_routes = [ + { + "method": f.http_method, + "path": f.http_path, + "fn_name": f.function_name, + "config_variable": f.config_variable, + } + for f in lb_funcs + ] + workers.append( + WorkerInfo( + file_path=file_path, + url_prefix=url_prefix, + module_path=module_path, + resource_name=resource_name, + worker_type="LB", + functions=[f.function_name for f in lb_funcs], + lb_routes=lb_routes, + ) + ) + + return workers + + +def _ensure_gitignore(project_root: Path) -> None: + """Add .flash/ to .gitignore if not already present.""" + gitignore = project_root / ".gitignore" + entry = ".flash/" + + if gitignore.exists(): + content = gitignore.read_text(encoding="utf-8") + if entry in content: + return + # Append with a newline + if not content.endswith("\n"): + content += "\n" + gitignore.write_text(content + entry + "\n", encoding="utf-8") + else: + gitignore.write_text(entry + "\n", encoding="utf-8") + + +def _sanitize_fn_name(name: str) -> str: + """Sanitize a string for use as a Python function name. + + Replaces non-identifier characters with underscores and prepends '_' + if the result starts with a digit (Python identifiers cannot start + with digits). + """ + result = name.replace("/", "_").replace(".", "_").replace("-", "_") + if result and result[0].isdigit(): + result = "_" + result + return result + + +def _has_numeric_module_segments(module_path: str) -> bool: + """Check if any segment in a dotted module path starts with a digit. + + Python identifiers cannot start with digits, so ``from 01_foo import bar`` + is a SyntaxError. Callers should use ``importlib.import_module()`` instead. + """ + return any(seg and seg[0].isdigit() for seg in module_path.split(".")) + + +def _module_parent_subdir(module_path: str) -> str | None: + """Return the parent sub-directory for a dotted module path, or None for top-level. + + Example: ``01_getting_started.03_mixed.pipeline`` → ``01_getting_started/03_mixed`` + """ + parts = module_path.rsplit(".", 1) + if len(parts) == 1: + return None + return parts[0].replace(".", "/") + + +def _make_import_line(module_path: str, name: str) -> str: + """Build an import statement for *name* from *module_path*. + + Uses a regular ``from … import …`` when the module path is a valid + Python identifier chain. Falls back to ``_flash_import()`` (a generated + helper in server.py) when any segment starts with a digit. The helper + temporarily scopes ``sys.path`` so sibling imports in the target module + resolve to the correct directory. + """ + if _has_numeric_module_segments(module_path): + subdir = _module_parent_subdir(module_path) + if subdir: + return f'{name} = _flash_import("{module_path}", "{name}", "{subdir}")' + return f'{name} = _flash_import("{module_path}", "{name}")' + return f"from {module_path} import {name}" + + +_PATH_PARAM_RE = re.compile(r"\{(\w+)\}") + + +def _extract_path_params(path: str) -> list[str]: + """Extract path parameter names from a FastAPI-style route path. + + Example: "/images/{file_id}" -> ["file_id"] + """ + return _PATH_PARAM_RE.findall(path) + + +def _build_call_expr(callable_name: str, params: list[str] | None) -> tuple[str, bool]: + """Build an async call expression based on parameter count. + + Args: + callable_name: Fully qualified callable (e.g. "fn" or "instance.method") + params: List of param names, or None if unknown (backward compat) + + Returns: + Tuple of (call_expression, needs_body). needs_body is False when the + handler signature should omit the body parameter. + """ + if params is not None and len(params) == 0: + return f"await {callable_name}()", False + return f"await _call_with_body({callable_name}, body)", True + + +def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: + """Generate .flash/server.py from the discovered workers. + + Args: + project_root: Root of the Flash project + workers: List of discovered worker infos + + Returns: + Path to the generated server.py + """ + flash_dir = project_root / ".flash" + flash_dir.mkdir(exist_ok=True) + + _ensure_gitignore(project_root) + + has_lb_workers = any(w.worker_type == "LB" for w in workers) + + lines = [ + '"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run."""', + "import sys", + "import uuid", + "from pathlib import Path", + "_project_root = Path(__file__).parent.parent", + "sys.path.insert(0, str(_project_root))", + "", + ] + + # When modules live in directories with numeric prefixes (e.g. 01_hello/), + # we cannot use ``from … import …`` — Python identifiers cannot start with + # digits. Instead we emit a small ``_flash_import`` helper that uses + # ``importlib.import_module()`` *and* temporarily scopes ``sys.path`` so + # that sibling imports inside the loaded module (e.g. ``from cpu_worker + # import …``) resolve to the correct directory rather than a same-named + # file from a different example subdirectory. + needs_importlib = any(_has_numeric_module_segments(w.module_path) for w in workers) + + if needs_importlib: + lines += [ + "import importlib as _importlib", + "", + "", + "def _flash_import(module_path, name, subdir=None):", + ' """Import *name* from *module_path* with scoped sys.path for sibling imports."""', + " _path = str(_project_root / subdir) if subdir else None", + " if _path:", + " sys.path.insert(0, _path)", + " try:", + " return getattr(_importlib.import_module(module_path), name)", + " finally:", + " if _path is not None:", + " try:", + " if sys.path and sys.path[0] == _path:", + " sys.path.pop(0)", + " else:", + " sys.path.remove(_path)", + " except ValueError:", + " pass", + "", + ] + + lines += [ + "from runpod_flash.cli.commands._run_server_helpers import make_input_model as _make_input_model", + "from runpod_flash.cli.commands._run_server_helpers import call_with_body as _call_with_body", + ] + + if has_lb_workers: + lines += [ + "from fastapi import FastAPI, Request", + "from runpod_flash.cli.commands._run_server_helpers import lb_execute as _lb_execute", + "from runpod_flash.cli.commands._run_server_helpers import to_dict as _to_dict", + "", + ] + else: + lines += [ + "from fastapi import FastAPI", + "", + ] + + # Collect imports — QB functions are called directly, LB config variables and + # functions are passed to lb_execute for dispatch via LoadBalancerSlsStub. + all_imports: List[str] = [] + for worker in workers: + if worker.worker_type == "QB": + for fn_name in worker.functions: + all_imports.append(_make_import_line(worker.module_path, fn_name)) + for cls_info in worker.class_remotes: + all_imports.append( + _make_import_line(worker.module_path, cls_info["name"]) + ) + elif worker.worker_type == "LB": + # Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...)) + config_vars = { + r["config_variable"] + for r in worker.lb_routes + if r.get("config_variable") + } + for var in sorted(config_vars): + all_imports.append(_make_import_line(worker.module_path, var)) + for fn_name in worker.functions: + all_imports.append(_make_import_line(worker.module_path, fn_name)) + + if all_imports: + lines.extend(all_imports) + lines.append("") + + lines += [ + "app = FastAPI(", + ' title="Flash Dev Server",', + ' description="Auto-generated by `flash run`. Visit /docs for interactive testing.",', + ")", + "", + ] + + # Module-level instance creation for @remote classes + for worker in workers: + for cls_info in worker.class_remotes: + cls_name = cls_info["name"] + lines.append(f"_instance_{cls_name} = {cls_name}()") + # Add blank line if any instances were created + if any(worker.class_remotes for worker in workers): + lines.append("") + + # Module-level Pydantic model creation for typed Swagger UI + model_lines: list[str] = [] + for worker in workers: + if worker.worker_type == "QB": + for fn in worker.functions: + params = worker.function_params.get(fn) + if params is None or len(params) > 0: + model_var = f"_{worker.resource_name}_{fn}_Input" + model_lines.append( + f'{model_var} = _make_input_model("{model_var}", {fn}) or dict' + ) + for cls_info in worker.class_remotes: + cls_name = cls_info["name"] + method_params = cls_info.get("method_params", {}) + instance_var = f"_instance_{cls_name}" + for method in cls_info["methods"]: + params = method_params.get(method) + if params is None or len(params) > 0: + model_var = f"_{worker.resource_name}_{cls_name}_{method}_Input" + # Use _class_type to get the original unwrapped method + # (RemoteClassWrapper.__getattr__ returns proxies with (*args, **kwargs)) + class_ref = f"getattr({instance_var}, '_class_type', type({instance_var}))" + model_lines.append( + f'{model_var} = _make_input_model("{model_var}", {class_ref}.{method}) or dict' + ) + elif worker.worker_type == "LB": + for route in worker.lb_routes: + method = route["method"].lower() + if method in ("post", "put", "patch", "delete"): + fn_name = route["fn_name"] + model_var = f"_{worker.resource_name}_{fn_name}_Input" + model_lines.append( + f'{model_var} = _make_input_model("{model_var}", {fn_name}) or dict' + ) + if model_lines: + lines.extend(model_lines) + lines.append("") + + for worker in workers: + # Group routes by project directory in Swagger UI. + # Nested: /03_mixed_workers/cpu_worker -> "03_mixed_workers/" + # Root: /worker -> "worker" + prefix = worker.url_prefix.lstrip("/") + tag = f"{prefix.rsplit('/', 1)[0]}/" if "/" in prefix else prefix + lines.append(f"# {'─' * 60}") + lines.append(f"# {worker.worker_type}: {worker.file_path.name}") + lines.append(f"# {'─' * 60}") + + if worker.worker_type == "QB": + # Total callable count: functions + sum of class methods + total_class_methods = sum(len(c["methods"]) for c in worker.class_remotes) + total_callables = len(worker.functions) + total_class_methods + use_multi = total_callables > 1 + + # Function-based routes + for fn in worker.functions: + if use_multi: + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{fn}_run_sync" + ) + sync_path = f"{worker.url_prefix}/{fn}/run_sync" + else: + handler_name = _sanitize_fn_name(f"{worker.resource_name}_run_sync") + sync_path = f"{worker.url_prefix}/run_sync" + params = worker.function_params.get(fn) + call_expr, needs_body = _build_call_expr(fn, params) + if needs_body: + model_var = f"_{worker.resource_name}_{fn}_Input" + handler_sig = f"async def {handler_name}(body: {model_var}):" + else: + handler_sig = f"async def {handler_name}():" + lines += [ + f'@app.post("{sync_path}", tags=["{tag}"])', + handler_sig, + f" result = {call_expr}", + ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', + "", + ] + + # Class-based routes + for cls_info in worker.class_remotes: + cls_name = cls_info["name"] + methods = cls_info["methods"] + method_params = cls_info.get("method_params", {}) + instance_var = f"_instance_{cls_name}" + + for method in methods: + if use_multi: + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{cls_name}_{method}_run_sync" + ) + sync_path = f"{worker.url_prefix}/{method}/run_sync" + else: + handler_name = _sanitize_fn_name( + f"{worker.resource_name}_{cls_name}_run_sync" + ) + sync_path = f"{worker.url_prefix}/run_sync" + params = method_params.get(method) + call_expr, needs_body = _build_call_expr( + f"{instance_var}.{method}", params + ) + if needs_body: + model_var = f"_{worker.resource_name}_{cls_name}_{method}_Input" + handler_sig = f"async def {handler_name}(body: {model_var}):" + else: + handler_sig = f"async def {handler_name}():" + lines += [ + f'@app.post("{sync_path}", tags=["{tag}"])', + handler_sig, + f" result = {call_expr}", + ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', + "", + ] + + elif worker.worker_type == "LB": + for route in worker.lb_routes: + method = route["method"].lower() + sub_path = route["path"].lstrip("/") + fn_name = route["fn_name"] + config_var = route["config_variable"] + full_path = f"{worker.url_prefix}/{sub_path}" + handler_name = _sanitize_fn_name( + f"_route_{worker.resource_name}_{fn_name}" + ) + path_params = _extract_path_params(full_path) + has_body = method in ("post", "put", "patch", "delete") + if has_body: + model_var = f"_{worker.resource_name}_{fn_name}_Input" + # POST/PUT/PATCH/DELETE: typed body + optional path params + if path_params: + param_sig = ", ".join(f"{p}: str" for p in path_params) + param_dict = ", ".join(f'"{p}": {p}' for p in path_params) + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: {model_var}, {param_sig}):", + f" return await _lb_execute({config_var}, {fn_name}, {{**_to_dict(body), {param_dict}}})", + "", + ] + else: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(body: {model_var}):", + f" return await _lb_execute({config_var}, {fn_name}, _to_dict(body))", + "", + ] + else: + # GET/etc: path params + query params (unchanged) + if path_params: + param_sig = ", ".join(f"{p}: str" for p in path_params) + param_dict = ", ".join(f'"{p}": {p}' for p in path_params) + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}({param_sig}, request: Request):", + f" return await _lb_execute({config_var}, {fn_name}, {{**dict(request.query_params), {param_dict}}})", + "", + ] + else: + lines += [ + f'@app.{method}("{full_path}", tags=["{tag}"])', + f"async def {handler_name}(request: Request):", + f" return await _lb_execute({config_var}, {fn_name}, dict(request.query_params))", + "", + ] + + # Health endpoints + lines += [ + "# Health", + '@app.get("/", tags=["health"])', + "def home():", + ' return {"message": "Flash Dev Server", "docs": "/docs"}', + "", + '@app.get("/ping", tags=["health"])', + "def ping():", + ' return {"status": "healthy"}', + "", + ] + + server_path = flash_dir / "server.py" + server_path.write_text("\n".join(lines), encoding="utf-8") + return server_path + + +def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> None: + """Print the startup table showing local paths, resource names, and types.""" + console.print(f"\n[bold green]Flash Dev Server[/bold green] http://{host}:{port}") + console.print() + + table = Table(show_header=True, header_style="bold") + table.add_column("Local path", style="cyan") + table.add_column("Resource", style="white") + table.add_column("Type", style="yellow") + + for worker in workers: + if worker.worker_type == "QB": + total_class_methods = sum(len(c["methods"]) for c in worker.class_remotes) + total_callables = len(worker.functions) + total_class_methods + use_multi = total_callables > 1 + + for fn in worker.functions: + if use_multi: + table.add_row( + f"POST {worker.url_prefix}/{fn}/run_sync", + worker.resource_name, + "QB", + ) + else: + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) + + for cls_info in worker.class_remotes: + methods = cls_info["methods"] + for method in methods: + if use_multi: + table.add_row( + f"POST {worker.url_prefix}/{method}/run_sync", + worker.resource_name, + "QB", + ) + else: + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) + elif worker.worker_type == "LB": + for route in worker.lb_routes: + sub_path = route["path"].lstrip("/") + full_path = f"{worker.url_prefix}/{sub_path}" + table.add_row( + f"{route['method']} {full_path}", + worker.resource_name, + "LB", + ) + + console.print(table) + console.print(f"\n Visit [bold]http://{host}:{port}/docs[/bold] for Swagger UI") + console.print( + " Press [bold]Ctrl+C[/bold] to stop — provisioned endpoints are cleaned up automatically\n" + ) + + +def _cleanup_live_endpoints() -> None: + """Deprovision all Live Serverless endpoints created during this session. + + Reads the resource state file written by the uvicorn subprocess, finds + all endpoints with the 'live-' name prefix, and deprovisions them. + Best-effort: errors per endpoint are logged but do not prevent cleanup + of other endpoints. + """ + if not _RESOURCE_STATE_FILE.exists(): + return + + import asyncio + import cloudpickle + from ...core.utils.file_lock import file_lock + + # Load persisted resource state. If this fails (lock error, corruption), + # log and return — don't let it prevent the rest of shutdown. + try: + with open(_RESOURCE_STATE_FILE, "rb") as f: + with file_lock(f, exclusive=False): + data = cloudpickle.load(f) + except Exception as e: + logger.warning(f"Could not read resource state for cleanup: {e}") + return + + if isinstance(data, tuple): + resources, configs = data + else: + resources, configs = data, {} + + live_items = { + key: resource + for key, resource in resources.items() + if hasattr(resource, "name") + and resource.name + and resource.name.startswith("live-") + } + + if not live_items: + return + + import time + + async def _do_cleanup(): + undeployed = 0 + for key, resource in live_items.items(): + name = getattr(resource, "name", key) + try: + success = await resource._do_undeploy() + if success: + console.print(f" Deprovisioned: {name}") + undeployed += 1 + else: + logger.warning(f"Failed to deprovision: {name}") + except Exception as e: + logger.warning(f"Error deprovisioning {name}: {e}") + return undeployed + + t0 = time.monotonic() + loop = asyncio.new_event_loop() + try: + undeployed = loop.run_until_complete(_do_cleanup()) + finally: + loop.close() + elapsed = time.monotonic() - t0 + console.print( + f" Cleanup completed: {undeployed}/{len(live_items)} " + f"resource(s) undeployed in {elapsed:.1f}s" + ) + + # Remove live- entries from persisted state so they don't linger. + remaining = {k: v for k, v in resources.items() if k not in live_items} + remaining_configs = {k: v for k, v in configs.items() if k not in live_items} + try: + with open(_RESOURCE_STATE_FILE, "wb") as f: + with file_lock(f, exclusive=True): + cloudpickle.dump((remaining, remaining_configs), f) + except Exception as e: + logger.warning(f"Could not update resource state after cleanup: {e}") + + +def _is_reload() -> bool: + """Check if running in uvicorn reload subprocess.""" + return "UVICORN_RELOADER_PID" in os.environ + + +def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> None: + """Watch project .py files and regenerate server.py when they change. + + Ignores .flash/ to avoid reacting to our own writes. Runs until + stop_event is set. + """ + # Suppress watchfiles' internal debug chatter (filter hits, rust timeouts). + logging.getLogger("watchfiles").setLevel(logging.WARNING) + + watch_filter = _WatchfilesDefaultFilter(ignore_paths=[str(project_root / ".flash")]) + + try: + for changes in _watchfiles_watch( + project_root, + watch_filter=watch_filter, + stop_event=stop_event, + ): + py_changed = [p for _, p in changes if p.endswith(".py")] + if not py_changed: + continue + try: + workers = _scan_project_workers(project_root) + _generate_flash_server(project_root, workers) + logger.debug("server.py regenerated (%d changed)", len(py_changed)) + except Exception as e: + logger.warning("Failed to regenerate server.py: %s", e) + except ModuleNotFoundError as e: + logger.warning("File watching disabled: %s", e) + except Exception as e: + if not stop_event.is_set(): + logger.exception("Unexpected error in file watcher: %s", e) + + +def _discover_resources(project_root: Path): + """Discover deployable resources in project files. + + Uses ResourceDiscovery to find all DeployableResource instances by + parsing @remote decorators and importing the referenced config variables. + + Args: + project_root: Root directory of the Flash project + + Returns: + List of discovered DeployableResource instances + """ + from ...core.discovery import ResourceDiscovery + + py_files = sorted( + p + for p in project_root.rglob("*.py") + if not any( + skip in p.parts + for skip in (".flash", ".venv", "venv", "__pycache__", ".git") + ) + ) + + # Add project root to sys.path so cross-module imports resolve + # (e.g. api/routes.py doing "from longruns.stage1 import stage1_process"). + root_str = str(project_root) + added_to_path = root_str not in sys.path + if added_to_path: + sys.path.insert(0, root_str) + + resources = [] + try: + for py_file in py_files: + try: + discovery = ResourceDiscovery(str(py_file), max_depth=0) + resources.extend(discovery.discover()) + except Exception as e: + logger.debug("Discovery failed for %s: %s", py_file, e) + finally: + if added_to_path: + sys.path.remove(root_str) + + if resources: + console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") + for res in resources: + res_name = getattr(res, "name", "Unknown") + res_type = res.__class__.__name__ + console.print(f" [dim]- {res_name} ({res_type})[/dim]") + console.print() + + return resources + + +def _provision_resources(resources) -> None: + """Provision resources in parallel and wait for completion. + + Args: + resources: List of DeployableResource instances to provision + """ + import asyncio + + from ...core.deployment import DeploymentOrchestrator + + try: + console.print(f"[bold]Provisioning {len(resources)} resource(s)...[/bold]") + orchestrator = DeploymentOrchestrator(max_concurrent=3) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(orchestrator.deploy_all(resources, show_progress=True)) + loop.close() + except Exception as e: + console.print(f"[yellow]Warning:[/yellow] Provisioning failed: {e}") + console.print( + "[dim]Resources will be provisioned on-demand at first request.[/dim]" + ) + def run_command( host: str = typer.Option( @@ -36,65 +821,66 @@ def run_command( auto_provision: bool = typer.Option( False, "--auto-provision", - help="Auto-provision deployable resources on startup", + help="Auto-provision all endpoints on startup (eliminates cold-start on first request)", ), ): - """Run Flash development server with uvicorn.""" + """Run Flash development server. - # Discover entry point - entry_point = discover_entry_point() - if not entry_point: - console.print("[red]Error:[/red] No entry point found") - console.print("Create main.py with a FastAPI app") - raise typer.Exit(1) + Scans the project for @remote decorated functions, generates a dev server + at .flash/server.py, and starts uvicorn with hot-reload. - # Check if entry point has FastAPI app - app_location = check_fastapi_app(entry_point) - if not app_location: - console.print(f"[red]Error:[/red] No FastAPI app found in {entry_point}") - console.print("Make sure your main.py contains: app = FastAPI()") - raise typer.Exit(1) + No main.py or FastAPI boilerplate required. Any .py file with @remote + decorated functions is a valid Flash project. + """ + project_root = Path.cwd() - # Set flag for all flash run sessions to ensure both auto-provisioned - # and on-the-fly provisioned resources get the live- prefix + # Set flag for live provisioning so stubs get the live- prefix if not _is_reload(): os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true" - # Auto-provision resources if flag is set and not a reload + # Auto-provision all endpoints upfront (eliminates cold-start) if auto_provision and not _is_reload(): try: - resources = _discover_resources(entry_point) - + resources = _discover_resources(project_root) if resources: - # If many resources found, ask for confirmation - if len(resources) > 5: - if not _confirm_large_provisioning(resources): - console.print("[yellow]Auto-provisioning cancelled[/yellow]\n") - else: - _provision_resources(resources) - else: - _provision_resources(resources) + _provision_resources(resources) except Exception as e: logger.error("Auto-provisioning failed", exc_info=True) + console.print(f"[yellow]Warning:[/yellow] Auto-provisioning failed: {e}") console.print( - f"[yellow]Warning:[/yellow] Resource provisioning failed: {e}" - ) - console.print( - "[yellow]Note:[/yellow] Resources will be deployed on-demand when first called" + "[dim]Resources will be provisioned on-demand at first request.[/dim]" ) - console.print("\n[green]Starting Flash Server[/green]") - console.print(f"Entry point: [bold]{app_location}[/bold]") - console.print(f"Server: [bold]http://{host}:{port}[/bold]") - console.print(f"Auto-reload: [bold]{'enabled' if reload else 'disabled'}[/bold]") - console.print("\nPress CTRL+C to stop\n") + # Discover @remote functions + workers = _scan_project_workers(project_root) - # Build uvicorn command + if not workers: + console.print("[red]Error:[/red] No @remote functions found.") + console.print("Add @remote decorators to your functions to get started.") + console.print("\nExample:") + console.print( + " from runpod_flash import LiveServerless, remote\n" + " gpu_config = LiveServerless(name='my_worker')\n" + "\n" + " @remote(gpu_config)\n" + " async def process(input_data: dict) -> dict:\n" + " return {'result': input_data}" + ) + raise typer.Exit(1) + + # Generate .flash/server.py + _generate_flash_server(project_root, workers) + + _print_startup_table(workers, host, port) + + # Build uvicorn command using --app-dir so server:app is importable cmd = [ sys.executable, "-m", "uvicorn", - app_location, + "server:app", + "--app-dir", + ".flash", "--host", host, "--port", @@ -104,13 +890,26 @@ def run_command( ] if reload: - cmd.append("--reload") + cmd += [ + "--reload", + "--reload-dir", + ".flash", + "--reload-include", + "server.py", + ] + + stop_event = threading.Event() + watcher_thread = None + if reload: + watcher_thread = threading.Thread( + target=_watch_and_regenerate, + args=(project_root, stop_event), + daemon=True, + name="flash-watcher", + ) - # Run uvicorn with proper process group handling process = None try: - # Create new process group to ensure all child processes can be killed together - # On Unix systems, use process group; on Windows, CREATE_NEW_PROCESS_GROUP if sys.platform == "win32": process = subprocess.Popen( cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP @@ -118,27 +917,28 @@ def run_command( else: process = subprocess.Popen(cmd, preexec_fn=os.setsid) - # Wait for process to complete + if watcher_thread is not None: + watcher_thread.start() + process.wait() except KeyboardInterrupt: - console.print("\n[yellow]Stopping server and cleaning up processes...[/yellow]") + console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") + + stop_event.set() + if watcher_thread is not None and watcher_thread.is_alive(): + watcher_thread.join(timeout=2) - # Kill the entire process group to ensure all child processes are terminated if process: try: if sys.platform == "win32": - # Windows: terminate the process process.terminate() else: - # Unix: kill entire process group os.killpg(os.getpgid(process.pid), signal.SIGTERM) - # Wait briefly for graceful shutdown try: process.wait(timeout=2) except subprocess.TimeoutExpired: - # Force kill if didn't terminate gracefully if sys.platform == "win32": process.kill() else: @@ -146,14 +946,19 @@ def run_command( process.wait() except (ProcessLookupError, OSError): - # Process already terminated pass + _cleanup_live_endpoints() console.print("[green]Server stopped[/green]") raise typer.Exit(0) except Exception as e: console.print(f"[red]Error:[/red] {e}") + + stop_event.set() + if watcher_thread is not None and watcher_thread.is_alive(): + watcher_thread.join(timeout=2) + if process: try: if sys.platform == "win32": @@ -162,135 +967,5 @@ def run_command( os.killpg(os.getpgid(process.pid), signal.SIGTERM) except (ProcessLookupError, OSError): pass + _cleanup_live_endpoints() raise typer.Exit(1) - - -def discover_entry_point() -> Optional[str]: - """Discover the main entry point file.""" - candidates = ["main.py", "app.py", "server.py"] - - for candidate in candidates: - if Path(candidate).exists(): - return candidate - - return None - - -def check_fastapi_app(entry_point: str) -> Optional[str]: - """ - Check if entry point has a FastAPI app and return the app location. - - Returns: - App location in format "module:app" or None - """ - try: - # Read the file - content = Path(entry_point).read_text() - - # Check for FastAPI app - if "app = FastAPI(" in content or "app=FastAPI(" in content: - # Extract module name from file path - module = entry_point.replace(".py", "").replace("/", ".") - return f"{module}:app" - - return None - - except Exception: - return None - - -def _is_reload() -> bool: - """Check if running in uvicorn reload subprocess. - - Returns: - True if running in a reload subprocess - """ - return "UVICORN_RELOADER_PID" in os.environ - - -def _discover_resources(entry_point: str): - """Discover deployable resources in entry point. - - Args: - entry_point: Path to entry point file - - Returns: - List of discovered DeployableResource instances - """ - from ...core.discovery import ResourceDiscovery - - try: - discovery = ResourceDiscovery(entry_point, max_depth=2) - resources = discovery.discover() - - # Debug: Log what was discovered - if resources: - console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") - for res in resources: - res_name = getattr(res, "name", "Unknown") - res_type = res.__class__.__name__ - console.print(f" [dim]• {res_name} ({res_type})[/dim]") - console.print() - - return resources - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Resource discovery failed: {e}") - return [] - - -def _confirm_large_provisioning(resources) -> bool: - """Show resources and prompt user for confirmation. - - Args: - resources: List of resources to provision - - Returns: - True if user confirms, False otherwise - """ - try: - console.print( - f"\n[yellow]Found {len(resources)} resources to provision:[/yellow]" - ) - - for resource in resources: - name = getattr(resource, "name", "Unknown") - resource_type = resource.__class__.__name__ - console.print(f" • {name} ({resource_type})") - - console.print() - - confirmed = questionary.confirm( - "This may take several minutes. Do you want to proceed?" - ).ask() - - return confirmed if confirmed is not None else False - - except (KeyboardInterrupt, EOFError): - console.print("\n[yellow]Cancelled[/yellow]") - return False - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Confirmation failed: {e}") - return False - - -def _provision_resources(resources): - """Provision resources and wait for completion. - - Args: - resources: List of resources to provision - """ - import asyncio - from ...core.deployment import DeploymentOrchestrator - - try: - console.print(f"\n[bold]Provisioning {len(resources)} resource(s)...[/bold]") - orchestrator = DeploymentOrchestrator(max_concurrent=3) - - # Run provisioning with progress shown - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(orchestrator.deploy_all(resources, show_progress=True)) - loop.close() - - except Exception as e: - console.print(f"[yellow]Warning:[/yellow] Provisioning failed: {e}") diff --git a/src/runpod_flash/cli/utils/skeleton_template/.gitignore b/src/runpod_flash/cli/utils/skeleton_template/.gitignore index f0673581..0e3b93d7 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/.gitignore +++ b/src/runpod_flash/cli/utils/skeleton_template/.gitignore @@ -36,6 +36,7 @@ wheels/ .env.local # Flash +.flash/ .runpod/ dist/ diff --git a/src/runpod_flash/cli/utils/skeleton_template/README.md b/src/runpod_flash/cli/utils/skeleton_template/README.md index 6c4801e5..328a8ab3 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/README.md +++ b/src/runpod_flash/cli/utils/skeleton_template/README.md @@ -1,243 +1,176 @@ # {{project_name}} -Flash application demonstrating distributed GPU and CPU computing on Runpod's serverless infrastructure. - -## About This Template - -This project was generated using `flash init`. The `{{project_name}}` placeholder is automatically replaced with your actual project name during initialization. +Runpod Flash application with GPU and CPU workers on Runpod serverless infrastructure. ## Quick Start -### 1. Install Dependencies +Install [uv](https://docs.astral.sh/uv/getting-started/installation/) (recommended Python package manager): ```bash -pip install -r requirements.txt +curl -LsSf https://astral.sh/uv/install.sh | sh ``` -### 2. Configure Environment - -Create `.env` file: +Set up the project: ```bash -RUNPOD_API_KEY=your_api_key_here +uv venv && source .venv/bin/activate +uv sync +cp .env.example .env # Add your RUNPOD_API_KEY +flash run ``` -Get your API key from [Runpod Settings](https://www.runpod.io/console/user/settings). - -### 3. Run Locally +Or with pip: ```bash -# Standard run +python -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt +cp .env.example .env # Add your RUNPOD_API_KEY flash run - -# Faster development: pre-provision endpoints (eliminates cold-start delays) -flash run --auto-provision ``` -Server starts at **http://localhost:8000** +Server starts at **http://localhost:8888**. Visit **http://localhost:8888/docs** for interactive Swagger UI. -With `--auto-provision`, all serverless endpoints deploy before testing begins. This is much faster for development because endpoints are cached and reused across server restarts. Subsequent runs skip deployment and start immediately. +Use `flash run --auto-provision` to pre-deploy all endpoints on startup, eliminating cold-start delays on first request. Provisioned endpoints are cached and reused across restarts. -### 4. Test the API +When you stop the server with Ctrl+C, all endpoints provisioned during the session are automatically cleaned up. -```bash -# Health check -curl http://localhost:8000/ping +Get your API key from [Runpod Settings](https://www.runpod.io/console/user/settings). +Learn more about it from our [Documentation](https://docs.runpod.io/get-started/api-keys). + +## Test the API -# GPU worker -curl -X POST http://localhost:8000/gpu/hello \ +```bash +# Queue-based GPU worker +curl -X POST http://localhost:8888/gpu_worker/run_sync \ -H "Content-Type: application/json" \ -d '{"message": "Hello GPU!"}' -# CPU worker -curl -X POST http://localhost:8000/cpu/hello \ +# Queue-based CPU worker +curl -X POST http://localhost:8888/cpu_worker/run_sync \ -H "Content-Type: application/json" \ -d '{"message": "Hello CPU!"}' -``` - -Visit **http://localhost:8000/docs** for interactive API documentation. -## What This Demonstrates - -### GPU Worker (`workers/gpu/`) -Simple GPU-based serverless function: -- Remote execution with `@remote` decorator -- GPU resource configuration -- Automatic scaling (0-3 workers) -- No external dependencies required - -```python -@remote( - resource_config=LiveServerless( - name="gpu_worker", - gpus=[GpuGroup.ADA_24], # RTX 4090 - workersMin=0, - workersMax=3, - ) -) -async def gpu_hello(input_data: dict) -> dict: - # Your GPU code here - return {"status": "success", "message": "Hello from GPU!"} -``` - -### CPU Worker (`workers/cpu/`) -Simple CPU-based serverless function: -- CPU-only execution (no GPU overhead) -- CpuLiveServerless configuration -- Efficient for API endpoints -- Automatic scaling (0-5 workers) +# Load-balanced HTTP endpoint +curl -X POST http://localhost:8888/lb_worker/process \ + -H "Content-Type: application/json" \ + -d '{"input": "test"}' -```python -@remote( - resource_config=CpuLiveServerless( - name="cpu_worker", - instanceIds=[CpuInstanceType.CPU3G_2_8], # 2 vCPU, 8GB RAM - workersMin=0, - workersMax=5, - ) -) -async def cpu_hello(input_data: dict) -> dict: - # Your CPU code here - return {"status": "success", "message": "Hello from CPU!"} +# Load-balanced health check +curl http://localhost:8888/lb_worker/health ``` ## Project Structure ``` {{project_name}}/ -├── main.py # FastAPI application -├── workers/ -│ ├── gpu/ # GPU worker -│ │ ├── __init__.py # FastAPI router -│ │ └── endpoint.py # @remote decorated function -│ └── cpu/ # CPU worker -│ ├── __init__.py # FastAPI router -│ └── endpoint.py # @remote decorated function -├── .env # Environment variables -├── requirements.txt # Dependencies -└── README.md # This file +├── gpu_worker.py # GPU serverless worker (queue-based) +├── cpu_worker.py # CPU serverless worker (queue-based) +├── lb_worker.py # CPU load-balanced HTTP endpoint +├── .env.example # Environment variable template +├── requirements.txt # Python dependencies +└── README.md ``` -## Key Concepts - -### Remote Execution -The `@remote` decorator transparently executes functions on serverless infrastructure: -- Code runs locally during development -- Automatically deploys to Runpod when configured -- Handles serialization, dependencies, and resource management - -### Resource Scaling -Both workers scale to zero when idle to minimize costs: -- **idleTimeout**: Seconds before scaling down (default: 60) -- **workersMin**: 0 = completely scales to zero -- **workersMax**: Maximum concurrent workers - -### GPU Types -Available GPU options for `LiveServerless`: -- `GpuGroup.ADA_24` - RTX 4090 (24GB) -- `GpuGroup.ADA_48_PRO` - RTX 6000 Ada, L40 (48GB) -- `GpuGroup.AMPERE_80` - A100 (80GB) -- `GpuGroup.ANY` - Any available GPU - -### CPU Types -Available CPU options for `CpuLiveServerless`: -- `CpuInstanceType.CPU3G_2_8` - 2 vCPU, 8GB RAM (General Purpose) -- `CpuInstanceType.CPU3C_4_8` - 4 vCPU, 8GB RAM (Compute Optimized) -- `CpuInstanceType.CPU5G_4_16` - 4 vCPU, 16GB RAM (Latest Gen) -- `CpuInstanceType.ANY` - Any available GPU - -## Development Workflow - -### Test Workers Locally -```bash -# Test GPU worker -python -m workers.gpu.endpoint +## Worker Types -# Test CPU worker -python -m workers.cpu.endpoint -``` +### Queue-Based (QB) Workers -### Run the Application -```bash -flash run -``` - -### Deploy to Production -```bash -# Build and deploy in one step -flash deploy - -# Or deploy to a specific environment -flash deploy --env production -``` - -## Adding New Workers - -### Add a GPU Worker +QB workers process jobs from a queue. Each call to `/run_sync` sends a job and waits +for the result. Use QB for compute-heavy tasks that may take seconds to minutes. -1. Create `workers/my_worker/endpoint.py`: +**gpu_worker.py** — GPU serverless function: ```python -from runpod_flash import remote, LiveServerless +from runpod_flash import GpuType, LiveServerless, remote -config = LiveServerless(name="my_worker") +gpu_config = LiveServerless( + name="gpu_worker", + gpus=[GpuType.ANY], +) -@remote(resource_config=config, dependencies=["torch"]) -async def my_function(data: dict) -> dict: +@remote(resource_config=gpu_config, dependencies=["torch"]) +async def gpu_hello(input_data: dict) -> dict: import torch - # Your code here - return {"result": "success"} + gpu_available = torch.cuda.is_available() + gpu_name = torch.cuda.get_device_name(0) if gpu_available else "No GPU detected" + return {"message": gpu_name} ``` -2. Create `workers/my_worker/__init__.py`: +**cpu_worker.py** — CPU serverless function: ```python -from fastapi import APIRouter -from .endpoint import my_function +from runpod_flash import CpuLiveServerless, remote -router = APIRouter() +cpu_config = CpuLiveServerless(name="cpu_worker") -@router.post("/process") -async def handler(data: dict): - return await my_function(data) +@remote(resource_config=cpu_config) +async def cpu_hello(input_data: dict = {}) -> dict: + return {"message": "Hello from CPU!", **input_data} ``` -3. Add to `main.py`: -```python -from workers.my_worker import router as my_router -app.include_router(my_router, prefix="/my_worker") -``` +### Load-Balanced (LB) Workers -### Add a CPU Worker +LB workers expose standard HTTP endpoints (GET, POST, etc.) behind a load balancer. +Use LB for low-latency API endpoints that need horizontal scaling. -Same pattern but use `CpuLiveServerless`: +**lb_worker.py** — HTTP endpoints on a load-balanced container: ```python -from runpod_flash import remote, CpuLiveServerless, CpuInstanceType +from runpod_flash import CpuLiveLoadBalancer, remote -config = CpuLiveServerless( - name="my_cpu_worker", - instanceIds=[CpuInstanceType.CPU3G_2_8] +api_config = CpuLiveLoadBalancer( + name="lb_worker", + workersMin=1, ) -@remote(resource_config=config, dependencies=["requests"]) -async def fetch_data(url: str) -> dict: - import requests - return requests.get(url).json() +@remote(resource_config=api_config, method="POST", path="/process") +async def process(input_data: dict) -> dict: + return {"status": "success", "echo": input_data} + +@remote(resource_config=api_config, method="GET", path="/health") +async def health() -> dict: + return {"status": "healthy"} ``` -## Adding Dependencies +## Adding New Workers + +Create a new `.py` file with a `@remote` function. `flash run` auto-discovers all +`@remote` functions in the project. -Specify dependencies in the `@remote` decorator: ```python -@remote( - resource_config=config, - dependencies=["torch>=2.0.0", "transformers"], # Python packages - system_dependencies=["ffmpeg"] # System packages -) -async def my_function(data: dict) -> dict: - # Dependencies are automatically installed - import torch - import transformers +# my_worker.py +from runpod_flash import LiveServerless, GpuType, remote + +config = LiveServerless(name="my_worker", gpus=[GpuType.NVIDIA_GEFORCE_RTX_4090]) + +@remote(resource_config=config, dependencies=["transformers"]) +async def predict(input_data: dict) -> dict: + from transformers import pipeline + pipe = pipeline("sentiment-analysis") + return pipe(input_data["text"])[0] ``` +Then run `flash run` — the new worker appears automatically. + +## GPU Types + +| Config | Hardware | VRAM | +|--------|----------|------| +| `GpuType.ANY` | Any available GPU | varies | +| `GpuType.NVIDIA_GEFORCE_RTX_4090` | RTX 4090 | 24 GB | +| `GpuType.NVIDIA_GEFORCE_RTX_5090` | RTX 5090 | 32 GB | +| `GpuType.NVIDIA_RTX_6000_ADA_GENERATION` | RTX 6000 Ada | 48 GB | +| `GpuType.NVIDIA_L4` | L4 | 24 GB | +| `GpuType.NVIDIA_A100_80GB_PCIe` | A100 PCIe | 80 GB | +| `GpuType.NVIDIA_A100_SXM4_80GB` | A100 SXM4 | 80 GB | +| `GpuType.NVIDIA_H100_80GB_HBM3` | H100 | 80 GB | +| `GpuType.NVIDIA_H200` | H200 | 141 GB | + +## CPU Types + +| Config | vCPU | RAM | +|--------|------|-----| +| `CpuInstanceType.CPU3G_2_8` | 2 | 8 GB | +| `CpuInstanceType.CPU3C_4_8` | 4 | 8 GB | +| `CpuInstanceType.CPU5G_4_16` | 4 | 16 GB | + ## Environment Variables ```bash @@ -245,16 +178,13 @@ async def my_function(data: dict) -> dict: RUNPOD_API_KEY=your_api_key # Optional -FLASH_HOST=localhost # Host to bind the server to (default: localhost) -FLASH_PORT=8888 # Port to bind the server to (default: 8888) -LOG_LEVEL=INFO # Logging level (default: INFO) +FLASH_HOST=localhost # Server host (default: localhost) +FLASH_PORT=8888 # Server port (default: 8888) +LOG_LEVEL=INFO # Logging level (default: INFO) ``` -## Next Steps +## Deploy -- Add your ML models or processing logic -- Configure GPU/CPU resources based on your needs -- Add authentication to your endpoints -- Implement error handling and retries -- Add monitoring and logging -- Deploy to production with `flash deploy` +```bash +flash deploy +``` diff --git a/src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py b/src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py new file mode 100644 index 00000000..aee4b5a3 --- /dev/null +++ b/src/runpod_flash/cli/utils/skeleton_template/cpu_worker.py @@ -0,0 +1,17 @@ +from runpod_flash import CpuLiveServerless, remote + +cpu_config = CpuLiveServerless(name="cpu_worker") + + +@remote(resource_config=cpu_config) +async def cpu_hello(input_data: dict) -> dict: + """CPU worker — lightweight processing without GPU.""" + import platform + from datetime import datetime + + return { + "message": input_data.get("message", "Hello from CPU worker!"), + "timestamp": datetime.now().isoformat(), + "platform": platform.system(), + "python_version": platform.python_version(), + } diff --git a/src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py b/src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py new file mode 100644 index 00000000..d787e1e9 --- /dev/null +++ b/src/runpod_flash/cli/utils/skeleton_template/gpu_worker.py @@ -0,0 +1,27 @@ +from runpod_flash import GpuType, LiveServerless, remote + +gpu_config = LiveServerless( + name="gpu_worker", + gpus=[GpuType.ANY], +) + + +@remote(resource_config=gpu_config, dependencies=["torch"]) +async def gpu_hello(input_data: dict) -> dict: + """GPU worker — detects available GPU hardware.""" + import platform + + try: + import torch + + gpu_available = torch.cuda.is_available() + gpu_name = torch.cuda.get_device_name(0) if gpu_available else "No GPU detected" + except Exception as e: + gpu_available = False + gpu_name = f"Error: {e}" + + return { + "message": input_data.get("message", "Hello from GPU worker!"), + "gpu": {"available": gpu_available, "name": gpu_name}, + "python_version": platform.python_version(), + } diff --git a/src/runpod_flash/cli/utils/skeleton_template/lb_worker.py b/src/runpod_flash/cli/utils/skeleton_template/lb_worker.py new file mode 100644 index 00000000..1b40ed0c --- /dev/null +++ b/src/runpod_flash/cli/utils/skeleton_template/lb_worker.py @@ -0,0 +1,24 @@ +from runpod_flash import CpuLiveLoadBalancer, remote + +api_config = CpuLiveLoadBalancer( + name="lb_worker", + workersMin=1, +) + + +@remote(resource_config=api_config, method="POST", path="/process") +async def process(input_data: dict) -> dict: + """Process input data on a load-balanced CPU endpoint.""" + from datetime import datetime + + return { + "status": "success", + "echo": input_data, + "timestamp": datetime.now().isoformat(), + } + + +@remote(resource_config=api_config, method="GET", path="/health") +async def health() -> dict: + """Health check for the load-balanced endpoint.""" + return {"status": "healthy"} diff --git a/src/runpod_flash/cli/utils/skeleton_template/main.py b/src/runpod_flash/cli/utils/skeleton_template/main.py deleted file mode 100644 index ad3ce717..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/main.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -import os - -from fastapi import FastAPI - -from workers.cpu import cpu_router -from workers.gpu import gpu_router - -logger = logging.getLogger(__name__) - - -app = FastAPI( - title="Flash Application", - description="Distributed GPU and CPU computing with Runpod Flash", - version="0.1.0", -) - -# Include routers -app.include_router(gpu_router, prefix="/gpu", tags=["GPU Workers"]) -app.include_router(cpu_router, prefix="/cpu", tags=["CPU Workers"]) - - -@app.get("/") -def home(): - return { - "message": "Flash Application", - "docs": "/docs", - "endpoints": {"gpu_hello": "/gpu/hello", "cpu_hello": "/cpu/hello"}, - } - - -@app.get("/ping") -def ping(): - return {"status": "healthy"} - - -if __name__ == "__main__": - import uvicorn - - host = os.getenv("FLASH_HOST", "localhost") - port = int(os.getenv("FLASH_PORT", 8888)) - logger.info(f"Starting Flash server on {host}:{port}") - - uvicorn.run(app, host=host, port=port) diff --git a/src/runpod_flash/cli/utils/skeleton_template/mothership.py b/src/runpod_flash/cli/utils/skeleton_template/mothership.py deleted file mode 100644 index 85779bfc..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/mothership.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Mothership Endpoint Configuration - -The mothership endpoint serves your FastAPI application routes. -It is automatically deployed as a CPU-optimized load-balanced endpoint. - -To customize this configuration: -- Modify worker scaling: change workersMin and workersMax values -- Use GPU load balancer: import LiveLoadBalancer instead of CpuLiveLoadBalancer -- Change endpoint name: update the 'name' parameter - -To disable mothership deployment: -- Delete this file, or -- Comment out the 'mothership' variable below - -Documentation: https://docs.runpod.io/flash/mothership -""" - -from runpod_flash import CpuLiveLoadBalancer - -# Mothership endpoint configuration -# This serves your FastAPI app routes from main.py -mothership = CpuLiveLoadBalancer( - name="mothership", - workersMin=1, - workersMax=1, -) - -# Examples of customization: - -# Increase scaling for high traffic -# mothership = CpuLiveLoadBalancer( -# name="mothership", -# workersMin=2, -# workersMax=10, -# ) - -# Use GPU-based load balancer instead of CPU -# (requires importing LiveLoadBalancer) -# from runpod_flash import LiveLoadBalancer -# mothership = LiveLoadBalancer( -# name="mothership", -# gpus=[GpuGroup.ANY], -# ) - -# Custom endpoint name -# mothership = CpuLiveLoadBalancer( -# name="my-api-gateway", -# workersMin=1, -# workersMax=1, -# ) - -# To disable mothership: -# - Delete this entire file, or -# - Comment out the 'mothership' variable above diff --git a/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml b/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml index 7987ad22..a58ae558 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml +++ b/src/runpod_flash/cli/utils/skeleton_template/pyproject.toml @@ -5,54 +5,9 @@ build-backend = "setuptools.build_meta" [project] name = "{{project_name}}" version = "0.1.0" -description = "Flash serverless application" +description = "Runpod Flash Serverless Application" readme = "README.md" requires-python = ">=3.11" dependencies = [ "runpod-flash", - "fastapi>=0.104.0", - "uvicorn>=0.24.0", -] - -[project.optional-dependencies] -dev = [ - "pytest>=7.0", - "pytest-asyncio>=0.21", - "pytest-cov>=4.0", - "ruff>=0.1", - "mypy>=1.0", -] - -[tool.ruff] -line-length = 100 -target-version = "py311" - -[tool.ruff.lint] -select = ["E", "F", "I", "N", "W"] -ignore = ["E501"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py", "*_test.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" - -[tool.mypy] -python_version = "3.11" -warn_return_any = false -warn_unused_configs = true -disallow_untyped_defs = false - -[tool.coverage.run] -source = ["src"] -omit = ["*/tests/*"] - -[tool.coverage.report] -exclude_lines = [ - "pragma: no cover", - "def __repr__", - "raise AssertionError", - "raise NotImplementedError", - "if __name__ == .__main__.:", ] diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/__init__.py b/src/runpod_flash/cli/utils/skeleton_template/workers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py b/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py deleted file mode 100644 index aef10a1a..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter -from pydantic import BaseModel - -from .endpoint import cpu_hello - -cpu_router = APIRouter() - - -class MessageRequest(BaseModel): - """Request model for CPU worker.""" - - message: str = "Hello from CPU!" - - -@cpu_router.post("/hello") -async def hello(request: MessageRequest): - """Simple CPU worker endpoint.""" - result = await cpu_hello({"message": request.message}) - return result diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py b/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py deleted file mode 100644 index 8161e5a7..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/cpu/endpoint.py +++ /dev/null @@ -1,36 +0,0 @@ -from runpod_flash import CpuLiveServerless, remote - -cpu_config = CpuLiveServerless( - name="cpu_worker", - workersMin=0, - workersMax=1, - idleTimeout=60, -) - - -@remote(resource_config=cpu_config) -async def cpu_hello(input_data: dict) -> dict: - """Simple CPU worker example.""" - import platform - from datetime import datetime - - message = input_data.get("message", "Hello from CPU worker!") - - return { - "status": "success", - "message": message, - "worker_type": "CPU", - "timestamp": datetime.now().isoformat(), - "platform": platform.system(), - "python_version": platform.python_version(), - } - - -# Test locally with: python -m workers.cpu.endpoint -if __name__ == "__main__": - import asyncio - - test_payload = {"message": "Testing CPU worker"} - print(f"Testing CPU worker with payload: {test_payload}") - result = asyncio.run(cpu_hello(test_payload)) - print(f"Result: {result}") diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py b/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py deleted file mode 100644 index a6a3caad..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter -from pydantic import BaseModel - -from .endpoint import gpu_hello - -gpu_router = APIRouter() - - -class MessageRequest(BaseModel): - """Request model for GPU worker.""" - - message: str = "Hello from GPU!" - - -@gpu_router.post("/hello") -async def hello(request: MessageRequest): - """Simple GPU worker endpoint.""" - result = await gpu_hello({"message": request.message}) - return result diff --git a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py b/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py deleted file mode 100644 index f3c4466c..00000000 --- a/src/runpod_flash/cli/utils/skeleton_template/workers/gpu/endpoint.py +++ /dev/null @@ -1,61 +0,0 @@ -from runpod_flash import GpuGroup, LiveServerless, remote - -gpu_config = LiveServerless( - name="gpu_worker", - gpus=[GpuGroup.ANY], - workersMin=0, - workersMax=1, - idleTimeout=60, -) - - -@remote(resource_config=gpu_config, dependencies=["torch"]) -async def gpu_hello(input_data: dict) -> dict: - """Simple GPU worker example with GPU detection.""" - import platform - from datetime import datetime - - try: - import torch - - gpu_available = torch.cuda.is_available() - if gpu_available: - gpu_name = torch.cuda.get_device_name(0) - gpu_count = torch.cuda.device_count() - gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) - else: - gpu_name = "No GPU detected" - gpu_count = 0 - gpu_memory = 0 - except Exception as e: - gpu_available = False - gpu_name = f"Error detecting GPU: {str(e)}" - gpu_count = 0 - gpu_memory = 0 - - message = input_data.get("message", "Hello from GPU worker!") - - return { - "status": "success", - "message": message, - "worker_type": "GPU", - "gpu_info": { - "available": gpu_available, - "name": gpu_name, - "count": gpu_count, - "memory_gb": round(gpu_memory, 2) if gpu_memory else 0, - }, - "timestamp": datetime.now().isoformat(), - "platform": platform.system(), - "python_version": platform.python_version(), - } - - -# Test locally with: python -m workers.gpu.endpoint -if __name__ == "__main__": - import asyncio - - test_payload = {"message": "Testing GPU worker"} - print(f"Testing GPU worker with payload: {test_payload}") - result = asyncio.run(gpu_hello(test_payload)) - print(f"Result: {result}") diff --git a/src/runpod_flash/client.py b/src/runpod_flash/client.py index ed68bc30..8709cf75 100644 --- a/src/runpod_flash/client.py +++ b/src/runpod_flash/client.py @@ -159,6 +159,20 @@ def decorator(func_or_class): "system_dependencies": system_dependencies, } + # LB route handler passthrough — return the function unwrapped. + # + # When @remote is applied to an LB resource (LiveLoadBalancer, + # CpuLiveLoadBalancer, LoadBalancerSlsResource) with method= and path=, + # the decorated function IS the HTTP route handler. Its body executes + # directly on the LB endpoint server; it is not dispatched to a remote + # process. QB @remote calls inside its body still use their own stubs. + is_lb_route_handler = is_lb_resource and method is not None and path is not None + if is_lb_route_handler: + routing_config["is_lb_route_handler"] = True + func_or_class.__remote_config__ = routing_config + func_or_class.__is_lb_route_handler__ = True + return func_or_class + # Local execution mode - execute without provisioning remote servers if local: func_or_class.__remote_config__ = routing_config diff --git a/src/runpod_flash/core/api/runpod.py b/src/runpod_flash/core/api/runpod.py index bc30219a..478428b8 100644 --- a/src/runpod_flash/core/api/runpod.py +++ b/src/runpod_flash/core/api/runpod.py @@ -202,7 +202,7 @@ async def save_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: variables = {"input": input_data} - log.debug(f"Saving endpoint with GraphQL: {input_data.get('name', 'unnamed')}") + log.debug(f"GraphQL saveEndpoint: {input_data.get('name', 'unnamed')}") result = await self._execute_graphql(mutation, variables) @@ -211,7 +211,7 @@ async def save_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: endpoint_data = result["saveEndpoint"] log.debug( - f"Saved endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}" + f"GraphQL response: {endpoint_data.get('id', 'unknown')} ({endpoint_data.get('name', 'unnamed')})" ) return endpoint_data diff --git a/src/runpod_flash/core/resources/constants.py b/src/runpod_flash/core/resources/constants.py index e927c09b..c2aee1b3 100644 --- a/src/runpod_flash/core/resources/constants.py +++ b/src/runpod_flash/core/resources/constants.py @@ -39,3 +39,6 @@ def _endpoint_domain_from_base_url(base_url: str) -> str: MAX_TARBALL_SIZE_MB = 500 # Maximum tarball size in megabytes VALID_TARBALL_EXTENSIONS = (".tar.gz", ".tgz") # Valid tarball file extensions GZIP_MAGIC_BYTES = (0x1F, 0x8B) # Magic bytes for gzip files + +# Load balancer stub timeout (seconds) +DEFAULT_LB_STUB_TIMEOUT = 60.0 diff --git a/src/runpod_flash/core/resources/load_balancer_sls_resource.py b/src/runpod_flash/core/resources/load_balancer_sls_resource.py index eb664ed0..df84d622 100644 --- a/src/runpod_flash/core/resources/load_balancer_sls_resource.py +++ b/src/runpod_flash/core/resources/load_balancer_sls_resource.py @@ -259,10 +259,10 @@ async def _do_deploy(self) -> "LoadBalancerSlsResource": self.env["FLASH_IS_MOTHERSHIP"] = "true" # Call parent deploy (creates endpoint via RunPod API) - log.debug(f"Deploying LB endpoint {self.name}...") + log.info(f"Deploying LB endpoint: {self.name}") deployed = await super()._do_deploy() - log.debug(f"LB endpoint {self.name} ({deployed.id}) deployed successfully") + log.info(f"Deployed: {self.name} ({deployed.url})") return deployed except Exception as e: diff --git a/src/runpod_flash/core/resources/resource_manager.py b/src/runpod_flash/core/resources/resource_manager.py index 0cd18f51..82eebe2b 100644 --- a/src/runpod_flash/core/resources/resource_manager.py +++ b/src/runpod_flash/core/resources/resource_manager.py @@ -245,7 +245,6 @@ async def get_or_deploy_resource( deployed_resource = await self._deploy_with_error_context( config ) - log.debug(f"URL: {deployed_resource.url}") self._add_resource(resource_key, deployed_resource) return deployed_resource except Exception: @@ -278,7 +277,6 @@ async def get_or_deploy_resource( deployed_resource = await self._deploy_with_error_context( config ) - log.debug(f"URL: {deployed_resource.url}") self._add_resource(resource_key, deployed_resource) return deployed_resource except Exception: @@ -292,13 +290,11 @@ async def get_or_deploy_resource( raise # Config unchanged, reuse existing - log.info(f"URL: {existing.url}") return existing # No existing resource, deploy new one try: deployed_resource = await self._deploy_with_error_context(config) - log.debug(f"URL: {deployed_resource.url}") self._add_resource(resource_key, deployed_resource) return deployed_resource except Exception: diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index e7a14403..869b85a8 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -474,6 +474,15 @@ def is_deployed(self) -> bool: if not self.id: return False + # During flash run, skip the health check. Newly-created endpoints + # can fail health checks due to RunPod propagation delay — the + # endpoint exists but the health API hasn't registered it yet. + # Trusting the cached ID is correct here; actual failures surface + # on the first real run/run_sync call. + # Case-insensitive check; unset env var defaults to "" via getenv. + if os.getenv("FLASH_IS_LIVE_PROVISIONING", "").lower() == "true": + return True + response = self.endpoint.health() return response is not None except Exception as e: @@ -484,6 +493,12 @@ def _payload_exclude(self) -> Set[str]: # flashEnvironmentId is input-only but must be sent when provided exclude_fields = set(self._input_only or set()) exclude_fields.discard("flashEnvironmentId") + # When templateId is already set, exclude template from the payload. + # RunPod rejects requests that contain both fields simultaneously. + # Both can coexist after deploy mutates config (sets templateId while + # template remains from initialization) — templateId takes precedence. + if self.templateId: + exclude_fields.add("template") return exclude_fields @staticmethod @@ -564,12 +579,45 @@ def _check_makes_remote_calls(self) -> bool: ) return True # Safe default on error + def _get_module_path(self) -> Optional[str]: + """Get module_path from build manifest for this resource. + + Returns: + Dotted module path (e.g., 'preprocess.first_pass'), or None if not found. + """ + try: + manifest_path = Path.cwd() / "flash_manifest.json" + if not manifest_path.exists(): + manifest_path = Path("/flash_manifest.json") + if not manifest_path.exists(): + return None + + with open(manifest_path) as f: + manifest_data = json.load(f) + + resources = manifest_data.get("resources", {}) + + lookup_name = self.name + if lookup_name.endswith("-fb"): + lookup_name = lookup_name[:-3] + if lookup_name.startswith(LIVE_PREFIX): + lookup_name = lookup_name[len(LIVE_PREFIX) :] + + resource_config = resources.get(lookup_name) + if not resource_config: + return None + + return resource_config.get("module_path") + + except Exception: + return None + async def _do_deploy(self) -> "DeployableResource": """ Deploys the serverless resource using the provided configuration. - For queue-based endpoints that make remote calls, injects RUNPOD_API_KEY - into environment variables if not already set. + For queue-based endpoints that make remote calls, injects RUNPOD_API_KEY. + For load-balanced endpoints, injects FLASH_MODULE_PATH. Returns a DeployableResource object. """ @@ -604,6 +652,17 @@ async def _do_deploy(self) -> "DeployableResource": self.env = env_dict + # Inject module path for load-balanced endpoints + elif self.type == ServerlessType.LB: + env_dict = self.env or {} + + module_path = self._get_module_path() + if module_path and "FLASH_MODULE_PATH" not in env_dict: + env_dict["FLASH_MODULE_PATH"] = module_path + log.info(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}") + + self.env = env_dict + # Ensure network volume is deployed first await self._ensure_network_volume_deployed() @@ -617,6 +676,9 @@ async def _do_deploy(self) -> "DeployableResource": endpoint = await self._sync_graphql_object_with_inputs(endpoint) self.id = endpoint.id self.templateId = endpoint.templateId + self.template = ( + None # templateId takes precedence; clear to avoid conflict + ) return endpoint raise ValueError("Deployment failed, no endpoint was returned.") diff --git a/src/runpod_flash/stubs/dependency_resolver.py b/src/runpod_flash/stubs/dependency_resolver.py new file mode 100644 index 00000000..f8449fb0 --- /dev/null +++ b/src/runpod_flash/stubs/dependency_resolver.py @@ -0,0 +1,264 @@ +"""Dependency resolver for stacked @remote function execution. + +When @remote funcA calls @remote funcB, the worker only receives funcA's source. +This module detects such dependencies, provisions their endpoints, and generates +dispatch stubs so funcB resolves correctly inside the worker's exec() namespace. +""" + +import ast +import inspect +import logging +from dataclasses import dataclass +from typing import Any + +from .live_serverless import get_function_source + +log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RemoteDependency: + """A resolved @remote dependency ready for stub generation.""" + + name: str + endpoint_id: str + source: str + dependencies: list[str] + system_dependencies: list[str] + + +def detect_remote_dependencies(source: str, func_globals: dict[str, Any]) -> list[str]: + """Find names of @remote functions called in *source*. + + Parses the source AST and checks each direct function call (ast.Name) + against *func_globals* for the ``__remote_config__`` attribute set by + the @remote decorator. + + Only direct calls (``await funcB(x)``) are detected. Attribute calls + (``obj.funcB(x)``) and indirect references (``f = funcB; f(x)``) are + intentionally ignored (V1 limitation). + + Args: + source: Source code string of the calling function. + func_globals: The ``__globals__`` dict of the calling function, + used to resolve called names. + + Returns: + Sorted list of names that resolve to @remote-decorated objects. + """ + tree = ast.parse(source) + called_names: set[str] = set() + + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + called_names.add(node.func.id) + + remote_deps = [ + name + for name in sorted(called_names) + if name in func_globals and hasattr(func_globals[name], "__remote_config__") + ] + return remote_deps + + +async def resolve_dependencies( + source: str, func_globals: dict[str, Any] +) -> list[RemoteDependency]: + """Detect @remote dependencies and provision their endpoints. + + For each detected dependency: + 1. Extract resource_config from ``__remote_config__`` + 2. Provision via ``ResourceManager().get_or_deploy_resource()`` + 3. Return a ``RemoteDependency`` with the provisioned endpoint_id + + Args: + source: Source code of the calling function. + func_globals: The ``__globals__`` dict of the calling function. + + Returns: + List of resolved dependencies with endpoint IDs. + + Raises: + RuntimeError: If endpoint provisioning fails for any dependency. + """ + dep_names = detect_remote_dependencies(source, func_globals) + if not dep_names: + return [] + + import asyncio + + from ..core.resources import ResourceManager + + resource_manager = ResourceManager() + + # Gather metadata needed for each dependency before parallel provisioning. + dep_info: list[tuple[str, Any, str, list[str], list[str]]] = [] + for name in dep_names: + dep_func = func_globals[name] + config = dep_func.__remote_config__ + unwrapped = inspect.unwrap(dep_func) + dep_source, _ = get_function_source(unwrapped) + dep_info.append( + ( + name, + config["resource_config"], + dep_source, + config.get("dependencies") or [], + config.get("system_dependencies") or [], + ) + ) + + # Provision all endpoints in parallel. + remote_resources = await asyncio.gather( + *(resource_manager.get_or_deploy_resource(rc) for _, rc, _, _, _ in dep_info) + ) + + resolved: list[RemoteDependency] = [] + for (name, _, dep_source, deps, sys_deps), remote_resource in zip( + dep_info, remote_resources + ): + resolved.append( + RemoteDependency( + name=name, + endpoint_id=remote_resource.id, + source=dep_source, + dependencies=deps, + system_dependencies=sys_deps, + ) + ) + log.debug( + "Resolved dependency %s -> endpoint %s", + name, + remote_resource.id, + ) + + return resolved + + +def generate_stub_code(dep: RemoteDependency) -> str: + """Generate an async stub function that dispatches to a remote endpoint. + + The stub preserves the original function's parameter names so callers + can use ``await funcB(payload)`` naturally. Inside the stub, arguments + are serialized with cloudpickle and sent via aiohttp to the RunPod + runsync endpoint. + + Args: + dep: Resolved dependency with endpoint_id and source. + + Returns: + Python source code string defining the async stub function. + """ + # Parse the dependency source to extract parameter names + tree = ast.parse(dep.source) + params_str = "*args, **kwargs" + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == dep.name + ): + params_str = _extract_params(node) + break + + # Build serialization expressions for args/kwargs + ser_args_expr, ser_kwargs_expr = _build_serialization_exprs(tree, dep.name) + + I = " " # noqa: E741 — single indent level + lines = [ + f"async def {dep.name}({params_str}):", + f"{I}import os as _os", + f"{I}import base64 as _b64", + f"{I}import cloudpickle as _cp", + f"{I}import aiohttp as _aiohttp", + "", + f"{I}_endpoint_id = {repr(dep.endpoint_id)}", + f'{I}_api_key = _os.environ.get("RUNPOD_API_KEY", "")', + f'{I}_url = f"https://api.runpod.ai/v2/{{_endpoint_id}}/runsync"', + f'{I}_headers = {{"Content-Type": "application/json"}}', + f"{I}if _api_key:", + f'{I}{I}_headers["Authorization"] = f"Bearer {{_api_key}}"', + "", + f"{I}_func_source = {repr(dep.source)}", + f"{I}_ser_args = {ser_args_expr}", + f"{I}_ser_kwargs = {ser_kwargs_expr}", + f"{I}_payload = {{", + f'{I}{I}"input": {{', + f'{I}{I}{I}"function_name": {repr(dep.name)},', + f'{I}{I}{I}"function_code": _func_source,', + f'{I}{I}{I}"args": _ser_args,', + f'{I}{I}{I}"kwargs": _ser_kwargs,', + f'{I}{I}{I}"dependencies": {repr(dep.dependencies)},', + f'{I}{I}{I}"system_dependencies": {repr(dep.system_dependencies)},', + f"{I}{I}}}", + f"{I}}}", + "", + f"{I}_timeout = _aiohttp.ClientTimeout(total=300)", + f"{I}async with _aiohttp.ClientSession(timeout=_timeout) as _sess:", + f"{I}{I}async with _sess.post(_url, json=_payload, headers=_headers) as _resp:", + f"{I}{I}{I}if _resp.status != 200:", + f"{I}{I}{I}{I}_err = await _resp.text()", + f"{I}{I}{I}{I}raise RuntimeError(", + f'{I}{I}{I}{I}{I}f"Remote {dep.name} failed (HTTP {{_resp.status}}): {{_err}}"', + f"{I}{I}{I}{I})", + f"{I}{I}{I}_data = await _resp.json()", + f'{I}{I}{I}_output = _data.get("output", _data)', + f'{I}{I}{I}if not _output.get("success"):', + f"{I}{I}{I}{I}raise RuntimeError(", + f"{I}{I}{I}{I}{I}f\"Remote {dep.name} failed: {{_output.get('error')}}\"", + f"{I}{I}{I}{I})", + f'{I}{I}{I}return _cp.loads(_b64.b64decode(_output["result"]))', + ] + return "\n".join(lines) + "\n" + + +def build_augmented_source(original_source: str, stub_codes: list[str]) -> str: + """Prepend stub code blocks before the original function source. + + Args: + original_source: The calling function's source code. + stub_codes: List of stub code strings to prepend. + + Returns: + Combined source with stubs before the original function. + """ + if not stub_codes: + return original_source + + parts = stub_codes + [original_source] + return "\n\n".join(parts) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _extract_params(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + """Extract parameter list string from an AST function node.""" + params = [arg.arg for arg in func_node.args.args] + return ", ".join(params) if params else "*args, **kwargs" + + +def _build_serialization_exprs(tree: ast.Module, func_name: str) -> tuple[str, str]: + """Return (args_expr, kwargs_expr) for serializing function parameters. + + When the original signature has named params, we serialize each by name. + Otherwise fall back to generic *args/**kwargs serialization. + """ + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == func_name + ): + param_names = [arg.arg for arg in node.args.args] + if param_names: + items = ", ".join( + f"_b64.b64encode(_cp.dumps({p})).decode()" for p in param_names + ) + return f"[{items}]", "{}" + + # Fallback for *args, **kwargs + return ( + "[_b64.b64encode(_cp.dumps(a)).decode() for a in args]", + "{k: _b64.b64encode(_cp.dumps(v)).decode() for k, v in kwargs.items()}", + ) diff --git a/src/runpod_flash/stubs/live_serverless.py b/src/runpod_flash/stubs/live_serverless.py index 256e22d7..af0b9ab1 100644 --- a/src/runpod_flash/stubs/live_serverless.py +++ b/src/runpod_flash/stubs/live_serverless.py @@ -72,7 +72,7 @@ class LiveServerlessStub(RemoteExecutorStub): def __init__(self, server: LiveServerless): self.server = server - def prepare_request( + async def prepare_request( self, func, dependencies, @@ -83,6 +83,22 @@ def prepare_request( ): source, src_hash = get_function_source(func) + # Detect and resolve @remote dependencies for stacked execution + from .dependency_resolver import ( + build_augmented_source, + generate_stub_code, + resolve_dependencies, + ) + + original_func = inspect.unwrap(func) + remote_deps = await resolve_dependencies(source, original_func.__globals__) + if remote_deps: + stub_codes = [generate_stub_code(dep) for dep in remote_deps] + source = build_augmented_source(source, stub_codes) + # Recompute cache key to include dependency endpoints + dep_key = "|".join(f"{d.name}:{d.endpoint_id}" for d in remote_deps) + src_hash = hashlib.sha256((source + dep_key).encode("utf-8")).hexdigest() + request = { "function_name": func.__name__, "dependencies": dependencies, diff --git a/src/runpod_flash/stubs/load_balancer_sls.py b/src/runpod_flash/stubs/load_balancer_sls.py index d30a9ce5..1f096170 100644 --- a/src/runpod_flash/stubs/load_balancer_sls.py +++ b/src/runpod_flash/stubs/load_balancer_sls.py @@ -17,6 +17,8 @@ serialize_args, serialize_kwargs, ) +from runpod_flash.core.resources.constants import DEFAULT_LB_STUB_TIMEOUT + from .live_serverless import get_function_source log = logging.getLogger(__name__) @@ -47,17 +49,15 @@ class LoadBalancerSlsStub: result = await stub(my_func, deps, sys_deps, accel, arg1, arg2) """ - DEFAULT_TIMEOUT = 30.0 # Default timeout in seconds - def __init__(self, server: Any, timeout: Optional[float] = None) -> None: """Initialize stub with LoadBalancerSlsResource server. Args: server: LoadBalancerSlsResource instance with endpoint_url configured - timeout: Request timeout in seconds (default: 30.0) + timeout: Request timeout in seconds (default: DEFAULT_LB_STUB_TIMEOUT) """ self.server = server - self.timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT + self.timeout = timeout if timeout is not None else DEFAULT_LB_STUB_TIMEOUT def _should_use_execute_endpoint(self, func: Callable[..., Any]) -> bool: """Determine if /execute endpoint should be used for this function. @@ -75,11 +75,15 @@ def _should_use_execute_endpoint(self, func: Callable[..., Any]) -> bool: Returns: True if /execute should be used, False if user route should be used """ - from ..core.resources.live_serverless import LiveLoadBalancer + from ..core.resources.live_serverless import LiveServerlessMixin - # Always use /execute for LiveLoadBalancer (local development) - if isinstance(self.server, LiveLoadBalancer): - log.debug(f"Using /execute endpoint for LiveLoadBalancer: {func.__name__}") + # Always use /execute for live resources (local development) + if isinstance(self.server, LiveServerlessMixin): + log.debug( + "Using /execute endpoint for live resource %s (type=%s)", + func.__name__, + type(self.server).__name__, + ) return True # Check if function has routing metadata @@ -134,7 +138,7 @@ async def __call__( # Determine execution path based on resource type and routing metadata if self._should_use_execute_endpoint(func): # Local development or backward compatibility: use /execute endpoint - request = self._prepare_request( + request = await self._prepare_request( func, dependencies, system_dependencies, @@ -155,7 +159,7 @@ async def __call__( **kwargs, ) - def _prepare_request( + async def _prepare_request( self, func: Callable[..., Any], dependencies: Optional[List[str]], @@ -167,6 +171,7 @@ def _prepare_request( """Prepare HTTP request payload. Extracts function source code and serializes arguments using cloudpickle. + Detects @remote dependencies and injects dispatch stubs for stacked execution. Args: func: Function to serialize @@ -180,6 +185,20 @@ def _prepare_request( Request dictionary with serialized function and arguments """ source, _ = get_function_source(func) + + # Detect and resolve @remote dependencies for stacked execution + from .dependency_resolver import ( + build_augmented_source, + generate_stub_code, + resolve_dependencies, + ) + + original_func = inspect.unwrap(func) + remote_deps = await resolve_dependencies(source, original_func.__globals__) + if remote_deps: + stub_codes = [generate_stub_code(dep) for dep in remote_deps] + source = build_augmented_source(source, stub_codes) + log.debug(f"Extracted source for {func.__name__} ({len(source)} bytes)") request = { diff --git a/src/runpod_flash/stubs/registry.py b/src/runpod_flash/stubs/registry.py index 674e0085..23a50fad 100644 --- a/src/runpod_flash/stubs/registry.py +++ b/src/runpod_flash/stubs/registry.py @@ -3,6 +3,7 @@ from functools import singledispatch from ..core.resources import ( + CpuLiveLoadBalancer, CpuLiveServerless, CpuServerlessEndpoint, LiveLoadBalancer, @@ -41,7 +42,7 @@ async def stubbed_resource( if args == (None,): args = [] - request = stub.prepare_request( + request = await stub.prepare_request( func, dependencies, system_dependencies, @@ -209,3 +210,28 @@ async def stubbed_resource( ) return stubbed_resource + + +@stub_resource.register(CpuLiveLoadBalancer) +def _(resource, **extra): + """Create stub for CpuLiveLoadBalancer (HTTP-based execution, local testing).""" + stub = LoadBalancerSlsStub(resource) + + async def stubbed_resource( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) -> dict: + return await stub( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) + + return stubbed_resource diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index 406e1521..11a3d14a 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -72,7 +72,7 @@ def add(x: int, y: int) -> int: return x + y # Prepare request - request = stub._prepare_request(add, None, None, True, 5, 3) + request = await stub._prepare_request(add, None, None, True, 5, 3) # Verify request structure assert request["function_name"] == "add" diff --git a/tests/integration/test_run_auto_provision.py b/tests/integration/test_run_auto_provision.py deleted file mode 100644 index 9478f442..00000000 --- a/tests/integration/test_run_auto_provision.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Integration tests for flash run --auto-provision command.""" - -import pytest -from unittest.mock import patch, MagicMock -from textwrap import dedent -from typer.testing import CliRunner - -from runpod_flash.cli.main import app - -runner = CliRunner() - - -class TestRunAutoProvision: - """Test flash run --auto-provision integration.""" - - @pytest.fixture - def temp_project(self, tmp_path): - """Create temporary Flash project for testing.""" - # Create main.py with FastAPI app - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - app = FastAPI() - - gpu_config = ServerlessResource( - name="test-gpu", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(resource_config=gpu_config) - async def gpu_task(): - return "result" - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - return tmp_path - - @pytest.fixture - def temp_project_many_resources(self, tmp_path): - """Create temporary project with many resources (> 5).""" - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - app = FastAPI() - - # Create 6 resources to trigger confirmation - configs = [ - ServerlessResource( - name=f"endpoint-{i}", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - for i in range(6) - ] - - @remote(resource_config=configs[0]) - async def task1(): pass - - @remote(resource_config=configs[1]) - async def task2(): pass - - @remote(resource_config=configs[2]) - async def task3(): pass - - @remote(resource_config=configs[3]) - async def task4(): pass - - @remote(resource_config=configs[4]) - async def task5(): pass - - @remote(resource_config=configs[5]) - async def task6(): pass - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - return tmp_path - - def test_run_without_auto_provision(self, temp_project, monkeypatch): - """Test that flash run without --auto-provision doesn't deploy resources.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations to prevent hanging - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to track if it was called - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - runner.invoke(app, ["run"]) - - # Discovery should not be called - mock_discover.assert_not_called() - - def test_run_with_auto_provision_single_resource(self, temp_project, monkeypatch): - """Test flash run --auto-provision with single resource.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock deployment orchestrator - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should be called - mock_provision.assert_called_once() - - def test_run_with_auto_provision_skips_reload(self, temp_project, monkeypatch): - """Test that auto-provision is skipped on reload.""" - monkeypatch.chdir(temp_project) - - # Simulate reload environment - monkeypatch.setenv("UVICORN_RELOADER_PID", "12345") - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock provisioning - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should NOT be called on reload - mock_provision.assert_not_called() - - def test_run_with_auto_provision_many_resources_confirmed( - self, temp_project, monkeypatch - ): - """Test auto-provision with > 5 resources and user confirmation.""" - monkeypatch.chdir(temp_project) - - # Create 6 mock resources - mock_resources = [MagicMock(name=f"endpoint-{i}") for i in range(6)] - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to return > 5 resources - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = mock_resources - - # Mock questionary to simulate user confirmation - with patch( - "runpod_flash.cli.commands.run.questionary.confirm" - ) as mock_confirm: - mock_confirm.return_value.ask.return_value = True - - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Should prompt for confirmation - mock_confirm.assert_called_once() - - # Should provision after confirmation - mock_provision.assert_called_once() - - def test_run_with_auto_provision_many_resources_cancelled( - self, temp_project, monkeypatch - ): - """Test auto-provision with > 5 resources and user cancellation.""" - monkeypatch.chdir(temp_project) - - # Create 6 mock resources - mock_resources = [MagicMock(name=f"endpoint-{i}") for i in range(6)] - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to return > 5 resources - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = mock_resources - - # Mock questionary to simulate user cancellation - with patch( - "runpod_flash.cli.commands.run.questionary.confirm" - ) as mock_confirm: - mock_confirm.return_value.ask.return_value = False - - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Should prompt for confirmation - mock_confirm.assert_called_once() - - # Should NOT provision after cancellation - mock_provision.assert_not_called() - - def test_run_auto_provision_discovery_error(self, temp_project, monkeypatch): - """Test that run handles discovery errors gracefully.""" - monkeypatch.chdir(temp_project) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Mock discovery to raise exception - with patch( - "runpod_flash.cli.commands.run._discover_resources" - ) as mock_discover: - mock_discover.return_value = [] - - runner.invoke(app, ["run", "--auto-provision"]) - - # Server should still start despite discovery error - mock_popen.assert_called_once() - - def test_run_auto_provision_no_resources_found(self, tmp_path, monkeypatch): - """Test auto-provision when no resources are found.""" - monkeypatch.chdir(tmp_path) - - # Create main.py without any @remote decorators - main_file = tmp_path / "main.py" - main_file.write_text( - dedent( - """ - from fastapi import FastAPI - - app = FastAPI() - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - # Mock subprocess to prevent actual uvicorn start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - - with patch("runpod_flash.cli.commands.run.os.killpg"): - with patch( - "runpod_flash.cli.commands.run._provision_resources" - ) as mock_provision: - runner.invoke(app, ["run", "--auto-provision"]) - - # Provisioning should not be called - mock_provision.assert_not_called() - - # Server should still start - mock_popen.assert_called_once() diff --git a/tests/unit/cli/commands/build_utils/test_manifest_mothership.py b/tests/unit/cli/commands/build_utils/test_manifest_mothership.py deleted file mode 100644 index 896eefdf..00000000 --- a/tests/unit/cli/commands/build_utils/test_manifest_mothership.py +++ /dev/null @@ -1,404 +0,0 @@ -"""Tests for mothership resource creation in manifest.""" - -import tempfile -from pathlib import Path -from unittest.mock import patch - -from runpod_flash.cli.commands.build_utils.manifest import ManifestBuilder -from runpod_flash.cli.commands.build_utils.scanner import RemoteFunctionMetadata -from runpod_flash.core.resources.constants import ( - FLASH_CPU_LB_IMAGE, - FLASH_LB_IMAGE, -) - - -class TestManifestMothership: - """Test mothership resource creation in manifest.""" - - def test_manifest_includes_mothership_with_main_py(self): - """Test mothership resource added to manifest when main.py detected.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create a simple function file - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -gpu_config = LiveServerless(name="gpu_worker") - -@remote(resource_config=gpu_config) -def process(data): - return data -""" - ) - - # Change to project directory for detection - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", - remote_functions=[], - ) - manifest = builder.build() - - # Check mothership is in resources - assert "mothership" in manifest["resources"] - mothership = manifest["resources"]["mothership"] - assert mothership["is_mothership"] is True - assert mothership["main_file"] == "main.py" - assert mothership["app_variable"] == "app" - assert mothership["resource_type"] == "CpuLiveLoadBalancer" - assert mothership["imageName"] == FLASH_CPU_LB_IMAGE - - def test_manifest_skips_mothership_without_routes(self): - """Test mothership NOT added if main.py has no routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py without routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() -# No routes defined -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Mothership should NOT be in resources - assert "mothership" not in manifest["resources"] - - def test_manifest_skips_mothership_without_main_py(self): - """Test mothership NOT added if no main.py exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Mothership should NOT be in resources - assert "mothership" not in manifest["resources"] - - def test_manifest_handles_mothership_name_conflict(self): - """Test mothership uses alternate name if conflict with @remote resource.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create a remote function with name "mothership" (conflict) - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -mothership_config = LiveServerless(name="mothership") - -@remote(resource_config=mothership_config) -def process(data): - return data -""" - ) - - # Create remote function metadata with resource named "mothership" - remote_func = RemoteFunctionMetadata( - function_name="process", - module_path="functions", - resource_config_name="mothership", - resource_type="LiveServerless", - is_async=False, - is_class=False, - file_path=func_file, - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", remote_functions=[remote_func] - ) - manifest = builder.build() - - # Original mothership should be in resources - assert "mothership" in manifest["resources"] - # Auto-generated mothership should use alternate name - assert "mothership-entrypoint" in manifest["resources"] - entrypoint = manifest["resources"]["mothership-entrypoint"] - assert entrypoint["is_mothership"] is True - - def test_mothership_resource_config(self): - """Test mothership resource has correct configuration.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - mothership = manifest["resources"]["mothership"] - - # Check all expected fields - assert mothership["resource_type"] == "CpuLiveLoadBalancer" - # Functions should include the FastAPI route - assert len(mothership["functions"]) == 1 - assert mothership["functions"][0]["name"] == "root" - assert mothership["functions"][0]["http_method"] == "GET" - assert mothership["functions"][0]["http_path"] == "/" - assert mothership["is_load_balanced"] is True - assert mothership["is_live_resource"] is True - assert mothership["imageName"] == FLASH_CPU_LB_IMAGE - assert mothership["workersMin"] == 1 - assert mothership["workersMax"] == 1 - - def test_manifest_uses_explicit_mothership_config(self): - """Test explicit mothership.py config takes precedence over auto-detection.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create explicit mothership.py with custom config - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="my-api", - workersMin=3, - workersMax=7, -) -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Check explicit config is used - assert "my-api" in manifest["resources"] - mothership = manifest["resources"]["my-api"] - assert mothership["is_explicit"] is True - assert mothership["workersMin"] == 3 - assert mothership["workersMax"] == 7 - - def test_manifest_skips_auto_detect_with_explicit_config(self): - """Test auto-detection is skipped when explicit config exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with FastAPI routes - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Create explicit mothership.py - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="explicit-mothership", - workersMin=2, - workersMax=4, -) -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - # Check only explicit config is in resources (not auto-detected "mothership") - assert "explicit-mothership" in manifest["resources"] - assert ( - manifest["resources"]["explicit-mothership"]["is_explicit"] is True - ) - assert "mothership" not in manifest["resources"] - - def test_manifest_handles_explicit_mothership_name_conflict(self): - """Test explicit mothership uses alternate name if conflict with @remote.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create explicit mothership.py with name that conflicts with resource - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer - -mothership = CpuLiveLoadBalancer( - name="api", # Will conflict with @remote resource named "api" - workersMin=1, - workersMax=3, -) -""" - ) - - # Create a remote function with name "api" (conflict) - func_file = project_root / "functions.py" - func_file.write_text( - """ -from runpod_flash import remote -from runpod_flash import LiveServerless - -api_config = LiveServerless(name="api") - -@remote(resource_config=api_config) -def process(data): - return data -""" - ) - - remote_func = RemoteFunctionMetadata( - function_name="process", - module_path="functions", - resource_config_name="api", - resource_type="LiveServerless", - is_async=False, - is_class=False, - file_path=func_file, - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder( - project_name="test", remote_functions=[remote_func] - ) - manifest = builder.build() - - # Original resource should be in resources - assert "api" in manifest["resources"] - # Explicit mothership should use alternate name - assert "mothership-entrypoint" in manifest["resources"] - entrypoint = manifest["resources"]["mothership-entrypoint"] - assert entrypoint["is_explicit"] is True - - def test_manifest_explicit_mothership_with_gpu_load_balancer(self): - """Test explicit GPU-based load balancer config.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create explicit mothership.py with GPU load balancer - mothership_file = project_root / "mothership.py" - mothership_file.write_text( - """ -from runpod_flash import LiveLoadBalancer - -mothership = LiveLoadBalancer( - name="gpu-mothership", - workersMin=1, - workersMax=2, -) -""" - ) - - # Create main.py for FastAPI app - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - with patch( - "runpod_flash.cli.commands.build_utils.manifest.Path.cwd", - return_value=project_root, - ): - builder = ManifestBuilder(project_name="test", remote_functions=[]) - manifest = builder.build() - - mothership = manifest["resources"]["gpu-mothership"] - assert mothership["resource_type"] == "LiveLoadBalancer" - assert mothership["imageName"] == FLASH_LB_IMAGE - assert mothership["is_explicit"] is True diff --git a/tests/unit/cli/commands/build_utils/test_path_utilities.py b/tests/unit/cli/commands/build_utils/test_path_utilities.py new file mode 100644 index 00000000..73ec3557 --- /dev/null +++ b/tests/unit/cli/commands/build_utils/test_path_utilities.py @@ -0,0 +1,217 @@ +"""TDD tests for scanner path utility functions. + +Written first (failing) per the plan's TDD requirement. +These test file_to_url_prefix, file_to_resource_name, file_to_module_path. +""" + +import os + + +from runpod_flash.cli.commands.build_utils.scanner import ( + file_to_module_path, + file_to_resource_name, + file_to_url_prefix, +) + + +class TestFileToUrlPrefix: + """Tests for file_to_url_prefix utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → /gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_url_prefix(f, tmp_path) == "/gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → /longruns/stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_url_prefix(f, tmp_path) == "/longruns/stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → /preprocess/first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_url_prefix(f, tmp_path) == "/preprocess/first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → /workers/gpu/inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_url_prefix(f, tmp_path) == "/workers/gpu/inference" + + def test_hyphenated_filename(self, tmp_path): + """my-worker.py → /my-worker (hyphens valid in URLs)""" + f = tmp_path / "my-worker.py" + assert file_to_url_prefix(f, tmp_path) == "/my-worker" + + def test_starts_with_slash(self, tmp_path): + """Result always starts with /""" + f = tmp_path / "worker.py" + result = file_to_url_prefix(f, tmp_path) + assert result.startswith("/") + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_url_prefix(f, tmp_path) + assert ".py" not in result + + +class TestFileToResourceName: + """Tests for file_to_resource_name utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_resource_name(f, tmp_path) == "gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → longruns_stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_resource_name(f, tmp_path) == "longruns_stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → preprocess_first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_resource_name(f, tmp_path) == "preprocess_first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → workers_gpu_inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_resource_name(f, tmp_path) == "workers_gpu_inference" + + def test_hyphenated_filename(self, tmp_path): + """my-worker.py → my_worker (hyphens replaced with underscores for Python identifiers)""" + f = tmp_path / "my-worker.py" + assert file_to_resource_name(f, tmp_path) == "my_worker" + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_resource_name(f, tmp_path) + assert ".py" not in result + + def test_no_path_separators(self, tmp_path): + """Result contains no / or os.sep characters""" + f = tmp_path / "a" / "b" / "worker.py" + result = file_to_resource_name(f, tmp_path) + assert "/" not in result + assert os.sep not in result + + +class TestFileToModulePath: + """Tests for file_to_module_path utility.""" + + def test_root_level_file(self, tmp_path): + """gpu_worker.py → gpu_worker""" + f = tmp_path / "gpu_worker.py" + assert file_to_module_path(f, tmp_path) == "gpu_worker" + + def test_single_subdir(self, tmp_path): + """longruns/stage1.py → longruns.stage1""" + f = tmp_path / "longruns" / "stage1.py" + assert file_to_module_path(f, tmp_path) == "longruns.stage1" + + def test_nested_subdir(self, tmp_path): + """preprocess/first_pass.py → preprocess.first_pass""" + f = tmp_path / "preprocess" / "first_pass.py" + assert file_to_module_path(f, tmp_path) == "preprocess.first_pass" + + def test_deep_nested(self, tmp_path): + """workers/gpu/inference.py → workers.gpu.inference""" + f = tmp_path / "workers" / "gpu" / "inference.py" + assert file_to_module_path(f, tmp_path) == "workers.gpu.inference" + + def test_no_py_extension(self, tmp_path): + """Result does not include .py extension""" + f = tmp_path / "worker.py" + result = file_to_module_path(f, tmp_path) + assert ".py" not in result + + def test_uses_dots_not_slashes(self, tmp_path): + """Result uses dots as separators, not slashes""" + f = tmp_path / "a" / "b" / "worker.py" + result = file_to_module_path(f, tmp_path) + assert "." in result + assert "/" not in result + assert os.sep not in result + + +class TestIsLbRouteHandlerField: + """Tests that RemoteFunctionMetadata.is_lb_route_handler is set correctly.""" + + def test_lb_function_with_method_and_path_is_handler(self, tmp_path): + """An LB @remote function with method= and path= is marked as LB route handler.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "routes.py").write_text( + """ +from runpod_flash import CpuLiveLoadBalancer, remote + +lb_config = CpuLiveLoadBalancer(name="my_lb") + +@remote(lb_config, method="POST", path="/compute") +async def compute(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_lb_route_handler is True + + def test_qb_function_is_not_handler(self, tmp_path): + """A QB @remote function is NOT marked as LB route handler.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "worker.py").write_text( + """ +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_lb_route_handler is False + + def test_init_py_files_excluded(self, tmp_path): + """__init__.py files are excluded from scanning.""" + from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + + (tmp_path / "__init__.py").write_text( + """ +from runpod_flash import LiveServerless, remote + +gpu_config = LiveServerless(name="gpu_worker") + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + (tmp_path / "worker.py").write_text( + """ +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(data: dict) -> dict: + return data +""" + ) + + scanner = RemoteDecoratorScanner(tmp_path) + functions = scanner.discover_remote_functions() + + # Only the worker.py function should be discovered, not __init__.py + assert len(functions) == 1 + assert functions[0].file_path.name == "worker.py" diff --git a/tests/unit/cli/commands/build_utils/test_scanner.py b/tests/unit/cli/commands/build_utils/test_scanner.py index 0572b7ce..3b5227c4 100644 --- a/tests/unit/cli/commands/build_utils/test_scanner.py +++ b/tests/unit/cli/commands/build_utils/test_scanner.py @@ -878,3 +878,252 @@ async def process_data(): assert routes[0].http_path == "/api/process" assert routes[0].is_async is True assert routes[0].http_method == "POST" + + +def test_class_methods_extraction(): + """Test that public methods are extracted from @remote classes.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "gpu_worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +gpu_config = LiveServerless(name="gpu_worker") + +@remote(gpu_config) +class SimpleSD: + def __init__(self): + self.model = None + + def generate_image(self, prompt): + return {"image": "base64..."} + + def upscale(self, image): + return {"image": "upscaled..."} + + def _load_model(self): + pass + + def __repr__(self): + return "SimpleSD" +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + meta = functions[0] + assert meta.function_name == "SimpleSD" + assert meta.is_class is True + assert meta.class_methods == ["generate_image", "upscale"] + + +def test_class_methods_excludes_private_and_dunder(): + """Test that _private and __dunder__ methods are excluded from class_methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class MyWorker: + def __init__(self): + pass + + def __repr__(self): + return "MyWorker" + + def _internal_helper(self): + pass + + async def process(self, data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].class_methods == ["process"] + + +def test_class_with_no_public_methods(): + """Test @remote class with only private/dunder methods has empty class_methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class EmptyWorker: + def __init__(self): + pass + + def __call__(self, data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].class_methods == [] + + +def test_function_has_empty_class_methods(): + """Test that regular @remote functions have empty class_methods list.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def my_function(data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].is_class is False + assert functions[0].class_methods == [] + + +def test_param_names_single_param(): + """Test that param_names extracts a single parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def process(data): + return data +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].param_names == ["data"] + + +def test_param_names_zero_params(): + """Test that param_names is empty for zero-parameter functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def list_images() -> dict: + return {"images": []} +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].param_names == [] + + +def test_param_names_multiple_params(): + """Test that param_names extracts multiple parameters.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def transform(text: str, operation: str = "uppercase") -> dict: + return {"result": text} +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].param_names == ["text", "operation"] + + +def test_class_method_params_extraction(): + """Test that class_method_params extracts params for each public method.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + + test_file = project_dir / "worker.py" + test_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class ImageProcessor: + def __init__(self): + pass + + def generate(self, prompt: str, width: int = 512): + return {} + + def list_models(self): + return [] + + def _internal(self): + pass +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + meta = functions[0] + assert meta.is_class is True + assert meta.class_methods == ["generate", "list_models"] + assert meta.class_method_params == { + "generate": ["prompt", "width"], + "list_models": [], + } + # Classes should have empty param_names + assert meta.param_names == [] diff --git a/tests/unit/cli/commands/test_run.py b/tests/unit/cli/commands/test_run.py new file mode 100644 index 00000000..c8b3311e --- /dev/null +++ b/tests/unit/cli/commands/test_run.py @@ -0,0 +1,589 @@ +"""Tests for flash run dev server generation.""" + +import tempfile +from pathlib import Path + +from runpod_flash.cli.commands.run import ( + WorkerInfo, + _generate_flash_server, + _scan_project_workers, +) + + +def test_scan_separates_classes_from_functions(): + """Test that _scan_project_workers puts classes in class_remotes, not functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "gpu_worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="gpu_worker") + +@remote(config) +async def process(data): + return data + +@remote(config) +class SimpleSD: + def generate_image(self, prompt): + return {"image": "data"} + + def upscale(self, image): + return {"image": "upscaled"} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert worker.worker_type == "QB" + assert worker.functions == ["process"] + assert len(worker.class_remotes) == 1 + assert worker.class_remotes[0]["name"] == "SimpleSD" + assert worker.class_remotes[0]["methods"] == ["generate_image", "upscale"] + + +def test_scan_class_only_worker(): + """Test scanning a file with only a class-based @remote.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "sd_worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="sd_worker") + +@remote(config) +class StableDiffusion: + def __init__(self): + self.model = None + + def generate(self, prompt): + return {"image": "data"} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert worker.worker_type == "QB" + assert worker.functions == [] + assert len(worker.class_remotes) == 1 + assert worker.class_remotes[0]["name"] == "StableDiffusion" + assert worker.class_remotes[0]["methods"] == ["generate"] + + +def test_codegen_class_single_method(): + """Test generated server.py for a class with a single method uses short URL.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("sd_worker.py"), + url_prefix="/sd_worker", + module_path="sd_worker", + resource_name="sd_worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "StableDiffusion", + "methods": ["generate"], + "method_params": {"generate": ["prompt"]}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_instance_StableDiffusion = StableDiffusion()" in content + assert "_call_with_body(_instance_StableDiffusion.generate, body)" in content + assert "body: _sd_worker_StableDiffusion_generate_Input" in content + assert "_make_input_model" in content + assert '"/sd_worker/run_sync"' in content + # Single method: no method name in URL + assert '"/sd_worker/generate/run_sync"' not in content + + +def test_codegen_class_multiple_methods(): + """Test generated server.py for a class with multiple methods uses method URLs.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("gpu_worker.py"), + url_prefix="/gpu_worker", + module_path="gpu_worker", + resource_name="gpu_worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "SimpleSD", + "methods": ["generate_image", "upscale"], + "method_params": { + "generate_image": ["prompt"], + "upscale": ["image"], + }, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_instance_SimpleSD = SimpleSD()" in content + assert '"/gpu_worker/generate_image/run_sync"' in content + assert '"/gpu_worker/upscale/run_sync"' in content + assert "_call_with_body(_instance_SimpleSD.generate_image, body)" in content + assert "_call_with_body(_instance_SimpleSD.upscale, body)" in content + assert "body: _gpu_worker_SimpleSD_generate_image_Input" in content + assert "body: _gpu_worker_SimpleSD_upscale_Input" in content + + +def test_codegen_mixed_function_and_class(): + """Test codegen when a worker has both functions and class remotes.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + class_remotes=[ + { + "name": "MyModel", + "methods": ["predict"], + "method_params": {"predict": ["data"]}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Both should use multi-callable URL pattern (total_callables = 2) + assert '"/worker/process/run_sync"' in content + assert '"/worker/predict/run_sync"' in content + assert "_instance_MyModel = MyModel()" in content + assert "_call_with_body(_instance_MyModel.predict, body)" in content + assert "_call_with_body(process, body)" in content + + +def test_codegen_function_only(): + """Test that function-only workers use Pydantic model and _call_with_body.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("simple.py"), + url_prefix="/simple", + module_path="simple", + resource_name="simple", + worker_type="QB", + functions=["process"], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Single function: short URL + assert '"/simple/run_sync"' in content + assert "_call_with_body(process, body)" in content + assert "_simple_process_Input = _make_input_model(" in content + assert "body: _simple_process_Input" in content + # No instance creation + assert "_instance_" not in content + + +def test_codegen_zero_param_function(): + """Test generated code uses await fn() for zero-param functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["list_images"], + function_params={"list_images": []}, + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "await list_images()" in content + assert 'body.get("input"' not in content + # Handler should not accept body parameter + assert "async def worker_run_sync():" in content + + +def test_codegen_multi_param_function(): + """Test generated code uses _call_with_body for multi-param functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["transform"], + function_params={"transform": ["text", "operation"]}, + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_call_with_body(transform, body)" in content + assert "_worker_transform_Input = _make_input_model(" in content + assert "body: _worker_transform_Input" in content + + +def test_codegen_single_param_function(): + """Test generated code uses _call_with_body for single-param functions.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + function_params={"process": ["data"]}, + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_call_with_body(process, body)" in content + assert "body: _worker_process_Input" in content + + +def test_codegen_zero_param_class_method(): + """Test generated code uses await instance.method() for zero-param class methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "ImageProcessor", + "methods": ["list_models"], + "method_params": {"list_models": []}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "await _instance_ImageProcessor.list_models()" in content + # Handler should not accept body parameter + assert "worker_ImageProcessor_run_sync():" in content + + +def test_codegen_multi_param_class_method(): + """Test generated code uses _call_with_body for multi-param class methods.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=[], + class_remotes=[ + { + "name": "ImageProcessor", + "methods": ["generate"], + "method_params": {"generate": ["prompt", "width"]}, + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "_call_with_body(_instance_ImageProcessor.generate, body)" in content + assert "body: _worker_ImageProcessor_generate_Input" in content + # Model creation uses _class_type to get original method signature + assert "_class_type" in content + + +def test_codegen_backward_compat_no_method_params(): + """Test that missing method_params in class_remotes uses _call_with_body.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=[], + class_remotes=[ + {"name": "OldStyle", "methods": ["process"]}, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Should use _call_with_body when method_params not provided (params=None) + assert "_call_with_body(_instance_OldStyle.process, body)" in content + assert "body: _worker_OldStyle_process_Input" in content + + +def test_scan_populates_function_params(): + """Test that _scan_project_workers populates function_params from scanner.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +async def no_params() -> dict: + return {} + +@remote(config) +async def one_param(data: dict) -> dict: + return data + +@remote(config) +async def multi_params(text: str, mode: str = "default") -> dict: + return {"text": text} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert worker.function_params == { + "no_params": [], + "one_param": ["data"], + "multi_params": ["text", "mode"], + } + + +def test_scan_populates_class_method_params(): + """Test that _scan_project_workers populates method_params in class_remotes.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + worker_file = project_root / "worker.py" + worker_file.write_text( + """ +from runpod_flash import LiveServerless, remote + +config = LiveServerless(name="worker") + +@remote(config) +class Processor: + def run(self, data: dict): + return data + + def status(self): + return {"ok": True} +""" + ) + + workers = _scan_project_workers(project_root) + + assert len(workers) == 1 + worker = workers[0] + assert len(worker.class_remotes) == 1 + cls = worker.class_remotes[0] + assert cls["method_params"] == { + "run": ["data"], + "status": [], + } + + +def test_codegen_lb_get_with_path_params(): + """Test LB GET route with path params generates proper Swagger-compatible handler.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["get_image"], + lb_routes=[ + { + "method": "GET", + "path": "/images/{file_id}", + "fn_name": "get_image", + "config_variable": "cpu_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # Handler must declare file_id as a typed parameter for Swagger + assert "file_id: str" in content + # Path param must be forwarded in the body dict + assert '"file_id": file_id' in content + # Should NOT use bare request: Request as only param + assert ( + "async def _route_worker_get_image(file_id: str, request: Request):" + in content + ) + + +def test_codegen_lb_get_without_path_params(): + """Test LB GET route without path params uses request: Request.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["health"], + lb_routes=[ + { + "method": "GET", + "path": "/health", + "fn_name": "health", + "config_variable": "cpu_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "async def _route_worker_health(request: Request):" in content + assert "dict(request.query_params)" in content + + +def test_codegen_lb_post_with_path_params(): + """Test LB POST route with path params includes both body and path params.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["update_item"], + lb_routes=[ + { + "method": "POST", + "path": "/items/{item_id}", + "fn_name": "update_item", + "config_variable": "api_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + # POST handler must have typed body and path param + assert ( + "async def _route_worker_update_item(body: _worker_update_item_Input, item_id: str):" + in content + ) + assert '"item_id": item_id' in content + assert "_to_dict(body)" in content + + +def test_codegen_lb_get_with_multiple_path_params(): + """Test LB GET route with multiple path params.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + workers = [ + WorkerInfo( + file_path=Path("worker.py"), + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="LB", + functions=["get_version"], + lb_routes=[ + { + "method": "GET", + "path": "/items/{item_id}/versions/{version_id}", + "fn_name": "get_version", + "config_variable": "api_config", + }, + ], + ), + ] + + server_path = _generate_flash_server(project_root, workers) + content = server_path.read_text() + + assert "item_id: str" in content + assert "version_id: str" in content + assert '"item_id": item_id' in content + assert '"version_id": version_id' in content diff --git a/tests/unit/cli/commands/test_run_server_helpers.py b/tests/unit/cli/commands/test_run_server_helpers.py new file mode 100644 index 00000000..ed6529dd --- /dev/null +++ b/tests/unit/cli/commands/test_run_server_helpers.py @@ -0,0 +1,179 @@ +"""Tests for _run_server_helpers: make_input_model, call_with_body, to_dict.""" + +from typing import Any + +import pytest +from pydantic import BaseModel + +from runpod_flash.cli.commands._run_server_helpers import ( + call_with_body, + make_input_model, + to_dict, +) + + +# --- make_input_model --- + + +def test_make_input_model_basic(): + """Function with typed params produces a Pydantic model with correct fields.""" + + async def process(text: str, count: int): + pass + + Model = make_input_model("process_Input", process) + assert Model is not None + assert issubclass(Model, BaseModel) + fields = Model.model_fields + assert "text" in fields + assert "count" in fields + assert fields["text"].annotation is str + assert fields["count"].annotation is int + + +def test_make_input_model_with_defaults(): + """Default values are preserved in the generated model.""" + + async def transform(text: str, mode: str = "default", limit: int = 10): + pass + + Model = make_input_model("transform_Input", transform) + assert Model is not None + fields = Model.model_fields + assert fields["text"].is_required() + assert not fields["mode"].is_required() + assert fields["mode"].default == "default" + assert fields["limit"].default == 10 + + +def test_make_input_model_zero_params(): + """Zero-param function returns None.""" + + async def health(): + pass + + result = make_input_model("health_Input", health) + assert result is None + + +def test_make_input_model_skips_self(): + """Self parameter is excluded from the model (class methods).""" + + class Worker: + def generate(self, prompt: str): + pass + + Model = make_input_model("generate_Input", Worker().generate) + assert Model is not None + assert "self" not in Model.model_fields + assert "prompt" in Model.model_fields + + +def test_make_input_model_untyped_params(): + """Untyped params get Any annotation.""" + + async def process(data): + pass + + Model = make_input_model("process_Input", process) + assert Model is not None + assert Model.model_fields["data"].annotation is Any + + +def test_make_input_model_skips_var_positional_and_keyword(): + """Proxy-style (*args, **kwargs) signatures return None, not a model with args/kwargs fields.""" + + async def method_proxy(*args, **kwargs): + pass + + result = make_input_model("proxy_Input", method_proxy) + assert result is None + + +def test_make_input_model_mixed_regular_and_var_keyword(): + """Regular params are kept, **kwargs is skipped.""" + + async def process(text: str, **extra): + pass + + Model = make_input_model("process_Input", process) + assert Model is not None + assert "text" in Model.model_fields + assert "extra" not in Model.model_fields + + +def test_make_input_model_failure_graceful(): + """Bad input returns None instead of raising.""" + result = make_input_model("bad_Input", 42) + assert result is None + + +# --- call_with_body --- + + +@pytest.mark.asyncio +async def test_call_with_body_pydantic(): + """Pydantic model body is spread as kwargs via model_dump().""" + received = {} + + async def process(text: str, count: int): + received.update(text=text, count=count) + return {"ok": True} + + Model = make_input_model("process_Input", process) + body = Model(text="hello", count=5) + result = await call_with_body(process, body) + assert result == {"ok": True} + assert received == {"text": "hello", "count": 5} + + +@pytest.mark.asyncio +async def test_call_with_body_dict_fallback(): + """Plain dict body uses _map_body_to_params path.""" + received = {} + + async def process(data): + received["data"] = data + return {"ok": True} + + result = await call_with_body(process, {"data": "value"}) + assert result == {"ok": True} + assert received == {"data": "value"} + + +@pytest.mark.asyncio +async def test_call_with_body_dict_with_input_wrapper(): + """Dict body with 'input' key unwraps correctly.""" + received = {} + + async def process(text: str): + received["text"] = text + return text + + result = await call_with_body(process, {"input": {"text": "hello"}}) + assert result == "hello" + assert received == {"text": "hello"} + + +# --- to_dict --- + + +def test_to_dict_pydantic(): + """Pydantic model is converted to plain dict.""" + + async def process(text: str, count: int): + pass + + Model = make_input_model("process_Input", process) + body = Model(text="hello", count=5) + result = to_dict(body) + assert result == {"text": "hello", "count": 5} + assert isinstance(result, dict) + + +def test_to_dict_plain_dict(): + """Plain dict passes through unchanged.""" + body = {"text": "hello", "count": 5} + result = to_dict(body) + assert result == body + assert result is body diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index a652aa75..73bee3bd 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -1,10 +1,19 @@ """Unit tests for run CLI command.""" import pytest +from pathlib import Path from unittest.mock import patch, MagicMock from typer.testing import CliRunner from runpod_flash.cli.main import app +from runpod_flash.cli.commands.run import ( + WorkerInfo, + _generate_flash_server, + _has_numeric_module_segments, + _make_import_line, + _module_parent_subdir, + _sanitize_fn_name, +) @pytest.fixture @@ -15,15 +24,27 @@ def runner(): @pytest.fixture def temp_fastapi_app(tmp_path): - """Create minimal FastAPI app for testing.""" - main_file = tmp_path / "main.py" - main_file.write_text("from fastapi import FastAPI\napp = FastAPI()") + """Create minimal Flash project with @remote function for testing.""" + worker_file = tmp_path / "worker.py" + worker_file.write_text( + "from runpod_flash import LiveServerless, remote\n" + "gpu_config = LiveServerless(name='test_worker')\n" + "@remote(gpu_config)\n" + "async def process(data: dict) -> dict:\n" + " return data\n" + ) return tmp_path class TestRunCommandEnvironmentVariables: """Test flash run command environment variable support.""" + @pytest.fixture(autouse=True) + def patch_watcher(self): + """Prevent the background watcher thread from blocking tests.""" + with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): + yield + def test_port_from_environment_variable( self, runner, temp_fastapi_app, monkeypatch ): @@ -215,3 +236,571 @@ def test_short_port_flag_overrides_environment( assert "--port" in call_args port_index = call_args.index("--port") assert call_args[port_index + 1] == "7000" + + +class TestRunCommandHotReload: + """Test flash run hot-reload behavior.""" + + @pytest.fixture(autouse=True) + def patch_watcher(self): + """Prevent the background watcher thread from blocking tests.""" + with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): + yield + + def _invoke_run(self, runner, monkeypatch, temp_fastapi_app, extra_args=None): + """Helper: invoke flash run and return the Popen call args.""" + monkeypatch.chdir(temp_fastapi_app) + monkeypatch.delenv("FLASH_PORT", raising=False) + monkeypatch.delenv("FLASH_HOST", raising=False) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + runner.invoke(app, ["run"] + (extra_args or [])) + + return mock_popen.call_args[0][0] + + def test_reload_watches_flash_server_py( + self, runner, temp_fastapi_app, monkeypatch + ): + """Uvicorn watches .flash/server.py, not the whole project.""" + cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) + + assert "--reload" in cmd + assert "--reload-dir" in cmd + reload_dir_index = cmd.index("--reload-dir") + assert cmd[reload_dir_index + 1] == ".flash" + + assert "--reload-include" in cmd + reload_include_index = cmd.index("--reload-include") + assert cmd[reload_include_index + 1] == "server.py" + + def test_reload_does_not_watch_project_root( + self, runner, temp_fastapi_app, monkeypatch + ): + """Uvicorn reload-dir must not be '.' to prevent double-reload.""" + cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) + + reload_dir_index = cmd.index("--reload-dir") + assert cmd[reload_dir_index + 1] != "." + + def test_no_reload_skips_watcher_thread( + self, runner, temp_fastapi_app, monkeypatch + ): + """--no-reload: neither uvicorn reload args nor watcher thread started.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + runner.invoke(app, ["run", "--no-reload"]) + + cmd = mock_popen.call_args[0][0] + assert "--reload" not in cmd + mock_thread.start.assert_not_called() + + def test_watcher_thread_started_on_reload( + self, runner, temp_fastapi_app, monkeypatch, patch_watcher + ): + """When reload=True, the background watcher thread is started.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + runner.invoke(app, ["run"]) + + mock_thread.start.assert_called_once() + + def test_watcher_thread_stopped_on_keyboard_interrupt( + self, runner, temp_fastapi_app, monkeypatch + ): + """KeyboardInterrupt sets stop_event and joins the watcher thread.""" + monkeypatch.chdir(temp_fastapi_app) + + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + + with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): + with patch("runpod_flash.cli.commands.run.os.killpg"): + with patch( + "runpod_flash.cli.commands.run.threading.Thread" + ) as mock_thread_cls: + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + with patch( + "runpod_flash.cli.commands.run.threading.Event" + ) as mock_event_cls: + mock_stop = MagicMock() + mock_event_cls.return_value = mock_stop + + runner.invoke(app, ["run"]) + + mock_stop.set.assert_called_once() + mock_thread.join.assert_called_once_with(timeout=2) + + +class TestWatchAndRegenerate: + """Unit tests for the _watch_and_regenerate background function.""" + + def test_regenerates_server_py_on_py_file_change(self, tmp_path): + """When a .py file changes, server.py is regenerated.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch( + "runpod_flash.cli.commands.run._scan_project_workers", return_value=[] + ) as mock_scan: + with patch( + "runpod_flash.cli.commands.run._generate_flash_server" + ) as mock_gen: + with patch( + "runpod_flash.cli.commands.run._watchfiles_watch" + ) as mock_watch: + # Yield one batch of changes then stop + mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) + stop.set() # ensures the loop exits after one iteration + _watch_and_regenerate(tmp_path, stop) + + mock_scan.assert_called_once_with(tmp_path) + mock_gen.assert_called_once() + + def test_ignores_non_py_changes(self, tmp_path): + """Changes to non-.py files do not trigger regeneration.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch("runpod_flash.cli.commands.run._scan_project_workers") as mock_scan: + with patch( + "runpod_flash.cli.commands.run._generate_flash_server" + ) as mock_gen: + with patch( + "runpod_flash.cli.commands.run._watchfiles_watch" + ) as mock_watch: + mock_watch.return_value = iter([{(1, "/path/to/README.md")}]) + _watch_and_regenerate(tmp_path, stop) + + mock_scan.assert_not_called() + mock_gen.assert_not_called() + + def test_scan_error_does_not_crash_watcher(self, tmp_path): + """If regeneration raises, the watcher logs a warning and continues.""" + import threading + from runpod_flash.cli.commands.run import _watch_and_regenerate + + stop = threading.Event() + + with patch( + "runpod_flash.cli.commands.run._scan_project_workers", + side_effect=RuntimeError("scan failed"), + ): + with patch("runpod_flash.cli.commands.run._watchfiles_watch") as mock_watch: + mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) + # Should not raise + _watch_and_regenerate(tmp_path, stop) + + +class TestGenerateFlashServer: + """Test _generate_flash_server() route code generation.""" + + def _make_lb_worker(self, tmp_path: Path, method: str = "GET") -> WorkerInfo: + return WorkerInfo( + file_path=tmp_path / "api.py", + url_prefix="/api", + module_path="api", + resource_name="api", + worker_type="LB", + functions=["list_routes"], + lb_routes=[ + { + "method": method, + "path": "/routes/list", + "fn_name": "list_routes", + "config_variable": "api_config", + } + ], + ) + + def test_post_lb_route_generates_body_param(self, tmp_path): + """POST/PUT/PATCH/DELETE LB routes use typed body for OpenAPI docs.""" + for method in ("POST", "PUT", "PATCH", "DELETE"): + worker = self._make_lb_worker(tmp_path, method) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "body: _api_list_routes_Input" in content + assert "_lb_execute(api_config, list_routes, _to_dict(body))" in content + + def test_get_lb_route_uses_query_params(self, tmp_path): + """GET LB routes pass query params as a dict.""" + worker = self._make_lb_worker(tmp_path, "GET") + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "async def _route_api_list_routes(request: Request):" in content + assert ( + "_lb_execute(api_config, list_routes, dict(request.query_params))" + in content + ) + + def test_lb_config_var_and_function_imported(self, tmp_path): + """LB config vars and functions are both imported for remote dispatch.""" + worker = self._make_lb_worker(tmp_path) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "from api import api_config" in content + assert "from api import list_routes" in content + + def test_lb_execute_import_present_when_lb_routes_exist(self, tmp_path): + """server.py imports _lb_execute when there are LB workers.""" + worker = self._make_lb_worker(tmp_path) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "_lb_execute" in content + assert "lb_execute" in content + + def test_qb_function_still_imported_directly(self, tmp_path): + """QB workers still import and call functions directly.""" + worker = WorkerInfo( + file_path=tmp_path / "worker.py", + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "from worker import process" in content + assert "_call_with_body(process, body)" in content + + +class TestSanitizeFnName: + """Test _sanitize_fn_name handles leading-digit identifiers.""" + + def test_normal_name_unchanged(self): + assert _sanitize_fn_name("worker_run_sync") == "worker_run_sync" + + def test_leading_digit_gets_underscore_prefix(self): + assert _sanitize_fn_name("01_hello_run_sync") == "_01_hello_run_sync" + + def test_slashes_replaced(self): + assert _sanitize_fn_name("a/b/c") == "a_b_c" + + def test_dots_and_hyphens_replaced(self): + assert _sanitize_fn_name("a.b-c") == "a_b_c" + + def test_numeric_after_slash(self): + assert _sanitize_fn_name("01_foo/02_bar") == "_01_foo_02_bar" + + +class TestHasNumericModuleSegments: + """Test _has_numeric_module_segments detects digit-prefixed segments.""" + + def test_normal_module_path(self): + assert _has_numeric_module_segments("worker") is False + + def test_dotted_normal(self): + assert _has_numeric_module_segments("longruns.stage1") is False + + def test_leading_digit_first_segment(self): + assert _has_numeric_module_segments("01_hello.worker") is True + + def test_leading_digit_nested_segment(self): + assert _has_numeric_module_segments("getting_started.01_hello.worker") is True + + def test_digit_in_middle_not_leading(self): + assert _has_numeric_module_segments("stage1.worker") is False + + +class TestModuleParentSubdir: + """Test _module_parent_subdir extracts parent directory from dotted path.""" + + def test_top_level_returns_none(self): + assert _module_parent_subdir("worker") is None + + def test_single_parent(self): + assert _module_parent_subdir("01_hello.gpu_worker") == "01_hello" + + def test_nested_parent(self): + assert ( + _module_parent_subdir("01_getting_started.03_mixed.pipeline") + == "01_getting_started/03_mixed" + ) + + +class TestMakeImportLine: + """Test _make_import_line generates correct import syntax.""" + + def test_normal_module_uses_from_import(self): + result = _make_import_line("worker", "process") + assert result == "from worker import process" + + def test_numeric_module_uses_flash_import(self): + result = _make_import_line("01_hello.gpu_worker", "gpu_hello") + assert ( + result + == 'gpu_hello = _flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' + ) + + def test_nested_numeric_includes_full_subdir(self): + result = _make_import_line( + "01_getting_started.01_hello.gpu_worker", "gpu_hello" + ) + assert '"01_getting_started/01_hello"' in result + + def test_top_level_numeric_module_no_subdir(self): + result = _make_import_line("01_worker", "process") + assert result == 'process = _flash_import("01_worker", "process")' + + +class TestGenerateFlashServerNumericDirs: + """Test _generate_flash_server with numeric-prefixed directory names.""" + + def test_qb_numeric_dir_uses_flash_import(self, tmp_path): + """QB workers in numeric dirs use _flash_import with scoped sys.path.""" + worker = WorkerInfo( + file_path=tmp_path / "01_hello" / "gpu_worker.py", + url_prefix="/01_hello/gpu_worker", + module_path="01_hello.gpu_worker", + resource_name="01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + # Must NOT contain invalid 'from 01_hello...' import + assert "from 01_hello" not in content + # Must have _flash_import helper and importlib + assert "import importlib as _importlib" in content + assert "def _flash_import(" in content + assert ( + '_flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' in content + ) + + def test_qb_numeric_dir_function_name_prefixed(self, tmp_path): + """QB handler function names starting with digits get '_' prefix.""" + worker = WorkerInfo( + file_path=tmp_path / "01_hello" / "gpu_worker.py", + url_prefix="/01_hello/gpu_worker", + module_path="01_hello.gpu_worker", + resource_name="01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + # Function name must start with '_', not a digit + assert ( + "async def _01_hello_gpu_worker_run_sync(body: _01_hello_gpu_worker_gpu_hello_Input):" + in content + ) + + def test_lb_numeric_dir_uses_flash_import(self, tmp_path): + """LB workers in numeric dirs use _flash_import for config and function imports.""" + worker = WorkerInfo( + file_path=tmp_path / "03_advanced" / "05_lb" / "cpu_lb.py", + url_prefix="/03_advanced/05_lb/cpu_lb", + module_path="03_advanced.05_lb.cpu_lb", + resource_name="03_advanced_05_lb_cpu_lb", + worker_type="LB", + functions=["validate_data"], + lb_routes=[ + { + "method": "POST", + "path": "/validate", + "fn_name": "validate_data", + "config_variable": "cpu_config", + } + ], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + assert "from 03_advanced" not in content + assert ( + '_flash_import("03_advanced.05_lb.cpu_lb", "cpu_config", "03_advanced/05_lb")' + in content + ) + assert ( + '_flash_import("03_advanced.05_lb.cpu_lb", "validate_data", "03_advanced/05_lb")' + in content + ) + + def test_mixed_numeric_and_normal_dirs(self, tmp_path): + """Normal modules use 'from' imports, numeric modules use _flash_import.""" + normal_worker = WorkerInfo( + file_path=tmp_path / "worker.py", + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + ) + numeric_worker = WorkerInfo( + file_path=tmp_path / "01_hello" / "gpu_worker.py", + url_prefix="/01_hello/gpu_worker", + module_path="01_hello.gpu_worker", + resource_name="01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + content = _generate_flash_server( + tmp_path, [normal_worker, numeric_worker] + ).read_text() + + # Normal worker uses standard import + assert "from worker import process" in content + # Numeric worker uses scoped _flash_import + assert ( + '_flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' in content + ) + + def test_no_importlib_when_all_normal_dirs(self, tmp_path): + """importlib and _flash_import are not emitted when no numeric dirs exist.""" + worker = WorkerInfo( + file_path=tmp_path / "worker.py", + url_prefix="/worker", + module_path="worker", + resource_name="worker", + worker_type="QB", + functions=["process"], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + assert "importlib" not in content + assert "_flash_import" not in content + + def test_scoped_import_includes_subdir(self, tmp_path): + """_flash_import calls pass the subdirectory for sibling import scoping.""" + worker = WorkerInfo( + file_path=tmp_path / "01_getting_started" / "03_mixed" / "pipeline.py", + url_prefix="/01_getting_started/03_mixed/pipeline", + module_path="01_getting_started.03_mixed.pipeline", + resource_name="01_getting_started_03_mixed_pipeline", + worker_type="LB", + functions=["classify"], + lb_routes=[ + { + "method": "POST", + "path": "/classify", + "fn_name": "classify", + "config_variable": "pipeline_config", + } + ], + ) + content = _generate_flash_server(tmp_path, [worker]).read_text() + + # Must scope to correct subdirectory, not add all dirs to sys.path + assert '"01_getting_started/03_mixed"' in content + # No global sys.path additions for subdirs — only the project root + # line at the top and the one inside _flash_import helper body + lines = content.split("\n") + global_sys_path_lines = [ + line + for line in lines + if "sys.path.insert" in line and not line.startswith(" ") + ] + assert len(global_sys_path_lines) == 1 + + def test_generated_server_is_valid_python(self, tmp_path): + """Generated server.py with numeric dirs must be parseable Python.""" + worker = WorkerInfo( + file_path=tmp_path / "01_getting_started" / "01_hello" / "gpu_worker.py", + url_prefix="/01_getting_started/01_hello/gpu_worker", + module_path="01_getting_started.01_hello.gpu_worker", + resource_name="01_getting_started_01_hello_gpu_worker", + worker_type="QB", + functions=["gpu_hello"], + ) + server_path = _generate_flash_server(tmp_path, [worker]) + content = server_path.read_text() + + # Must parse without SyntaxError + import ast + + ast.parse(content) + + +class TestMapBodyToParams: + """Tests for _map_body_to_params — maps HTTP body to function arguments.""" + + def test_body_keys_match_params_spreads_as_kwargs(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def process(name: str, value: int): + pass + + result = _map_body_to_params(process, {"name": "test", "value": 42}) + assert result == {"name": "test", "value": 42} + + def test_body_keys_mismatch_wraps_in_first_param(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def run_pipeline(input_data: dict): + pass + + body = {"text": "hello", "mode": "fast"} + result = _map_body_to_params(run_pipeline, body) + assert result == {"input_data": {"text": "hello", "mode": "fast"}} + + def test_non_dict_body_wraps_in_first_param(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def run_pipeline(input_data): + pass + + result = _map_body_to_params(run_pipeline, [1, 2, 3]) + assert result == {"input_data": [1, 2, 3]} + + def test_no_params_returns_empty(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def no_args(): + pass + + result = _map_body_to_params(no_args, {"key": "val"}) + assert result == {} + + def test_partial_key_match_wraps_in_first_param(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def process(name: str, value: int): + pass + + result = _map_body_to_params(process, {"name": "test", "extra": "bad"}) + assert result == {"name": {"name": "test", "extra": "bad"}} + + def test_empty_dict_body_spreads_as_empty_kwargs(self): + from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params + + def run_pipeline(input_data: dict): + pass + + result = _map_body_to_params(run_pipeline, {}) + assert result == {} diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index b5e3cbbe..25899904 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -467,6 +467,32 @@ def test_is_deployed_false_when_no_id(self): assert serverless.is_deployed() is False + def test_is_deployed_skips_health_check_during_live_provisioning(self, monkeypatch): + """During flash run, is_deployed returns True based on ID alone.""" + monkeypatch.setenv("FLASH_IS_LIVE_PROVISIONING", "true") + serverless = ServerlessResource(name="test") + serverless.id = "ep-live-123" + + # health() must NOT be called — no mock needed, any call would raise + assert serverless.is_deployed() is True + + def test_is_deployed_uses_health_check_outside_live_provisioning(self, monkeypatch): + """Outside flash run, is_deployed falls back to health check.""" + monkeypatch.delenv("FLASH_IS_LIVE_PROVISIONING", raising=False) + serverless = ServerlessResource(name="test") + serverless.id = "ep-123" + + mock_endpoint = MagicMock() + mock_endpoint.health.return_value = {"workers": {}} + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + assert serverless.is_deployed() is True + mock_endpoint.health.assert_called_once() + @pytest.mark.asyncio async def test_deploy_already_deployed(self): """Test deploy returns early when already deployed.""" @@ -938,6 +964,42 @@ def test_is_deployed_with_exception(self): assert result is False + def test_payload_exclude_adds_template_when_template_id_set(self): + """_payload_exclude excludes template field when templateId is already set.""" + serverless = ServerlessResource(name="test") + serverless.templateId = "tmpl-123" + + excluded = serverless._payload_exclude() + + assert "template" in excluded + + def test_payload_exclude_tolerates_both_template_id_and_template(self): + """_payload_exclude does not raise when both templateId and template are set. + + After deploy mutates the config object, both fields can coexist. + templateId takes precedence and template should be excluded. + """ + serverless = ServerlessResource(name="test") + serverless.templateId = "tmpl-123" + serverless.template = PodTemplate( + name="test-template", + imageName="runpod/test:latest", + containerDiskInGb=20, + ) + + excluded = serverless._payload_exclude() + + assert "template" in excluded + + def test_payload_exclude_does_not_exclude_template_without_template_id(self): + """_payload_exclude does not exclude template when templateId is absent.""" + serverless = ServerlessResource(name="test") + serverless.templateId = None + + excluded = serverless._payload_exclude() + + assert "template" not in excluded + def test_reverse_sync_from_backend_response(self): """Test reverse sync when receiving backend response with gpuIds.""" # This tests the lines 173-176 which convert gpuIds back to gpus list diff --git a/tests/unit/test_dependency_resolver.py b/tests/unit/test_dependency_resolver.py new file mode 100644 index 00000000..c916c794 --- /dev/null +++ b/tests/unit/test_dependency_resolver.py @@ -0,0 +1,372 @@ +"""Unit tests for dependency_resolver module. + +Tests detection, stub generation, source assembly, and async resolution +of @remote function dependencies for stacked execution. +""" + +import ast +import textwrap +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from runpod_flash.stubs.dependency_resolver import ( + RemoteDependency, + build_augmented_source, + detect_remote_dependencies, + generate_stub_code, + resolve_dependencies, +) + + +# --------------------------------------------------------------------------- +# Helpers: fake @remote-decorated functions for detection tests +# --------------------------------------------------------------------------- + + +def _make_remote_func(name: str, source: str, resource_config=None): + """Create a fake function with __remote_config__ to simulate @remote.""" + ns: dict = {} + exec(compile(source, "", "exec"), ns) + func = ns[name] + func.__remote_config__ = { + "resource_config": resource_config or MagicMock(name=name), + "dependencies": ["numpy"], + "system_dependencies": [], + } + return func + + +# Shared globals dict simulating a module where both funcA and funcB live +_shared_globals: dict = {} + +_funcB_source = textwrap.dedent("""\ +async def funcB(param: dict) -> dict: + return {"result": param} +""") + +_funcB = _make_remote_func("funcB", _funcB_source) +_shared_globals["funcB"] = _funcB + +_funcC_source = textwrap.dedent("""\ +async def funcC(x: int) -> int: + return x + 1 +""") +_funcC = _make_remote_func("funcC", _funcC_source) +_shared_globals["funcC"] = _funcC + + +def _plain_helper(x): + """A plain function — no __remote_config__.""" + return x + + +_shared_globals["_plain_helper"] = _plain_helper + + +# funcA calls funcB (a @remote function) and _plain_helper (not @remote) +_funcA_source = textwrap.dedent("""\ +async def funcA(foo: str) -> dict: + payload = _plain_helper(foo) + return await funcB(payload) +""") + + +# funcD calls both funcB and funcC +_funcD_source = textwrap.dedent("""\ +async def funcD(data: str) -> dict: + b = await funcB({"key": data}) + c = await funcC(42) + return {"b": b, "c": c} +""") + + +# funcE calls nothing remote +_funcE_source = textwrap.dedent("""\ +async def funcE(x: int) -> int: + return x * 2 +""") + + +# funcF calls funcB via attribute (indirect — should NOT be detected) +_funcF_source = textwrap.dedent("""\ +async def funcF(x: int) -> int: + import somemodule + return somemodule.funcB(x) +""") + + +# --------------------------------------------------------------------------- +# Tests: detect_remote_dependencies +# --------------------------------------------------------------------------- + + +class TestDetectRemoteDependencies: + def test_detects_single_remote_dependency(self): + result = detect_remote_dependencies(_funcA_source, _shared_globals) + assert result == ["funcB"] + + def test_detects_multiple_remote_dependencies(self): + result = detect_remote_dependencies(_funcD_source, _shared_globals) + assert sorted(result) == ["funcB", "funcC"] + + def test_no_remote_dependencies(self): + result = detect_remote_dependencies(_funcE_source, _shared_globals) + assert result == [] + + def test_ignores_plain_helpers(self): + result = detect_remote_dependencies(_funcA_source, _shared_globals) + assert "_plain_helper" not in result + + def test_ignores_builtins(self): + source = textwrap.dedent("""\ + async def funcX(x: int) -> str: + return str(len([x])) + """) + result = detect_remote_dependencies(source, _shared_globals) + assert result == [] + + def test_ignores_attribute_calls(self): + """Only ast.Name calls are detected, not ast.Attribute calls.""" + result = detect_remote_dependencies(_funcF_source, _shared_globals) + assert "funcB" not in result + + def test_ignores_names_not_in_globals(self): + source = textwrap.dedent("""\ + async def funcX(x: int) -> int: + return unknown_func(x) + """) + result = detect_remote_dependencies(source, _shared_globals) + assert result == [] + + +# --------------------------------------------------------------------------- +# Tests: generate_stub_code +# --------------------------------------------------------------------------- + + +class TestGenerateStubCode: + def _make_dep(self, name="funcB", endpoint_id="ep-123", source=None): + return RemoteDependency( + name=name, + endpoint_id=endpoint_id, + source=source or _funcB_source, + dependencies=["numpy"], + system_dependencies=[], + ) + + def test_generates_valid_python(self): + dep = self._make_dep() + code = generate_stub_code(dep) + # Must compile without errors + compile(code, "", "exec") + + def test_stub_defines_correct_function_name(self): + dep = self._make_dep(name="funcB") + code = generate_stub_code(dep) + tree = ast.parse(code) + func_names = [ + node.name + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + assert "funcB" in func_names + + def test_stub_is_async(self): + dep = self._make_dep() + code = generate_stub_code(dep) + tree = ast.parse(code) + async_funcs = [ + node for node in ast.walk(tree) if isinstance(node, ast.AsyncFunctionDef) + ] + assert len(async_funcs) >= 1 + + def test_endpoint_id_embedded(self): + dep = self._make_dep(endpoint_id="ep-abc-999") + code = generate_stub_code(dep) + assert "ep-abc-999" in code + + def test_function_source_embedded(self): + dep = self._make_dep() + code = generate_stub_code(dep) + # The original source should appear somewhere in the stub (as a string) + assert "funcB" in code + + def test_preserves_original_signature(self): + """Stub should accept same params as original function.""" + dep = self._make_dep() + code = generate_stub_code(dep) + # The stub for funcB(param: dict) should have 'param' in its signature + assert "param" in code + + def test_handles_multi_param_function(self): + multi_src = textwrap.dedent("""\ + async def multi(a: int, b: str, c: float = 1.0) -> dict: + return {"a": a, "b": b, "c": c} + """) + dep = self._make_dep(name="multi", source=multi_src) + code = generate_stub_code(dep) + compile(code, "", "exec") + assert "multi" in code + + def test_handles_triple_quotes_in_source(self): + """Source with triple-quoted docstrings should be safely escaped.""" + src_with_docs = textwrap.dedent('''\ + async def documented(x: int) -> int: + """Process x with triple-quoted docstring.""" + return x + ''') + dep = self._make_dep(name="documented", source=src_with_docs) + code = generate_stub_code(dep) + compile(code, "", "exec") + + +# --------------------------------------------------------------------------- +# Tests: build_augmented_source +# --------------------------------------------------------------------------- + + +class TestBuildAugmentedSource: + def test_no_stubs_returns_original(self): + original = "async def funcA(x): return x\n" + result = build_augmented_source(original, []) + assert result == original + + def test_stubs_prepended_before_original(self): + original = "async def funcA(x): return x\n" + stub = "async def funcB(y): return y\n" + result = build_augmented_source(original, [stub]) + # stub should appear before original + assert result.index("funcB") < result.index("funcA") + + def test_augmented_source_is_valid_python(self): + original = textwrap.dedent("""\ + async def funcA(foo: str) -> dict: + return await funcB(foo) + """) + stub = textwrap.dedent("""\ + async def funcB(param: dict) -> dict: + return {"stub": True} + """) + result = build_augmented_source(original, [stub]) + compile(result, "", "exec") + + def test_multiple_stubs_prepended(self): + original = "async def funcA(x): return x\n" + stubs = [ + "async def funcB(y): return y\n", + "async def funcC(z): return z\n", + ] + result = build_augmented_source(original, stubs) + assert "funcB" in result + assert "funcC" in result + assert result.index("funcB") < result.index("funcA") + assert result.index("funcC") < result.index("funcA") + + +# --------------------------------------------------------------------------- +# Tests: resolve_dependencies (async, mocked ResourceManager) +# --------------------------------------------------------------------------- + + +class TestResolveDependencies: + """Tests for resolve_dependencies with mocked ResourceManager and get_function_source.""" + + def _patch_resolve(self, mock_rm): + """Return combined patch context for ResourceManager and get_function_source.""" + return ( + patch( + "runpod_flash.core.resources.ResourceManager", + return_value=mock_rm, + ), + patch( + "runpod_flash.stubs.dependency_resolver.get_function_source", + side_effect=lambda func: ( + f"async def {func.__name__}(): pass\n", + "hash", + ), + ), + ) + + @pytest.mark.asyncio + async def test_resolves_single_dependency(self): + mock_resource = MagicMock() + mock_resource.id = "ep-resolved-123" + + mock_rm = MagicMock() + mock_rm.get_or_deploy_resource = AsyncMock(return_value=mock_resource) + + rm_patch, gfs_patch = self._patch_resolve(mock_rm) + with rm_patch, gfs_patch: + deps = await resolve_dependencies(_funcA_source, _shared_globals) + + assert len(deps) == 1 + assert deps[0].name == "funcB" + assert deps[0].endpoint_id == "ep-resolved-123" + + @pytest.mark.asyncio + async def test_resolves_multiple_dependencies(self): + mock_resource_b = MagicMock() + mock_resource_b.id = "ep-b" + mock_resource_c = MagicMock() + mock_resource_c.id = "ep-c" + + async def mock_deploy(config): + if config is _funcB.__remote_config__["resource_config"]: + return mock_resource_b + return mock_resource_c + + mock_rm = MagicMock() + mock_rm.get_or_deploy_resource = AsyncMock(side_effect=mock_deploy) + + rm_patch, gfs_patch = self._patch_resolve(mock_rm) + with rm_patch, gfs_patch: + deps = await resolve_dependencies(_funcD_source, _shared_globals) + + assert len(deps) == 2 + names = {d.name for d in deps} + assert names == {"funcB", "funcC"} + + @pytest.mark.asyncio + async def test_no_dependencies_returns_empty(self): + deps = await resolve_dependencies(_funcE_source, _shared_globals) + assert deps == [] + + @pytest.mark.asyncio + async def test_provisioning_failure_raises(self): + mock_rm = MagicMock() + mock_rm.get_or_deploy_resource = AsyncMock( + side_effect=RuntimeError("deploy failed") + ) + + rm_patch, gfs_patch = self._patch_resolve(mock_rm) + with rm_patch, gfs_patch: + with pytest.raises(RuntimeError, match="deploy failed"): + await resolve_dependencies(_funcA_source, _shared_globals) + + +# --------------------------------------------------------------------------- +# Tests: exec() integration — verify augmented source works at runtime +# --------------------------------------------------------------------------- + + +class TestExecIntegration: + def test_exec_augmented_source_defines_both_functions(self): + """When we exec() augmented source, both funcA and the funcB stub exist.""" + dep = RemoteDependency( + name="funcB", + endpoint_id="ep-test", + source=_funcB_source, + dependencies=[], + system_dependencies=[], + ) + stub_code = generate_stub_code(dep) + augmented = build_augmented_source(_funcA_source, [stub_code]) + + namespace: dict = {"_plain_helper": lambda x: x} + exec(compile(augmented, "", "exec"), namespace) + + assert "funcA" in namespace + assert "funcB" in namespace + assert callable(namespace["funcA"]) + assert callable(namespace["funcB"]) diff --git a/tests/unit/test_load_balancer_sls_stub.py b/tests/unit/test_load_balancer_sls_stub.py index 206f8eea..7bb4bd5e 100644 --- a/tests/unit/test_load_balancer_sls_stub.py +++ b/tests/unit/test_load_balancer_sls_stub.py @@ -20,14 +20,15 @@ class TestLoadBalancerSlsStubPrepareRequest: """Test suite for _prepare_request method.""" - def test_prepare_request_with_no_args(self): + @pytest.mark.asyncio + async def test_prepare_request_with_no_args(self): """Test request preparation with no arguments.""" stub = LoadBalancerSlsStub(test_lb_resource) def test_func(): return "result" - request = stub._prepare_request(test_func, None, None, True) + request = await stub._prepare_request(test_func, None, None, True) assert request["function_name"] == "test_func" assert "def test_func" in request["function_code"] @@ -37,7 +38,8 @@ def test_func(): assert "args" not in request or request["args"] == [] assert "kwargs" not in request or request["kwargs"] == {} - def test_prepare_request_with_args(self): + @pytest.mark.asyncio + async def test_prepare_request_with_args(self): """Test request preparation with positional arguments.""" stub = LoadBalancerSlsStub(test_lb_resource) @@ -46,7 +48,7 @@ def add(x, y): arg1 = 5 arg2 = 3 - request = stub._prepare_request(add, None, None, True, arg1, arg2) + request = await stub._prepare_request(add, None, None, True, arg1, arg2) assert request["function_name"] == "add" assert len(request["args"]) == 2 @@ -57,14 +59,15 @@ def add(x, y): assert decoded_arg1 == 5 assert decoded_arg2 == 3 - def test_prepare_request_with_kwargs(self): + @pytest.mark.asyncio + async def test_prepare_request_with_kwargs(self): """Test request preparation with keyword arguments.""" stub = LoadBalancerSlsStub(test_lb_resource) def greet(name, greeting="Hello"): return f"{greeting}, {name}!" - request = stub._prepare_request( + request = await stub._prepare_request( greet, None, None, True, name="Alice", greeting="Hi" ) @@ -79,7 +82,8 @@ def greet(name, greeting="Hello"): assert decoded_name == "Alice" assert decoded_greeting == "Hi" - def test_prepare_request_with_dependencies(self): + @pytest.mark.asyncio + async def test_prepare_request_with_dependencies(self): """Test request preparation includes dependencies.""" stub = LoadBalancerSlsStub(test_lb_resource) @@ -89,7 +93,9 @@ def test_func(): dependencies = ["requests", "numpy"] system_deps = ["git"] - request = stub._prepare_request(test_func, dependencies, system_deps, True) + request = await stub._prepare_request( + test_func, dependencies, system_deps, True + ) assert request["dependencies"] == dependencies assert request["system_dependencies"] == system_deps diff --git a/tests/unit/test_skeleton.py b/tests/unit/test_skeleton.py index 0c3962b4..ea788bd6 100644 --- a/tests/unit/test_skeleton.py +++ b/tests/unit/test_skeleton.py @@ -85,13 +85,13 @@ def test_detect_no_conflicts_empty_directory(self, tmp_path): def test_detect_conflict_with_existing_file(self, tmp_path): """Test that existing files are detected as conflicts.""" # Create a file that exists in the template - (tmp_path / "main.py").write_text("# existing file") + (tmp_path / "gpu_worker.py").write_text("# existing file") conflicts = detect_file_conflicts(tmp_path) - # Should detect main.py as a conflict + # Should detect gpu_worker.py as a conflict conflict_names = [str(c) for c in conflicts] - assert "main.py" in conflict_names + assert "gpu_worker.py" in conflict_names def test_detect_conflict_with_hidden_file(self, tmp_path): """Test that existing hidden files are detected as conflicts.""" @@ -138,7 +138,9 @@ def test_create_skeleton_in_empty_directory(self, tmp_path): assert len(created_files) > 0 # Check that key files exist - assert (tmp_path / "main.py").exists() + assert (tmp_path / "gpu_worker.py").exists() + assert (tmp_path / "cpu_worker.py").exists() + assert (tmp_path / "lb_worker.py").exists() assert (tmp_path / "README.md").exists() assert (tmp_path / "requirements.txt").exists() @@ -147,13 +149,6 @@ def test_create_skeleton_in_empty_directory(self, tmp_path): assert (tmp_path / ".gitignore").exists() assert (tmp_path / ".flashignore").exists() - # Check that workers directory structure exists - assert (tmp_path / "workers").is_dir() - assert (tmp_path / "workers" / "cpu").is_dir() - assert (tmp_path / "workers" / "gpu").is_dir() - assert (tmp_path / "workers" / "cpu" / "__init__.py").exists() - assert (tmp_path / "workers" / "gpu" / "__init__.py").exists() - def test_create_skeleton_with_project_name_substitution(self, tmp_path): """Test that {{project_name}} placeholder is replaced.""" project_dir = tmp_path / "my_test_project" @@ -169,14 +164,14 @@ def test_create_skeleton_with_project_name_substitution(self, tmp_path): def test_create_skeleton_skips_existing_files_without_force(self, tmp_path): """Test that existing files are not overwritten without force flag.""" # Create an existing file with specific content - existing_content = "# This is my custom main.py" - (tmp_path / "main.py").write_text(existing_content) + existing_content = "# This is my custom gpu_worker.py" + (tmp_path / "gpu_worker.py").write_text(existing_content) # Create skeleton without force create_project_skeleton(tmp_path, force=False) # Existing file should not be overwritten - assert (tmp_path / "main.py").read_text() == existing_content + assert (tmp_path / "gpu_worker.py").read_text() == existing_content # But other files should be created assert (tmp_path / ".env.example").exists() @@ -184,16 +179,16 @@ def test_create_skeleton_skips_existing_files_without_force(self, tmp_path): def test_create_skeleton_overwrites_with_force(self, tmp_path): """Test that existing files are overwritten with force=True.""" # Create an existing file - existing_content = "# This is my custom main.py" - (tmp_path / "main.py").write_text(existing_content) + existing_content = "# This is my custom gpu_worker.py" + (tmp_path / "gpu_worker.py").write_text(existing_content) # Create skeleton with force create_project_skeleton(tmp_path, force=True) # Existing file should be overwritten - new_content = (tmp_path / "main.py").read_text() + new_content = (tmp_path / "gpu_worker.py").read_text() assert new_content != existing_content - assert "# This is my custom main.py" not in new_content + assert "# This is my custom gpu_worker.py" not in new_content def test_create_skeleton_ignores_pycache(self, tmp_path): """Test that __pycache__ directories are not copied.""" @@ -225,7 +220,7 @@ def test_create_skeleton_creates_parent_directories(self, tmp_path): # All parent directories should exist assert project_dir.exists() - assert (project_dir / "main.py").exists() + assert (project_dir / "gpu_worker.py").exists() def test_create_skeleton_returns_created_files_list(self, tmp_path): """Test that function returns list of created files.""" @@ -236,14 +231,14 @@ def test_create_skeleton_returns_created_files_list(self, tmp_path): assert all(isinstance(f, str) for f in created_files) # Should contain expected files - assert "main.py" in created_files + assert "gpu_worker.py" in created_files assert ".env.example" in created_files assert "README.md" in created_files def test_create_skeleton_handles_readonly_files_gracefully(self, tmp_path): """Test handling of read-only files during creation.""" # Create a read-only file - readonly_file = tmp_path / "main.py" + readonly_file = tmp_path / "gpu_worker.py" readonly_file.write_text("# readonly") readonly_file.chmod(0o444) @@ -287,7 +282,9 @@ def test_full_init_workflow_in_place(self, tmp_path): # Verify all expected files exist expected_files = [ - "main.py", + "gpu_worker.py", + "cpu_worker.py", + "lb_worker.py", "README.md", "requirements.txt", ".env.example", @@ -297,14 +294,10 @@ def test_full_init_workflow_in_place(self, tmp_path): for filename in expected_files: assert (tmp_path / filename).exists(), f"{filename} should exist" - # Verify workers structure - assert (tmp_path / "workers" / "cpu" / "endpoint.py").exists() - assert (tmp_path / "workers" / "gpu" / "endpoint.py").exists() - def test_full_init_workflow_with_conflicts(self, tmp_path): """Test complete workflow when conflicts exist.""" # Create some existing files - (tmp_path / "main.py").write_text("# my custom main") + (tmp_path / "gpu_worker.py").write_text("# my custom worker") (tmp_path / ".env.example").write_text("MY_VAR=123") # Detect conflicts @@ -312,14 +305,14 @@ def test_full_init_workflow_with_conflicts(self, tmp_path): assert len(conflicts) == 2 conflict_names = [str(c) for c in conflicts] - assert "main.py" in conflict_names + assert "gpu_worker.py" in conflict_names assert ".env.example" in conflict_names # Create skeleton without force (should preserve existing) create_project_skeleton(tmp_path, force=False) # Check that existing files were preserved - assert (tmp_path / "main.py").read_text() == "# my custom main" + assert (tmp_path / "gpu_worker.py").read_text() == "# my custom worker" assert (tmp_path / ".env.example").read_text() == "MY_VAR=123" # But new files should be created