Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class FlashRetryableError(FlashError):
class FlashNonRetryableError(FlashError):
"""Exception for non-retryable flash errors (configuration, file system, etc.)."""


debug_console_option = click.option("--console-debug", is_flag=True, help="Enable console debug mode")

EXPECT_TIMEOUT_DEFAULT = 60
Expand Down Expand Up @@ -100,6 +101,7 @@ def flash( # noqa: C901
retries: int = 3,
method: str = "fls",
fls_version: str = "",
fls_binary_url: str | None = None,
):
if bearer_token:
bearer_token = self._validate_bearer_token(bearer_token)
Expand All @@ -112,8 +114,11 @@ def flash( # noqa: C901
image_url = ""
original_http_url = None
operator_scheme = None
# initrmafs cannot handle https yet, fallback to using the exporter's http server
if path.startswith(("http://", "https://")) and not force_exporter_http:
if path.startswith("oci://"):
# OCI URLs are always passed directly to fls
image_url = path
should_download_to_httpd = False
elif path.startswith(("http://", "https://")) and not force_exporter_http:
# the flasher image can handle the http(s) from a remote directly, unless target is isolated
image_url = path
should_download_to_httpd = False
Expand Down Expand Up @@ -171,9 +176,19 @@ def flash( # noqa: C901
for attempt in range(retries + 1): # +1 for initial attempt
try:
self._perform_flash_operation(
partition, path, image_url, should_download_to_httpd,
storage_thread, error_queue, cacert_file, insecure_tls,
headers, bearer_token, method, fls_version
partition,
path,
image_url,
should_download_to_httpd,
storage_thread,
error_queue,
cacert_file,
insecure_tls,
headers,
bearer_token,
method,
fls_version,
fls_binary_url,
)
self.logger.info(f"Flash operation succeeded on attempt {attempt + 1}")
break
Expand All @@ -193,15 +208,14 @@ def flash( # noqa: C901
)
self.logger.info(f"Retrying flash operation (attempt {attempt + 2}/{retries + 1})")
# Wait a bit before retrying
time.sleep(2 ** attempt) # Exponential backoff
time.sleep(2**attempt) # Exponential backoff
continue
else:
self.logger.error(f"Flash operation failed after {retries + 1} attempts")
raise FlashError(
f"Flash operation failed after {retries + 1} attempts. Last error: {categorized_error}"
) from e


total_time = time.time() - start_time
# total time in minutes:seconds
minutes, seconds = divmod(total_time, 60)
Expand Down Expand Up @@ -261,7 +275,7 @@ def _find_exception_in_chain(self, exception: Exception, target_type: type) -> E
The found exception instance if found, None otherwise
"""
# Check if this is an ExceptionGroup and look through its exceptions
if hasattr(exception, 'exceptions'):
if hasattr(exception, "exceptions"):
for sub_exc in exception.exceptions:
result = self._find_exception_in_chain(sub_exc, target_type)
if result is not None:
Expand All @@ -272,17 +286,17 @@ def _find_exception_in_chain(self, exception: Exception, target_type: type) -> E
return exception

# Check the cause chain
current = getattr(exception, '__cause__', None)
current = getattr(exception, "__cause__", None)
while current is not None:
if isinstance(current, target_type):
return current
# Also check if the cause is an ExceptionGroup
if hasattr(current, 'exceptions'):
if hasattr(current, "exceptions"):
for sub_exc in current.exceptions:
result = self._find_exception_in_chain(sub_exc, target_type)
if result is not None:
return result
current = getattr(current, '__cause__', None)
current = getattr(current, "__cause__", None)
return None

def _perform_flash_operation(
Expand All @@ -299,6 +313,7 @@ def _perform_flash_operation(
bearer_token: str | None,
method: str,
fls_version: str,
fls_binary_url: str | None,
):
"""Perform the actual flash operation with console setup.

Expand Down Expand Up @@ -351,7 +366,6 @@ def _perform_flash_operation(

header_args = self._prepare_headers(headers, bearer_token)


if method == "fls":
self._flash_with_fls(
console,
Expand All @@ -363,6 +377,7 @@ def _perform_flash_operation(
stored_cacert,
header_args,
fls_version,
fls_binary_url,
)
elif method == "shell":
self._flash_with_progress(
Expand Down Expand Up @@ -453,6 +468,34 @@ def _sq(s: str) -> str:

return " ".join(parts)

def _download_fls_binary(self, console, prompt: str, download_url: str, error_message_prefix: str):
"""Download FLS binary to the target device.

Args:
console: Console object for device interaction
prompt: Login prompt for console interaction
download_url: URL to download the FLS binary from
error_message_prefix: Prefix for error message if download fails

Raises:
FlashRetryableError: If download fails or binary cannot be made executable
"""
console.sendline(f"curl -L {download_url} -o /sbin/fls")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
console.sendline("echo $?")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)

try:
lines = console.before.decode(errors="ignore").strip().splitlines()
exit_code = int(lines[-1]) if lines else -1
except (IndexError, ValueError) as e:
raise FlashRetryableError(f"{error_message_prefix}, failed to parse exit code") from e

if exit_code != 0:
raise FlashRetryableError(f"{error_message_prefix}, exit code: {exit_code}")
console.sendline("chmod +x /sbin/fls")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)

def _flash_with_fls(
self,
console,
Expand All @@ -464,6 +507,7 @@ def _flash_with_fls(
stored_cacert,
header_args: str,
fls_version: str,
fls_binary_url: str | None,
):
"""Flash image to target device with progress monitoring.

Expand All @@ -477,30 +521,21 @@ def _flash_with_fls(
stored_cacert: Path to the stored CA certificate in the DUT flasher
header_args: Header arguments for curl command
fls_version: Version of FLS to use
fls_binary_url: Custom URL to download FLS binary from (overrides fls_version)
"""

# Calculate decompress and tls arguments for curl
prompt = manifest.spec.login.prompt
tls_args = self._cmdline_tls_args(insecure_tls, stored_cacert)

if fls_version != "":
if fls_binary_url:
self.logger.info(f"Downloading FLS binary from custom URL: {fls_binary_url}")
self._download_fls_binary(console, prompt, fls_binary_url, f"Failed to download FLS from {fls_binary_url}")
elif fls_version != "":
self.logger.info(f"Downloading FLS version {fls_version} from GitHub releases")
# Download fls binary to the target device (until it is available on the target device)
fls_url = (
f"https://github.com/jumpstarter-dev/fls/releases/download/{fls_version}/"
f"fls-aarch64-linux"
)
console.sendline(f"curl -L {fls_url} -o /sbin/fls")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
console.sendline("echo $?")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)

exit_code = int(console.before.decode(errors="ignore").strip().splitlines()[-1])

if exit_code != 0:
raise FlashRetryableError(f"Failed to download FLS from {fls_url}, exit code: {exit_code}")
console.sendline("chmod +x /sbin/fls")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
fls_url = f"https://github.com/jumpstarter-dev/fls/releases/download/{fls_version}/fls-aarch64-linux"
self._download_fls_binary(console, prompt, fls_url, f"Failed to download FLS from {fls_url}")

# Flash the image
flash_cmd = f'fls from-url -i 1.0 -n {tls_args} {header_args} --o-direct "{image_url}" {target_path}'
Expand Down Expand Up @@ -529,7 +564,7 @@ def _monitor_fls_progress(self, console, prompt):
if len(current_output) > last_printed_length:
new_output = current_output[last_printed_length:]
if new_output:
print(new_output, end='', flush=True)
print(new_output, end="", flush=True)
last_printed_length = len(current_output)

# Check if we matched the prompt (index 0 means prompt matched)
Expand All @@ -538,7 +573,7 @@ def _monitor_fls_progress(self, console, prompt):
break
# If match_index is 1, it means TIMEOUT was matched, so we continue the loop

if 'panicked at' in current_output:
if "panicked at" in current_output:
raise FlashRetryableError(f"FLS panicked: {current_output}")

except pexpect.EOF as err:
Expand Down Expand Up @@ -592,8 +627,8 @@ def _flash_with_progress(
flash_cmd = (
f'( set -o pipefail; curl -fsSL {tls_args} {header_args} "{image_url}" | '
f"{decompress_cmd} "
f'dd of={target_path} bs=64k iflag=fullblock oflag=direct ' +
'&& echo "F""LASH_COMPLETE" || echo "F""LASH_FAILED" ) &'
f"dd of={target_path} bs=64k iflag=fullblock oflag=direct "
+ '&& echo "F""LASH_COMPLETE" || echo "F""LASH_FAILED" ) &'
)
console.sendline(flash_cmd)
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT * 2)
Expand All @@ -611,7 +646,7 @@ def _monitor_flash_progress(self, console, prompt):
console.sendline("pidof dd")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
pidof_output = console.before.decode(errors="ignore")
accumulated_output = pidof_output # just in case we get the FLASH_COMPLETE or FLASH_FAILED markers soon
accumulated_output = pidof_output # just in case we get the FLASH_COMPLETE or FLASH_FAILED markers soon

# Extract the actual process ID from the output, handling potential error messages
lines = pidof_output.splitlines()
Expand Down Expand Up @@ -691,8 +726,8 @@ def _update_accumulated_output(self, accumulated_output, data):
"""Update accumulated output with new data, keeping only last 64KB."""
accumulated_output += data
# Keep only the last 64KB to prevent memory growth
if len(accumulated_output) > 64*1024:
accumulated_output = accumulated_output[-64*1024:]
if len(accumulated_output) > 64 * 1024:
accumulated_output = accumulated_output[-64 * 1024 :]
return accumulated_output

def _update_progress_stats(self, data, last_pos, last_time):
Expand Down Expand Up @@ -925,7 +960,16 @@ def dump(

def _filename(self, path: PathBuf) -> str:
"""Extract filename from url or path"""
if path.startswith(("http://", "https://")):
if path.startswith("oci://"):
oci_path = path[6:] # Remove "oci://" prefix
if ":" in oci_path:
repository, tag = oci_path.rsplit(":", 1)
repo_name = repository.split("/")[-1] if "/" in repository else repository
return f"{repo_name}-{tag}"
else:
repo_name = oci_path.split("/")[-1] if "/" in oci_path else oci_path
return repo_name
elif path.startswith(("http://", "https://")):
return urlparse(path).path.split("/")[-1]
else:
return Path(path).name
Expand Down Expand Up @@ -1163,9 +1207,14 @@ def base():
@click.option(
"--fls-version",
type=str,
default="0.1.9", # TODO(majopela): set default to "" once fls is included in our images
default="0.1.9", # TODO(majopela): set default to "" once fls is included in our images
help="Download an specific fls version from the github releases",
)
@click.option(
"--fls-binary-url",
type=str,
help="Custom URL to download FLS binary from (overrides --fls-version)",
)
@debug_console_option
def flash(
file,
Expand All @@ -1182,6 +1231,7 @@ def flash(
retries,
method,
fls_version,
fls_binary_url,
):
"""Flash image to DUT from file"""
if os_image_checksum_file and os.path.exists(os_image_checksum_file):
Expand All @@ -1205,6 +1255,7 @@ def flash(
retries=retries,
method=method,
fls_version=fls_version,
fls_binary_url=fls_binary_url,
)

@base.command()
Expand Down
Loading