diff --git a/README.md b/README.md index b820939..dbbd4f4 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,14 @@ edgewalker cve # Check for known CVEs edgewalker report # View security report ``` +### CI/CD & Automation +EdgeWalker supports non-interactive execution for automated environments: +```bash +# Run a silent scan with explicit telemetry opt-in +edgewalker --silent --accept-telemetry scan --target 192.168.1.0/24 +``` +See the [Configuration Guide](docs/configuration.md#non-interactive-silent-mode) for more details. + --- ## The Periphery Mission diff --git a/docs/configuration.md b/docs/configuration.md index c0ad082..89f38a1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -44,10 +44,34 @@ Environment variables prefixed with `EW_` override all settings. `edgewalker/con | `EW_THEME` | `periphery` | Active UI theme slug | | `EW_IOT_PORTS` | `[21, 22, ...]` | Common IoT ports for quick scan | | `EW_TELEMETRY_ENABLED` | `None` | User opt-in status for anonymous data sharing | +| `EW_SILENT_MODE` | `False` | Run in non-interactive mode (bypass prompts) | +| `EW_SUPPRESS_WARNINGS` | `False` | Suppress configuration and security warnings in the console | | `EW_CONFIG_DIR` | `~/.config/edgewalker` | Configuration directory override | | `EW_CACHE_DIR` | `~/.cache/edgewalker` | Cache directory override | | `EW_DEMO_MODE` | `0` | Set to `1` to enable demo mode with mock data | +## Non-Interactive (Silent) Mode + +For CI/CD pipelines and automated environments, EdgeWalker provides a non-interactive mode that bypasses all user prompts. + +### Global Flags + +These flags can be used with any command: + +- `--silent` or `-s`: Enables non-interactive mode. +- `--suppress-warnings`: Hides configuration override panels and security warnings from the console. +- `--accept-telemetry`: Explicitly opts-in to anonymous telemetry (required in silent mode if no preference is set). +- `--decline-telemetry`: Explicitly opts-out of anonymous telemetry (required in silent mode if no preference is set). + +### CI/CD Usage + +When running in a fresh environment (like a GitHub Action), you must provide a telemetry choice if you use `--silent`. If no choice is provided, the CLI will exit with an error to ensure an explicit decision is made. + +```bash +# Run a scan in CI/CD without any prompts +edgewalker --silent --suppress-warnings --accept-telemetry scan --target 192.168.1.0/24 +``` + ## Security Validation EdgeWalker enforces security best practices for its configuration: diff --git a/docs/data-privacy.md b/docs/data-privacy.md index 1550b53..626a15e 100644 --- a/docs/data-privacy.md +++ b/docs/data-privacy.md @@ -57,6 +57,7 @@ The findings feed back into improving EdgeWalker's credential database and infor ### How to Opt Out - **During first run**: Select "No thanks" when prompted. +- **In Silent Mode**: Use the `--decline-telemetry` flag. - **After opting in**: Opt out via the TUI settings menu or by deleting the configuration file: ```bash # On macOS: @@ -65,6 +66,15 @@ The findings feed back into improving EdgeWalker's credential database and infor rm ~/.config/edgewalker/config.yaml ``` +### Non-Interactive (Silent) Mode Telemetry + +When running EdgeWalker in automated environments (CI/CD) using the `--silent` flag, the tool requires an explicit telemetry choice if one has not been previously set. This ensures that data sharing is never enabled by default without a conscious decision. + +- Use `--accept-telemetry` to opt-in. +- Use `--decline-telemetry` to opt-out. + +If neither flag is provided in silent mode on a fresh installation, EdgeWalker will exit with an error. + ### Server-Side Security The data collection API features hardening: diff --git a/pyproject.toml b/pyproject.toml index ad53c38..b6bedfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,4 +133,4 @@ color = true [tool.pytest.ini_options] pythonpath = ["src"] -addopts = "--cov=src --cov-report=term-missing --cov-fail-under=85" +addopts = "--cov=src --cov-report=term-missing --cov-fail-under=90" diff --git a/src/edgewalker/cli/cli.py b/src/edgewalker/cli/cli.py index b3c9280..7a4549b 100644 --- a/src/edgewalker/cli/cli.py +++ b/src/edgewalker/cli/cli.py @@ -181,7 +181,7 @@ def run_guided_scan( overrides = get_active_overrides() if (security_warnings or overrides) and not allow_override: - if security_warnings: + if security_warnings and not settings.suppress_warnings: console.print( f"\n[bold {theme.RISK_CRITICAL}]SECURITY WARNING: " f"Non-standard or insecure API endpoints detected![/bold {theme.RISK_CRITICAL}]" @@ -193,7 +193,7 @@ def run_guided_scan( "sensitive data like API keys.[/dim]" ) - if overrides: + if overrides and not settings.suppress_warnings: sources = ", ".join(sorted(set(overrides.values()))) console.print( f"\n[bold {theme.WARNING}]CONFIGURATION OVERRIDES ACTIVE " @@ -205,14 +205,21 @@ def run_guided_scan( "\n[dim]These settings will take precedence over your config.yaml file.[/dim]" ) - console.print("") - confirm = typer.confirm("Do you want to proceed with the scan using these settings?") - if not confirm: + if not settings.silent_mode: + console.print("") + confirm = typer.confirm("Do you want to proceed with the scan using these settings?") + if not confirm: + console.print( + "\n[dim]Scan cancelled. Use [bold]--allow-override[/bold] or " + "[bold]-ao[/bold] to bypass this check.[/dim]" + ) + raise typer.Exit() + elif (security_warnings or overrides) and not settings.suppress_warnings: + console.print("") console.print( - "\n[dim]Scan cancelled. Use [bold]--allow-override[/bold] or " - "[bold]-ao[/bold] to bypass this check.[/dim]" + f"[{theme.WARNING}]Silent mode active: proceeding with scan " + f"despite security warnings.[/{theme.WARNING}]" ) - raise typer.Exit() ensure_telemetry_choice() controller = ScanController() @@ -371,8 +378,39 @@ def main( log_file: Optional[str] = typer.Option( None, "--log-file", help="Path to write logs to a file." ), + silent: bool = typer.Option( + False, + "--silent", + "-s", + help="Run in non-interactive mode (bypass prompts).", + ), + suppress_warnings: bool = typer.Option( + False, + "--suppress-warnings", + help="Suppress configuration and security warnings in the console.", + ), + accept_telemetry: bool = typer.Option( + False, + "--accept-telemetry", + help="Explicitly opt-in to telemetry (used in silent mode).", + ), + decline_telemetry: bool = typer.Option( + False, + "--decline-telemetry", + help="Explicitly opt-out of telemetry (used in silent mode).", + ), ) -> None: """EdgeWalker - IoT Home Network Security Scanner.""" + # Update settings with global flags + if silent: + update_setting("silent_mode", True) + if suppress_warnings: + update_setting("suppress_warnings", True) + if accept_telemetry: + update_setting("accept_telemetry", True) + if decline_telemetry: + update_setting("decline_telemetry", True) + # Configure logging using the Typer options setup_logging(verbosity=verbose, log_file=log_file) diff --git a/src/edgewalker/core/config.py b/src/edgewalker/core/config.py index a7904c1..ac13f89 100644 --- a/src/edgewalker/core/config.py +++ b/src/edgewalker/core/config.py @@ -303,6 +303,28 @@ def handle_demo_mode(cls, v: Path) -> Path: description="User opt-in status for anonymous data sharing", ) + silent_mode: bool = Field( + default=False, + description="Run in non-interactive mode (bypass prompts)", + ) + + suppress_warnings: bool = Field( + default=False, + description="Suppress configuration and security warnings in the console", + ) + + accept_telemetry: bool = Field( + default=False, + description="Explicitly opt-in to telemetry (used in silent mode)", + exclude=True, + ) + + decline_telemetry: bool = Field( + default=False, + description="Explicitly opt-out of telemetry (used in silent mode)", + exclude=True, + ) + theme: str = Field( default="periphery", description="Active theme slug", @@ -319,8 +341,8 @@ def get_security_warnings(self) -> list[str]: Returns: A list of warning messages. """ - # Skip security warnings during tests to avoid confirmation prompts - if os.environ.get("PYTEST_CURRENT_TEST"): + # Skip security warnings during tests or if suppressed + if os.environ.get("PYTEST_CURRENT_TEST") or self.suppress_warnings: return [] warnings = [] @@ -417,7 +439,7 @@ def get_active_overrides() -> dict[str, str]: ('environment variable' or '.env file'). """ # Skip overrides during tests to ensure consistent behavior - if os.environ.get("PYTEST_CURRENT_TEST"): + if os.environ.get("PYTEST_CURRENT_TEST") and not os.environ.get("EW_ALLOW_OVERRIDES_IN_TESTS"): return {} overrides = {} @@ -438,7 +460,7 @@ def get_active_overrides() -> dict[str, str]: # Check environment variables (higher precedence) for key in os.environ: - if key.startswith("EW_"): + if key.startswith("EW_") and key != "EW_ALLOW_OVERRIDES_IN_TESTS": overrides[key] = "environment variable" return overrides diff --git a/src/edgewalker/modules/cve_scan/scanner.py b/src/edgewalker/modules/cve_scan/scanner.py index 9370cf2..8cc85ff 100644 --- a/src/edgewalker/modules/cve_scan/scanner.py +++ b/src/edgewalker/modules/cve_scan/scanner.py @@ -12,6 +12,7 @@ # Third Party import httpx +from loguru import logger # First Party from edgewalker import __version__, utils @@ -46,25 +47,32 @@ async def search_cves_async( async with semaphore: try: - if verbose: - # Use logger or print only if not using rich progress - pass - + logger.debug(f"Searching NVD for: {params['keywordSearch']}") response = await client.get( settings.nvd_api_url, params=params, headers=headers, timeout=30 ) + logger.debug(f"NVD Response: {response.status_code}") + if response.status_code == 403: # Rate limit hit, wait and retry once + logger.warning( + f"NVD Rate limit hit (403). Waiting {settings.nvd_rate_limit_delay * 2}s..." + ) await asyncio.sleep(settings.nvd_rate_limit_delay * 2) response = await client.get( settings.nvd_api_url, params=params, headers=headers, timeout=30 ) + logger.debug(f"NVD Retry Response: {response.status_code}") if response.status_code != 200: + logger.error(f"NVD API error: {response.status_code} - {response.text[:200]}") return [] data = response.json() vulnerabilities = data.get("vulnerabilities", []) + logger.debug( + f"Found {len(vulnerabilities)} vulnerabilities for {params['keywordSearch']}" + ) cves = [] for vuln in vulnerabilities: @@ -94,7 +102,8 @@ async def search_cves_async( "score": base_score, }) return cves - except Exception: + except Exception as e: + logger.error(f"Error searching CVEs for {product}: {e}") return [] @@ -224,6 +233,7 @@ async def _scan_service( rich_progress: Optional[tuple[utils.Progress, utils.TaskID]] = None, ) -> CveScanResultModel: """Scan a single service for CVEs asynchronously.""" + logger.debug(f"Checking {svc['product']} {svc['version']} on {svc['ip']}:{svc['port']}") if self.progress_callback: self.progress_callback( "cve_check", @@ -266,41 +276,6 @@ async def _scan_service( version=svc["version"], cves=cves, ) - """Scan a single service for CVEs asynchronously.""" - if self.progress_callback: - self.progress_callback( - "cve_check", - f"Checking {svc['product']} {svc['version']} on {svc['ip']}:{svc['port']}", - ) - - cve_dicts = await search_cves_async( - client, svc["product"], svc["version"], self.verbose, semaphore - ) - - cves = [ - CveModel( - id=c["id"], - description=c.get("description", ""), - severity=c["severity"], - score=c["score"], - ) - for c in cve_dicts - ] - - if cves and self.progress_callback: - self.progress_callback( - "cve_found", - f"{len(cves)} CVE(s) found for {svc['product']} {svc['version']}", - ) - - return CveScanResultModel( - ip=svc["ip"], - port=svc["port"], - service=svc["service"], - product=svc["product"], - version=svc["version"], - cves=cves, - ) def _build_empty_model(self, skipped_no_version: int) -> CveScanModel: return CveScanModel( diff --git a/src/edgewalker/modules/mac_lookup/scanner.py b/src/edgewalker/modules/mac_lookup/scanner.py index 9473519..75a2832 100644 --- a/src/edgewalker/modules/mac_lookup/scanner.py +++ b/src/edgewalker/modules/mac_lookup/scanner.py @@ -14,6 +14,7 @@ # Third Party import httpx +from loguru import logger # First Party from edgewalker.core.config import settings @@ -66,18 +67,22 @@ def _lookup_mac_api(mac: str) -> dict | None: params["apiKey"] = settings.mac_api_key try: + logger.debug(f"Looking up MAC: {mac} via API") with httpx.Client() as client: resp = client.get( f"{settings.mac_api_url}/{mac}", params=params, timeout=settings.api_timeout, ) - except Exception: + logger.debug(f"MAC API Response: {resp.status_code}") + except Exception as e: + logger.error(f"MAC API request failed for {mac}: {e}") return None if resp.status_code == 429: # Rate limited - respect Retry-After header retry_after = float(resp.headers.get("Retry-After", "1")) + logger.warning(f"MAC API Rate limit hit (429). Retrying after {retry_after}s...") time.sleep(retry_after) try: with httpx.Client() as client: @@ -86,12 +91,15 @@ def _lookup_mac_api(mac: str) -> dict | None: params=params, timeout=settings.api_timeout, ) - except Exception: + logger.debug(f"MAC API Retry Response: {resp.status_code}") + except Exception as e: + logger.error(f"MAC API retry failed for {mac}: {e}") return None if resp.status_code == 200: return resp.json() + logger.error(f"MAC API error: {resp.status_code} - {resp.text[:200]}") return None @@ -124,6 +132,7 @@ def _get_csv_vendors() -> dict: def _csv_fallback_vendor(normalized: str) -> str: """Look up vendor from local CSV fallback.""" + logger.debug(f"Falling back to local CSV for MAC: {normalized}") vendors = _get_csv_vendors() if len(normalized) < 6: diff --git a/src/edgewalker/modules/password_scan/scanner.py b/src/edgewalker/modules/password_scan/scanner.py index 28f7ee3..ae2cb6e 100644 --- a/src/edgewalker/modules/password_scan/scanner.py +++ b/src/edgewalker/modules/password_scan/scanner.py @@ -21,6 +21,7 @@ # Third Party import asyncssh from impacket.smbconnection import SMBConnection +from loguru import logger # First Party from edgewalker import __version__, theme, utils @@ -163,6 +164,10 @@ async def scan(self) -> PasswordScanResultModel: login_status = StatusEnum.failed for i, (user, pw) in enumerate(creds): + logger.debug( + f"Attempting {self.service_name().upper()} login on " + f"{self.ip}:{self.port} with {user}:{pw}" + ) if self.rich_progress: progress, task_id = self.rich_progress progress.update( @@ -188,6 +193,10 @@ async def scan(self) -> PasswordScanResultModel: result, kill_loop = StatusEnum.failed, False if result is True: + logger.success( + f"SUCCESS: {self.service_name().upper()} login on " + f"{self.ip}:{self.port} with {user}:{pw}" + ) login_status = StatusEnum.successful found_cred = CredentialsModel(user=user, password=pw) if self.verbose and not self.rich_progress: @@ -204,10 +213,22 @@ async def scan(self) -> PasswordScanResultModel: ) break elif result == StatusEnum.ratelimit: + logger.warning( + f"RATE LIMITED: {self.service_name().upper()} on {self.ip}:{self.port}" + ) login_status = StatusEnum.ratelimit break + else: + logger.debug( + f"FAILED: {self.service_name().upper()} login on " + f"{self.ip}:{self.port} with {user}:{pw}" + ) if kill_loop: + logger.debug( + f"KILL LOOP: Stopping {self.service_name().upper()} scan " + f"on {self.ip}:{self.port}" + ) break if self.rich_progress: @@ -247,7 +268,8 @@ async def attempt_login(self, username: str, password: str) -> tuple[bool, bool] login_timeout=settings.conn_timeout, ): return True, False - except Exception: + except Exception as e: + logger.debug(f"SSH login error for {username}@{self.ip}:{self.port}: {e}") return False, False @@ -273,7 +295,8 @@ def _ftp_login() -> bool: return True return await asyncio.to_thread(_ftp_login), False - except Exception: + except Exception as e: + logger.debug(f"FTP login error for {username}@{self.ip}:{self.port}: {e}") return False, False @@ -303,20 +326,33 @@ async def _read_until(patterns: list[bytes]) -> tuple[int, bytes]: reader.read(1024), timeout=settings.conn_timeout ) except asyncio.TimeoutError: + logger.debug( + f"Telnet read timeout on {self.ip}:{self.port}. Buffer: {buf!r}" + ) return -1, buf if not data: + logger.debug( + f"Telnet connection closed by {self.ip}:{self.port}. Buffer: {buf!r}" + ) return -1, buf buf += data + logger.debug(f"Telnet raw data from {self.ip}:{self.port}: {data!r}") for i, p in enumerate(patterns): if p in buf: return i, buf if len(buf) > 4096: + logger.debug( + f"Telnet buffer overflow on {self.ip}:{self.port}. Buffer: {buf!r}" + ) return -1, buf # Wait for login prompt - idx, _ = await _read_until([b"login:", b"user:", b"Username:"]) + idx, buf = await _read_until([b"login:", b"user:", b"Username:"]) if idx == -1: + logger.debug( + f"Telnet login prompt not found on {self.ip}:{self.port}. Buffer: {buf!r}" + ) writer.close() await writer.wait_closed() return False, False @@ -325,8 +361,11 @@ async def _read_until(patterns: list[bytes]) -> tuple[int, bytes]: await writer.drain() # Wait for password prompt - idx, _ = await _read_until([b"password:", b"Password:"]) + idx, buf = await _read_until([b"password:", b"Password:"]) if idx == -1: + logger.debug( + f"Telnet password prompt not found on {self.ip}:{self.port}. Buffer: {buf!r}" + ) writer.close() await writer.wait_closed() return False, False @@ -335,13 +374,16 @@ async def _read_until(patterns: list[bytes]) -> tuple[int, bytes]: await writer.drain() # Check for success - idx, _ = await _read_until([b"Welcome", b"$", b"#", b">", b"Login incorrect"]) + idx, buf = await _read_until([b"Welcome", b"$", b"#", b">", b"Login incorrect"]) success = idx != -1 and idx != 4 + if not success: + logger.debug(f"Telnet login failed on {self.ip}:{self.port}. Buffer: {buf!r}") writer.close() await writer.wait_closed() return success, False - except Exception: + except Exception as e: + logger.debug(f"Telnet login error for {username}@{self.ip}:{self.port}: {e}") return False, False @@ -367,7 +409,8 @@ def _smb_login() -> bool: return True return await asyncio.to_thread(_smb_login), False - except Exception: + except Exception as e: + logger.debug(f"SMB login error for {username}@{self.ip}:{self.port}: {e}") return False, False diff --git a/src/edgewalker/modules/port_scan/scanner.py b/src/edgewalker/modules/port_scan/scanner.py index e7890fe..cf2c67d 100644 --- a/src/edgewalker/modules/port_scan/scanner.py +++ b/src/edgewalker/modules/port_scan/scanner.py @@ -21,6 +21,7 @@ # Third Party import validators +from loguru import logger # First Party from edgewalker import __version__, theme, utils @@ -159,10 +160,14 @@ def _chunk_hosts(hosts: list[str], n: int) -> list[list[str]]: def parse_nmap_xml(xml_output: str) -> list[Host]: """Parse nmap XML output into host list.""" hosts = [] + if not xml_output: + logger.debug("Empty XML output from nmap") + return hosts try: # nosec: B314 - nmap output is trusted in this context root = ET.fromstring(xml_output) - except ET.ParseError: + except ET.ParseError as e: + logger.error(f"Failed to parse nmap XML output: {e}") return hosts for host_elem in root.findall(".//host"): @@ -260,9 +265,11 @@ async def _scan_batch( cmd += ["-oX", xml_path, "-v", "--stats-every", "10s", "--open"] cmd += hosts + logger.debug(f"Executing nmap command: {' '.join(cmd)}") hosts_with_ports: set[str] = set() try: + logger.debug(f"Starting nmap subprocess for batch {batch_label}") process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, @@ -311,7 +318,9 @@ async def read_output() -> None: try: await asyncio.wait_for(asyncio.gather(process.wait(), read_output()), timeout=timeout) + logger.debug(f"nmap batch {batch_label} completed with exit code {process.returncode}") except asyncio.TimeoutError: + logger.warning(f"nmap batch {batch_label} timed out after {timeout}s") if verbose: print(f"\n Scan timed out after {timeout}s, terminating nmap...") sys.stdout.flush() @@ -528,6 +537,7 @@ async def scan(self, **kwargs: object) -> PortScanModel: async def quick_scan(self) -> PortScanModel: """Perform a quick scan of common IoT ports asynchronously.""" + logger.info(f"Starting quick scan on {self.target}") err = validate_target(self.target) if err: raise ValueError(err) @@ -625,6 +635,7 @@ async def quick_scan(self) -> PortScanModel: async def full_scan(self) -> PortScanModel: """Full scan using 3-phase hybrid approach asynchronously.""" + logger.info(f"Starting full scan on {self.target}") err = validate_target(self.target) if err: raise ValueError(err) @@ -831,6 +842,7 @@ async def ping_sweep( raise ValueError(err) cmd = get_nmap_command() + ["-sn", "-T4", target] + logger.debug(f"Executing ping sweep: {' '.join(cmd)}") live_hosts = [] try: process = await asyncio.create_subprocess_exec( diff --git a/src/edgewalker/utils.py b/src/edgewalker/utils.py index e3e7434..2e1e5e3 100644 --- a/src/edgewalker/utils.py +++ b/src/edgewalker/utils.py @@ -173,7 +173,7 @@ def print_logo() -> None: ) console.print() - if overrides: + if overrides and not settings.suppress_warnings: sources = ", ".join(sorted(set(overrides.values()))) keys = ", ".join(sorted(overrides.keys())) console.print( @@ -230,6 +230,9 @@ def print_error(msg: str) -> None: def get_input(prompt: str, default: str = None) -> str: """Get user input with optional default.""" + if settings.silent_mode: + return default + if default: prompt_text = ( f"[{theme.PRIMARY}]{theme.ICON_ARROW} {prompt}[/{theme.PRIMARY}] " @@ -248,6 +251,9 @@ def get_input(prompt: str, default: str = None) -> str: def press_enter() -> None: """Wait for user to press enter.""" + if settings.silent_mode: + return + console.print() console.print(f"[{theme.MUTED}]Press Enter to continue...[/{theme.MUTED}]", end="") try: @@ -269,7 +275,36 @@ def is_telemetry_enabled() -> bool: def ensure_telemetry_choice() -> None: """Ensure the user has seen the telemetry opt-in prompt and made a choice.""" telemetry = TelemetryManager(settings) + + # Handle silent mode flags first + if settings.accept_telemetry: + telemetry.set_telemetry_status(True) + return + if settings.decline_telemetry: + telemetry.set_telemetry_status(False) + return + if not telemetry.has_seen_telemetry_prompt(): + if settings.silent_mode: + # Third Party + import typer # noqa: PLC0415 + + console.print() + console.print( + Panel( + f"[bold {theme.DANGER}]ERROR: Telemetry choice required in silent mode." + f"[/bold {theme.DANGER}]\n\n" + "When running with [bold]--silent[/bold], you must explicitly provide a " + "telemetry choice if one has not been set yet.\n\n" + "Use [bold]--accept-telemetry[/bold] to opt-in or " + "[bold]--decline-telemetry[/bold] to opt-out.", + border_style=theme.DANGER, + box=theme.BOX_STYLE, + width=theme.get_ui_width(), + ) + ) + raise typer.Exit(code=1) + # First Party from edgewalker.display import build_telemetry_panel # noqa: PLC0415 diff --git a/tests/conftest.py b/tests/conftest.py index 47a5a81..f3252f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,12 @@ def global_telemetry_safety(): # First Party from edgewalker.core.config import settings - # Reset telemetry setting for every test + # Reset telemetry and silent mode settings for every test settings.telemetry_enabled = False + settings.silent_mode = False + settings.suppress_warnings = False + settings.accept_telemetry = False + settings.decline_telemetry = False # Mock both sync and async httpx post calls globally with ( diff --git a/tests/test_config_coverage.py b/tests/test_config_coverage.py index f12b2c9..1d15b0d 100644 --- a/tests/test_config_coverage.py +++ b/tests/test_config_coverage.py @@ -79,12 +79,20 @@ def test_get_active_overrides_with_env_and_file(tmp_path): env_file = tmp_path / ".env" env_file.write_text("EW_API_URL=https://env-file.com\n# Comment\nINVALID_LINE\nEW_THEME=dark") + # Clear any existing EW_ variables from the environment for this test + clean_env = {k: v for k, v in os.environ.items() if not k.startswith("EW_")} + clean_env.update({ + "EW_API_URL": "https://env-var.com", + "PYTEST_CURRENT_TEST": "1", + "EW_ALLOW_OVERRIDES_IN_TESTS": "1", + }) + with ( patch( "edgewalker.core.config.Path", side_effect=lambda *args: Path(*args) if args[0] != ".env" else env_file, ), - patch.dict(os.environ, {"EW_API_URL": "https://env-var.com", "PYTEST_CURRENT_TEST": ""}), + patch.dict(os.environ, clean_env, clear=True), ): overrides = get_active_overrides() assert overrides["EW_API_URL"] == "environment variable" @@ -181,13 +189,20 @@ def test_get_active_overrides_env_file_error(tmp_path): env_file = tmp_path / ".env" env_file.write_text("EW_THEME=dark") + # Clear any existing EW_ variables from the environment for this test + clean_env = {k: v for k, v in os.environ.items() if not k.startswith("EW_")} + clean_env.update({ + "PYTEST_CURRENT_TEST": "1", + "EW_ALLOW_OVERRIDES_IN_TESTS": "1", + }) + with ( patch( "edgewalker.core.config.Path", side_effect=lambda *args: Path(*args) if args[0] != ".env" else env_file, ), patch("builtins.open", side_effect=OSError("Read error")), - patch.dict(os.environ, {"PYTEST_CURRENT_TEST": ""}), + patch.dict(os.environ, clean_env, clear=True), ): overrides = get_active_overrides() assert overrides == {} # Should fail silently and return empty dict diff --git a/tests/test_utils.py b/tests/test_utils.py index 4c99b09..d6c8a06 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -50,13 +50,17 @@ def test_print_helpers(): @patch("builtins.input", return_value="test") def test_get_input(mock_input): - assert utils.get_input("prompt") == "test" - assert utils.get_input("prompt", "default") == "test" + with patch("edgewalker.utils.settings") as mock_settings: + mock_settings.silent_mode = False + assert utils.get_input("prompt") == "test" + assert utils.get_input("prompt", "default") == "test" @patch("builtins.input", return_value="") def test_get_input_default(mock_input): - assert utils.get_input("prompt", "default") == "default" + with patch("edgewalker.utils.settings") as mock_settings: + mock_settings.silent_mode = False + assert utils.get_input("prompt", "default") == "default" def test_is_physical_mac():