diff --git a/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py b/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py index 382b98ec9..8daa50c08 100644 --- a/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py +++ b/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py @@ -1,4 +1,5 @@ import json +import os from dataclasses import dataclass from functools import wraps from typing import ClassVar @@ -12,6 +13,8 @@ from joserfc.jws import extract_compact from yarl import URL +from jumpstarter.config.env import JMP_OIDC_CALLBACK_PORT + def opt_oidc(f): @click.option("--issuer", help="OIDC issuer") @@ -20,6 +23,12 @@ def opt_oidc(f): @click.option("--username", help="OIDC username") @click.option("--password", help="OIDC password") @click.option("--connector-id", "connector_id", help="OIDC token exchange connector id (Dex specific)") + @click.option("--callback-port", + "callback_port", + type=click.IntRange(0, 65535), + default=None, + help="Port for OIDC callback server (0=random port)", + ) @wraps(f) def wrapper(*args, **kwds): return f(*args, **kwds) @@ -71,9 +80,21 @@ async def password_grant(self, username: str, password: str): ) ) - async def authorization_code_grant(self): + async def authorization_code_grant(self, callback_port: int | None = None): config = await self.configuration() + # Use provided port, fall back to env var, then default to 0 (OS picks) + if callback_port is not None: + port = callback_port + else: + env_value = os.environ.get(JMP_OIDC_CALLBACK_PORT) + if env_value is None: + port = 0 + elif env_value.isdigit() and int(env_value) <= 65535: + port = int(env_value) + else: + raise click.ClickException(f"Invalid {JMP_OIDC_CALLBACK_PORT} \"{env_value}\": must be a valid port") + tx, rx = create_memory_object_stream() async def callback(request): @@ -86,8 +107,12 @@ async def callback(request): runner = web.AppRunner(app, access_log=None) await runner.setup() - site = web.TCPSite(runner, "localhost", 0) - await site.start() + site = web.TCPSite(runner, "localhost", port) + try: + await site.start() + except OSError as e: + await runner.cleanup() + raise click.ClickException(f"Failed to start callback server on port {port}: {e}") from None redirect_uri = "http://localhost:%d/callback" % site._server.sockets[0].getsockname()[1] diff --git a/packages/jumpstarter-cli/jumpstarter_cli/login.py b/packages/jumpstarter-cli/jumpstarter_cli/login.py index 57bf88bba..c03800239 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/login.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/login.py @@ -47,6 +47,7 @@ async def login( # noqa: C901 issuer: str, client_id: str, connector_id: str, + callback_port: int | None, unsafe, insecure_tls_config: bool, nointeractive: bool, @@ -123,7 +124,7 @@ async def login( # noqa: C901 elif username is not None and password is not None: tokens = await oidc.password_grant(username, password) else: - tokens = await oidc.authorization_code_grant() + tokens = await oidc.authorization_code_grant(callback_port=callback_port) config.token = tokens["access_token"] diff --git a/packages/jumpstarter/jumpstarter/config/env.py b/packages/jumpstarter/jumpstarter/config/env.py index ace6f1cc8..145966d27 100644 --- a/packages/jumpstarter/jumpstarter/config/env.py +++ b/packages/jumpstarter/jumpstarter/config/env.py @@ -10,3 +10,4 @@ JMP_LEASE = "JMP_LEASE" JMP_DISABLE_COMPRESSION = "JMP_DISABLE_COMPRESSION" +JMP_OIDC_CALLBACK_PORT = "JMP_OIDC_CALLBACK_PORT"