diff --git a/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py b/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py index 8daa50c08..f382caa4f 100644 --- a/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py +++ b/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py @@ -1,5 +1,6 @@ import json import os +import time from dataclasses import dataclass from functools import wraps from typing import ClassVar @@ -23,7 +24,8 @@ def opt_oidc(f): @click.option("--username", help="OIDC username") @click.option("--password", help="OIDC password") @click.option("--connector-id", "connector_id", help="OIDC token exchange connector id (Dex specific)") - @click.option("--callback-port", + @click.option( + "--callback-port", "callback_port", type=click.IntRange(0, 65535), default=None, @@ -93,7 +95,7 @@ async def authorization_code_grant(self, callback_port: int | None = None): elif env_value.isdigit() and int(env_value) <= 65535: port = int(env_value) else: - raise click.ClickException(f"Invalid {JMP_OIDC_CALLBACK_PORT} \"{env_value}\": must be a valid port") + raise click.ClickException(f'Invalid {JMP_OIDC_CALLBACK_PORT} "{env_value}": must be a valid port') tx, rx = create_memory_object_stream() @@ -133,8 +135,75 @@ async def callback(request): def decode_jwt(token: str): - return json.loads(extract_compact(token.encode()).payload) + try: + return json.loads(extract_compact(token.encode()).payload) + except (ValueError, KeyError, TypeError) as e: + raise ValueError(f"Invalid JWT format: {e}") from e def decode_jwt_issuer(token: str): return decode_jwt(token).get("iss") + + +def get_token_expiry(token: str) -> int | None: + """Get token expiry timestamp (Unix epoch seconds) from JWT. + + Returns None if token doesn't have an exp claim. + """ + return decode_jwt(token).get("exp") + + +def get_token_remaining_seconds(token: str) -> float | None: + """Get seconds remaining until token expires. + + Returns: + Positive value if token is still valid + Negative value if token is expired (magnitude = how long ago) + None if token doesn't have an exp claim + """ + exp = get_token_expiry(token) + if exp is None: + return None + return exp - time.time() + + +# Token expiry warning threshold in seconds (5 minutes) +TOKEN_EXPIRY_WARNING_SECONDS = 300 + + +def is_token_expired(token: str, buffer_seconds: int = 0) -> bool: + """Check if token is expired or will expire within buffer_seconds. + + Args: + token: JWT token string + buffer_seconds: Consider expired if less than this many seconds remain + + Returns: + True if token is expired or will expire within buffer + False if token is still valid (or has no exp claim) + """ + remaining = get_token_remaining_seconds(token) + if remaining is None: + return False + return remaining < buffer_seconds + + +def format_duration(seconds: float) -> str: + """Format a duration in seconds as a human-readable string. + + Args: + seconds: Duration in seconds (can be negative for past times) + + Returns: + Formatted string like "2h 30m" or "5m 10s" + """ + abs_seconds = abs(seconds) + hours = int(abs_seconds // 3600) + mins = int((abs_seconds % 3600) // 60) + secs = int(abs_seconds % 60) + + if hours > 0: + return f"{hours}h {mins}m" + if mins > 0: + return f"{mins}m {secs}s" + return f"{secs}s" diff --git a/packages/jumpstarter-cli/jumpstarter_cli/auth.py b/packages/jumpstarter-cli/jumpstarter_cli/auth.py new file mode 100644 index 000000000..d06142519 --- /dev/null +++ b/packages/jumpstarter-cli/jumpstarter_cli/auth.py @@ -0,0 +1,67 @@ +from datetime import datetime, timezone + +import click +from jumpstarter_cli_common.config import opt_config +from jumpstarter_cli_common.oidc import ( + TOKEN_EXPIRY_WARNING_SECONDS, + decode_jwt, + format_duration, + get_token_remaining_seconds, +) + + +@click.group() +def auth(): + """Authentication and token management commands.""" + + +def _print_token_status(remaining: float) -> None: + """Print token status message based on remaining time.""" + duration = format_duration(remaining) + + if remaining < 0: + click.echo(click.style(f"Status: EXPIRED ({duration} ago)", fg="red", bold=True)) + click.echo(click.style("Run 'jmp login' to refresh your credentials.", fg="yellow")) + elif remaining < TOKEN_EXPIRY_WARNING_SECONDS: + click.echo(click.style(f"Status: EXPIRING SOON ({duration} remaining)", fg="red", bold=True)) + click.echo(click.style("Run 'jmp login' to refresh your credentials.", fg="yellow")) + elif remaining < 3600: + click.echo(click.style(f"Status: Valid ({duration} remaining)", fg="yellow")) + else: + click.echo(click.style(f"Status: Valid ({duration} remaining)", fg="green")) + + +@auth.command(name="status") +@opt_config(exporter=False) +def token_status(config): + """Display token status and expiry information.""" + token_str = getattr(config, "token", None) + + if not token_str: + click.echo(click.style("No token found in config", fg="yellow")) + return + + try: + payload = decode_jwt(token_str) + except ValueError as e: + click.echo(click.style(f"Failed to decode token: {e}", fg="red")) + return + + remaining = get_token_remaining_seconds(token_str) + if remaining is None: + click.echo(click.style("Token has no expiry claim", fg="yellow")) + return + + exp = payload.get("exp") + exp_dt = datetime.fromtimestamp(exp, tz=timezone.utc) + click.echo(f"Token expiry: {exp_dt.strftime('%Y-%m-%d %H:%M:%S %Z')}") + + _print_token_status(remaining) + + # Show additional token info + sub = payload.get("sub") + iss = payload.get("iss") + if sub: + click.echo(f"Subject: {sub}") + if iss: + click.echo(f"Issuer: {iss}") diff --git a/packages/jumpstarter-cli/jumpstarter_cli/jmp.py b/packages/jumpstarter-cli/jumpstarter_cli/jmp.py index d27fe1a10..219286aec 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/jmp.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/jmp.py @@ -5,6 +5,7 @@ from jumpstarter_cli_common.version import version from jumpstarter_cli_driver import driver +from .auth import auth from .config import config from .create import create from .delete import delete @@ -21,6 +22,7 @@ def jmp(): """The Jumpstarter CLI""" +jmp.add_command(auth) jmp.add_command(create) jmp.add_command(delete) jmp.add_command(update) diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index f405b9f49..41155a4fa 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -6,6 +6,11 @@ from anyio import create_task_group, get_cancelled_exc_class from jumpstarter_cli_common.config import opt_config from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication +from jumpstarter_cli_common.oidc import ( + TOKEN_EXPIRY_WARNING_SECONDS, + format_duration, + get_token_remaining_seconds, +) from jumpstarter_cli_common.signal import signal_handler from .common import opt_acquisition_timeout, opt_duration_partial, opt_selector @@ -15,12 +20,59 @@ from jumpstarter.config.exporter import ExporterConfigV1Alpha1 +def _warn_about_expired_token(lease_name: str, selector: str) -> None: + """Warn user that lease won't be cleaned up due to expired token.""" + click.echo(click.style("\nToken expired - lease cleanup will fail.", fg="yellow", bold=True)) + click.echo(click.style(f"Lease '{lease_name}' will remain active.", fg="yellow")) + click.echo(click.style(f"To reconnect: JMP_LEASE={lease_name} jmp shell", fg="cyan")) + + +async def _monitor_token_expiry(config, cancel_scope) -> None: + """Monitor token expiry and warn user.""" + token = getattr(config, "token", None) + if not token: + return + + warned = False + while not cancel_scope.cancel_called: + try: + remaining = get_token_remaining_seconds(token) + if remaining is None: + return + + if remaining <= 0: + click.echo(click.style("\nToken expired! Exiting shell.", fg="red", bold=True)) + cancel_scope.cancel() + return + + if remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warned: + duration = format_duration(remaining) + click.echo( + click.style( + f"\nToken expires in {duration}. Session will continue but cleanup may fail on exit.", + fg="yellow", + bold=True, + ) + ) + warned = True + + await anyio.sleep(30) + except Exception: + return + + def _run_shell_with_lease(lease, exporter_logs, config, command): """Run shell with lease context managers.""" + def launch_remote_shell(path: str) -> int: return launch_shell( - path, lease.exporter_name, config.drivers.allow, config.drivers.unsafe, - config.shell.use_profiles, command=command, lease=lease + path, + lease.exporter_name, + config.drivers.allow, + config.drivers.unsafe, + config.shell.use_profiles, + command=command, + lease=lease, ) with lease.serve_unix() as path: @@ -39,13 +91,28 @@ async def _shell_with_signal_handling( """Handle lease acquisition and shell execution with signal handling.""" exit_code = 0 cancelled_exc_class = get_cancelled_exc_class() + lease_used = None + + # Check token before starting + token = getattr(config, "token", None) + if token: + remaining = get_token_remaining_seconds(token) + if remaining is not None and remaining <= 0: + from jumpstarter.common.exceptions import ConnectionError + raise ConnectionError("token is expired") async with create_task_group() as tg: tg.start_soon(signal_handler, tg.cancel_scope) + try: try: async with anyio.from_thread.BlockingPortal() as portal: async with config.lease_async(selector, lease_name, duration, portal, acquisition_timeout) as lease: + lease_used = lease + + # Start token monitoring only once we're in the shell + tg.start_soon(_monitor_token_expiry, config, tg.cancel_scope) + exit_code = await anyio.to_thread.run_sync( _run_shell_with_lease, lease, exporter_logs, config, command ) @@ -55,6 +122,13 @@ async def _shell_with_signal_handling( raise exc from None raise except cancelled_exc_class: + # Check if cancellation was due to token expiry + token = getattr(config, "token", None) + if lease_used and token: + remaining = get_token_remaining_seconds(token) + if remaining is not None and remaining <= 0: + _warn_about_expired_token(lease_used.name, selector) + return 3 # Exit code for token expiry exit_code = 2 finally: if not tg.cancel_scope.cancel_called: