From 881988e4139d8a0a2a8bc25258498f2f5a98aa58 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Wed, 25 Mar 2026 00:09:52 -0700 Subject: [PATCH] makefile(streamlit the same, stlite added, bug for display fixed, output as 0 as TD), and AGENTS.md added --- AGENTS.md | 221 ++++++++++++++++++++++++++++++++++++++ Makefile | 82 ++++++++++++++ app.py | 203 +++++++++++++++++++++------------- scripts/build_index.py | 145 +++++++++++++++++++++++++ utils/parameter_loader.py | 34 ++++-- utils/parameter_ui.py | 129 ++++++---------------- 6 files changed, 638 insertions(+), 176 deletions(-) create mode 100644 AGENTS.md create mode 100644 Makefile create mode 100644 scripts/build_index.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..44e808e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,221 @@ +# AGENTS.md — EpiCON Cost Calculator + +This file describes the conventions, roles, and constraints for contributors working in this +repository. All agents — whether running in GitHub Actions or invoked interactively — should +read and follow this document before making changes. + +--- + +## Project overview + +**EpiCON** is a browser-based epidemiological cost calculator built with **Streamlit** and +distributed as a static **stlite** build for browser execution. + +The current app supports two model flows: + +1. **Python + YAML models** + - A Python module in `models/` implements model logic. + - A paired YAML file provides default parameters. + - `app.py` loads the Python module, loads YAML defaults, renders parameter inputs, + runs the model, and renders sections. + +2. **Excel-driven models** + - An uploaded `.xlsx` file is parsed by `utils/excel_model_runner.py`. + - Parameters and computed outputs are rendered from workbook contents. + +Current high-level flow: + +`discover_models() → load_model_from_file() / load_model_params() → render_parameters_with_indent() → run_model() → build_sections() → render_sections()` + +Persistence helpers: +- `store_model_state()` +- `save_current_model()` + +--- + +## Repository layout + +```text +epiworldPythonStreamlit/ +├── app.py +├── models/ # Python model modules + paired YAML parameter files +├── utils/ +│ ├── model_loader.py +│ ├── parameter_loader.py +│ ├── parameter_ui.py +│ ├── excel_model_runner.py +│ └── section_renderer.py +├── config/ +├── styles/ +├── scripts/ # stlite build helpers +├── build/ # generated static output +├── docs/ +├── pyproject.toml +├── Makefile +├── AGENTS.md +└── README.md +``` + +--- + +## Coding conventions + +All agent-generated code must follow these standards. + +### Style +- Follow **PEP 8** for all Python code. +- Use **type hints** on every function signature. +- Write **docstrings** for every public function and class. +- Maximum line length: **100 characters**. + +### Testing +- Tests use **pytest** only. +- Add or update pytest coverage for every changed public function whenever a `tests/` target + exists for that module. +- Prefer pure-Python tests compatible with the stlite/Pyodide target. +- Avoid test dependencies that require native binaries. +- Aim for at least **80% line coverage** on new code where CI coverage is enforced. + +### Dependencies +- Add dependencies via `pyproject.toml` only, managed by `uv`. +- Runtime dependencies must be **pure-Python** or available as **Pyodide wheels**. +- Dev-only dependencies are exempt from the Pyodide runtime constraint. + +### Security +- Equations and figure code in YAML files are untrusted input. +- Validate them with the CPython `ast` module before evaluation. +- Never use `eval()` or `exec()` on unvalidated user-supplied strings. +- Do not log or print parameter values that could contain PII. + +### Git workflow +- **No direct pushes to `main`.** +- Use branch names like: + - `feat/` + - `fix/` + - `chore/` +- Use **Conventional Commits** for commit messages. +- Each PR should address a single concern. + +--- + +## Function contracts (current) + +Agents must not change a function signature or return type without updating all callers and +tests. + +### App / utility layer +- `discover_models(path: str) -> dict[str, str]` +- `load_model_from_file(filepath: str) -> object` +- `load_model_params(model_file_path: str, uploaded_excel=None) -> dict` +- `flatten_dict(d, level=0)` +- `render_parameters_with_indent(param_dict, params, label_overrides) -> None` +- `reset_parameters_to_defaults(param_dict, params, model_id) -> None` +- `render_sections(sections) -> None` + +### Python model modules +Each Python model module in `models/` must expose: +- `model_title: str` +- `model_description: str` +- `run_model(params: dict, label_overrides: dict | None = None) -> list[dict]` +- `build_sections(results: list[dict]) -> list[dict]` + +--- + +## Agent roles + +### Interactive development agents +**Scope:** Code generation, refactoring, documentation, tests, YAML schema work, and code +review suggestions. + +**Constraints:** +- Always read `AGENTS.md` and relevant source files before proposing changes. +- Do not modify `pyproject.toml` dependencies without explaining Pyodide compatibility. +- Use AST validation for equation evaluation. +- Prefer Streamlit-native rendering over heavier plotting dependencies. +- For browser-only persistence, use browser storage patterns rather than server-side files. +- Propose tests alongside implementation changes. + +### GitHub Actions — CI agent +**Trigger:** Every push and every PR targeting `main`. + +**Expected steps:** +1. Check out the repository. +2. Install dependencies with `uv`. +3. Run `ruff`. +4. Run `mypy`. +5. Run `pytest` with coverage. +6. Fail if configured coverage thresholds are not met. + +**Constraints:** +- Must run in the project’s supported development environment. +- Do not upload artifacts unless explicitly configured. + +### GitHub Actions — agent environment +**Purpose:** Set up credentials, tools, and optional external service tokens. + +**Constraints:** +- Secrets must be stored in GitHub Actions Secrets. +- Fail loudly if required secrets are missing. + +--- + +## YAML model schema + +The app currently supports both of these YAML layouts. + +### 1. Flat key/value defaults +```yaml +Cost of measles hospitalization: 31168 +Proportion of cases hospitalized: 0.2 +``` + +### 2. Nested parameter dictionary +```yaml +parameters: + Cost of measles hospitalization: 31168 + Proportion of cases hospitalized: 0.2 +``` + +Agents must preserve compatibility with both layouts unless the app is explicitly migrated. + +### Reference schema for future structured models +```yaml +model: + metadata: + parameters: + equations: + table: + figures: + current_parameters: +``` + +If generating new structured YAML, keep it compatible with current loaders or update the +loaders and tests together. + +--- + +## Out-of-scope for agents + +The following are off-limits without explicit human approval in PR review: + +- Changing branch protection rules. +- Adding `eval()` or `exec()` on user-supplied strings. +- Introducing runtime dependencies that require native binaries. +- Modifying the public function signatures listed above. +- Writing outside the repository working directory. +- Disabling or skipping CI checks. + +--- + +## Getting started + +```bash +uv sync +make dev +make setup +make serve +make check +``` + +For architecture questions, refer to the development plan, inline source comments, and current +utility/model implementations. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b36f286 --- /dev/null +++ b/Makefile @@ -0,0 +1,82 @@ +UV := uv + +STLITE_VER ?= 0.86.0 +PORT ?= 8000 + +STLITE_CSS := https://cdn.jsdelivr.net/npm/@stlite/browser@$(strip $(STLITE_VER))/build/stlite.css +STLITE_JS := https://cdn.jsdelivr.net/npm/@stlite/browser@$(strip $(STLITE_VER))/build/stlite.js +APP_PY := app.py +BUILD_DIR := build +INDEX_HTML := $(BUILD_DIR)/index.html + +STLITE_INPUTS := $(APP_PY) pyproject.toml scripts/build_index.py \ + $(shell find models utils config styles selected examples -type f 2>/dev/null) + +.DEFAULT_GOAL := help + +.PHONY: help +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \ + awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}' + +.PHONY: setup +setup: install build-html ## Install dependencies and build the stlite app + @echo "" + @echo "EpiCON is ready." + @echo " Dev server (normal Streamlit): make dev" + @echo " Serve stlite build locally: make serve" + +.PHONY: install +install: ## Install Python dependencies with uv + $(UV) sync + +.PHONY: build-html +build-html: $(INDEX_HTML) ## Generate build/index.html (stlite WASM entry point) + +$(INDEX_HTML): $(STLITE_INPUTS) | $(BUILD_DIR) + @echo "Generating $(INDEX_HTML) (stlite v$(STLITE_VER))" + $(UV) run python scripts/build_index.py \ + --app $(APP_PY) \ + --out $(INDEX_HTML) \ + --css $(STLITE_CSS) \ + --js $(STLITE_JS) \ + --title "EpiCON Cost Calculator" + @echo "Copying static assets into $(BUILD_DIR)/" + @for dir in utils config styles models smelected examples; do \ + if [ -d "$$dir" ]; then cp -r "$$dir" "$(BUILD_DIR)/$$dir"; fi; \ + done + @echo "Build artefacts written to $(BUILD_DIR)/" + +$(BUILD_DIR): + mkdir -p $@ + +.PHONY: dev +dev: ## Run normal Streamlit locally + $(UV) run streamlit run $(APP_PY) + +.PHONY: serve +serve: build-html ## Serve the stlite static build + @echo "Serving stlite build at http://localhost:$(PORT) (Ctrl-C to stop)" + $(UV) run python -m http.server $(PORT) --directory $(BUILD_DIR) + +.PHONY: stlite +stlite: setup serve ## Install, build, and serve the stlite app + +.PHONY: lint +lint: ## Run ruff linter + $(UV) run ruff check . + +.PHONY: typecheck +typecheck: ## Run mypy type checker + $(UV) run mypy utils + +.PHONY: test +test: ## Run pytest + $(UV) run pytest + +.PHONY: check +check: lint typecheck test ## Run all quality checks + +.PHONY: clean +clean: ## Remove build artefacts + rm -rf $(BUILD_DIR) .mypy_cache .ruff_cache .pytest_cache __pycache__ diff --git a/app.py b/app.py index 650c4f4..3c5dbf4 100644 --- a/app.py +++ b/app.py @@ -2,9 +2,10 @@ import yaml import streamlit as st import inspect +from decimal import Decimal, InvalidOperation from utils.model_loader import discover_models, load_model_from_file -from utils.parameter_loader import load_model_params, flatten_dict +from utils.parameter_loader import load_model_params, flatten_dict, get_leaf_defaults from utils.section_renderer import render_sections from utils.parameter_ui import render_parameters_with_indent, reset_parameters_to_defaults from utils.excel_model_runner import ( @@ -24,12 +25,63 @@ # UI STYLES -def load_css(file_path: str): +def load_css(file_path: str) -> None: + """Load a CSS file into the Streamlit app if it exists.""" if os.path.exists(file_path): - with open(file_path) as f: - st.markdown(f"", unsafe_allow_html=True) + with open(file_path, encoding="utf-8") as file: + st.markdown(f"", unsafe_allow_html=True) +def normalize_yaml_defaults(raw_yaml: object) -> dict: + """Normalize supported YAML layouts into a flat parameter dictionary.""" + if not isinstance(raw_yaml, dict): + return {} + + parameter_block = raw_yaml.get("parameters", raw_yaml) + if not isinstance(parameter_block, dict): + return {} + + return flatten_dict(parameter_block) + + +def running_in_stlite() -> bool: + """Return True when the app is running inside stlite/Pyodide.""" + return os.path.abspath(__file__).startswith("/home/pyodide/") + + +def coerce_like_default(value: object, default: object) -> object: + """Coerce a widget value to the type implied by its default.""" + if value in ("", None): + return default + + if isinstance(default, Decimal): + try: + return Decimal(str(value).replace(",", "").strip()) + except (InvalidOperation, ValueError): + return default + + if isinstance(default, int) and not isinstance(default, bool): + try: + return int(float(str(value).replace(",", "").strip())) + except ValueError: + return default + + if isinstance(default, float): + try: + return float(str(value).replace(",", "").strip()) + except ValueError: + return default + + return value + + +def normalize_stlite_params(params: dict, defaults: dict) -> dict: + """Restore blank stlite values and coerce numeric strings before model execution.""" + normalized = dict(params) + for key, default in get_leaf_defaults(defaults).items(): + normalized[key] = coerce_like_default(normalized.get(key), default) + return normalized + load_css(os.path.join(base_dir, "styles/sidebar.css")) # MODEL SELECTION @@ -89,46 +141,39 @@ def load_css(file_path: str): uploaded_excel_model = st.sidebar.file_uploader( "Upload Excel model file (.xlsx)", type=["xlsx"], - key="excel_model_uploader" + key="excel_model_uploader", ) if uploaded_excel_model: - - # reset params if Excel file changes if ( - "excel_active_name" not in st.session_state - or st.session_state.excel_active_name != uploaded_excel_model.name + "excel_active_name" not in st.session_state + or st.session_state.excel_active_name != uploaded_excel_model.name ): st.session_state.excel_active_name = uploaded_excel_model.name st.session_state.params = {} params = st.session_state.params - # Load the defaults editable_defaults, _ = load_excel_params_defaults_with_computed( uploaded_excel_model, sheet_name=None, - start_row=3 + start_row=3, ) - # Load Headers - current_headers = get_scenario_headers(uploaded_excel_model) + for key, value in get_leaf_defaults(editable_defaults).items(): + params.setdefault(key, value) + current_headers = get_scenario_headers(uploaded_excel_model) - # RESET CALLBACK - def handle_reset_excel(): - # 1. Reset Parameters + def handle_reset_excel() -> None: + """Reset Excel parameters and output labels to defaults.""" reset_parameters_to_defaults(editable_defaults, params, uploaded_excel_model.name) - # 2. Reset Header Labels if current_headers: for col_letter, default_text in current_headers.items(): st.session_state[f"label_override_{col_letter}"] = default_text - - # Display Button with Callback st.sidebar.button("Reset Parameters", on_click=handle_reset_excel) - # Outcome Headers (Column B-E) if current_headers: with st.sidebar.expander("Output Scenario Headers", expanded=False): st.caption("Rename the output headers (B, C, D, E)") @@ -137,25 +182,22 @@ def handle_reset_excel(): default_text = current_headers[col_letter] widget_key = f"label_override_{col_letter}" - # Robust Widget Logic if widget_key in st.session_state: - new_text = st.text_input( - f"Column {col_letter} Label", - key=widget_key - ) + new_text = st.text_input(f"Column {col_letter} Label", key=widget_key) else: new_text = st.text_input( f"Column {col_letter} Label", value=default_text, - key=widget_key + key=widget_key, ) - label_overrides[col_letter] = new_text + label_overrides[col_letter] = ( + new_text if str(new_text).strip() else default_text + ) - # Render Main Parameters render_parameters_with_indent( editable_defaults, params, - model_id=uploaded_excel_model.name + model_id=model_key ) else: @@ -166,22 +208,30 @@ def handle_reset_excel(): param_source = st.sidebar.radio( "Parameter Source", ["Model Default (YAML)", "Excel (.xlsx)", "YAML (.yaml)"], - horizontal=True + horizontal=True, ) uploaded_excel = None - uploaded_yaml = None + uploaded_yaml_file = None if param_source == "Excel (.xlsx)": uploaded_excel = st.sidebar.file_uploader("Upload Excel parameter file", type=["xlsx"]) elif param_source == "YAML (.yaml)": - uploaded_yaml = st.sidebar.file_uploader("Upload YAML parameter file", type=["yaml", "yml"]) + uploaded_yaml_file = st.sidebar.file_uploader( + "Upload YAML parameter file", + type=["yaml", "yml"], + ) + + parameter_file = "Model Default (YAML)" + if uploaded_excel is not None: + parameter_file = uploaded_excel.name + elif uploaded_yaml_file is not None: + parameter_file = uploaded_yaml_file.name - # PARAMETER SOURCE RESET LOGIC param_identity = ( param_source, uploaded_excel.name if uploaded_excel else None, - uploaded_yaml.name if uploaded_yaml else None, + uploaded_yaml_file.name if uploaded_yaml_file else None, ) if "active_param_identity" not in st.session_state: @@ -193,56 +243,53 @@ def handle_reset_excel(): params = st.session_state.params - if param_source == "YAML (.yaml)" and uploaded_yaml: - raw = yaml.safe_load(uploaded_yaml) or {} - model_defaults = flatten_dict(raw) + if param_source == "YAML (.yaml)" and uploaded_yaml_file: + raw_yaml = yaml.safe_load(uploaded_yaml_file) or {} + model_defaults = normalize_yaml_defaults(raw_yaml) else: model_defaults = load_model_params( selected_model_file, - uploaded_excel=uploaded_excel + uploaded_excel=uploaded_excel, ) - # Load Python Model Module to check for Labels + for key, value in get_leaf_defaults(model_defaults).items(): + params.setdefault(key, value) + model_module = load_model_from_file(selected_model_file) - # Check for SCENARIO_LABELS constant in the python file current_headers = getattr(model_module, "SCENARIO_LABELS", None) - if model_defaults: - # Define Python Model Reset Callback - def handle_reset_python(): - # 1. Reset Parameters - reset_parameters_to_defaults(model_defaults, params, model_key) - # 2. Reset Header Labels - if current_headers: - for key, default_text in current_headers.items(): - # We use a unique key format for python models to avoid conflicts - st.session_state[f"py_label_{model_key}_{key}"] = default_text - + def handle_reset_python() -> None: + """Reset Python-model parameters and output labels to defaults.""" + reset_parameters_to_defaults(model_defaults, params, model_key) + if current_headers: + for key, default_text in current_headers.items(): + st.session_state[f"py_label_{model_key}_{key}"] = default_text - st.sidebar.button("Reset Parameters", on_click=handle_reset_python) + st.sidebar.button("Reset Parameters", on_click=handle_reset_python) - # SCENARIO LABELS (PYTHON) - if current_headers: - with st.sidebar.expander("Output Scenario Headers", expanded=False): - st.caption("Rename the output headers") + if current_headers: + with st.sidebar.expander("Output Scenario Headers", expanded=False): + st.caption("Rename the output headers") - for key, default_text in current_headers.items(): - widget_key = f"py_label_{model_key}_{key}" + for key, default_text in current_headers.items(): + widget_key = f"py_label_{model_key}_{key}" - if widget_key in st.session_state: - new_val = st.text_input(f"Label for '{default_text}'", key=widget_key) - else: - new_val = st.text_input(f"Label for '{default_text}'", value=default_text, key=widget_key) + if widget_key in st.session_state: + new_val = st.text_input(f"Label for '{default_text}'", key=widget_key) + else: + new_val = st.text_input( + f"Label for '{default_text}'", + value=default_text, + key=widget_key, + ) - label_overrides[key] = new_val + label_overrides[key] = new_val if str(new_val).strip() else default_text - render_parameters_with_indent( - model_defaults, - params, - model_id=model_key - ) - else: - st.sidebar.info("No default parameters defined for this model.") + render_parameters_with_indent( + model_defaults, + params, + model_id=model_key, + ) # RUN SIMULATION if st.sidebar.button("Run Simulation"): @@ -263,7 +310,7 @@ def handle_reset_python(): filename=uploaded_excel_model.name, params=params, sheet_name=None, - label_overrides=label_overrides + label_overrides=label_overrides, ) st.title(results.get("model_title", "Excel Driven Model")) @@ -280,12 +327,22 @@ def handle_reset_python(): st.title(getattr(model_module, "model_title", app_config["title"])) st.write(getattr(model_module, "model_description", app_config["description"])) + run_params = ( + normalize_stlite_params(params, model_defaults) + if running_in_stlite() + else params + ) + # Check if run_model accepts label_overrides sig = inspect.signature(model_module.run_model) if "label_overrides" in sig.parameters: - results = model_module.run_model(params, label_overrides=label_overrides) + results = model_module.run_model(run_params, label_overrides=label_overrides) else: - results = model_module.run_model(params) + results = model_module.run_model(run_params) sections = model_module.build_sections(results) render_sections(sections) + +def running_in_stlite() -> bool: + """Return True when the app is running inside stlite/Pyodide.""" + return os.path.abspath(__file__).startswith("/home/pyodide/") diff --git a/scripts/build_index.py b/scripts/build_index.py new file mode 100644 index 0000000..9100d07 --- /dev/null +++ b/scripts/build_index.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import argparse +import html +import json +from pathlib import Path + + +PYODIDE_PACKAGES: list[str] = ["pyyaml", "pandas", "openpyxl"] +MOUNT_DIRS: tuple[str, ...] = ("models", "utils", "config", "styles") +TEXT_SUFFIXES: tuple[str, ...] = ( + ".py", + ".yaml", + ".yml", + ".css", + ".json", + ".md", + ".txt", +) + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--app", required=True, help="Path to the Streamlit entrypoint.") + parser.add_argument("--out", required=True, help="Path to the generated index.html file.") + parser.add_argument("--css", required=True, help="stlite CSS URL.") + parser.add_argument("--js", required=True, help="stlite JS URL.") + parser.add_argument("--title", default="EpiCON Cost Calculator", help="HTML page title.") + return parser.parse_args() + + +def should_mount_file(path: Path) -> bool: + """Return True when a file should be embedded into the stlite bundle.""" + if not path.is_file(): + return False + + if path.name.startswith("."): + return False + + if "__pycache__" in path.parts: + return False + + if path.suffix == ".pyc": + return False + + return path.suffix.lower() in TEXT_SUFFIXES + + +def collect_files(project_root: Path, app_path: Path) -> dict[str, str]: + """Collect source files that must be mounted into the stlite virtual filesystem.""" + mounted_files: dict[str, str] = {} + + files_to_mount: list[Path] = [app_path] + + for dirname in MOUNT_DIRS: + directory = project_root / dirname + if directory.exists(): + files_to_mount.extend(sorted(directory.rglob("*"))) + + for path in files_to_mount: + if not should_mount_file(path): + continue + + relative_path = path.relative_to(project_root).as_posix() + mounted_files[relative_path] = path.read_text(encoding="utf-8") + + return mounted_files + + +def build_html( + *, + title: str, + css_url: str, + js_url: str, + entrypoint: str, + mounted_files: dict[str, str], +) -> str: + """Build the stlite HTML document.""" + title_html = html.escape(title, quote=True) + css_html = html.escape(css_url, quote=True) + js_json = json.dumps(js_url) + entrypoint_json = json.dumps(entrypoint) + files_json = json.dumps(mounted_files, ensure_ascii=False) + packages_json = json.dumps(PYODIDE_PACKAGES) + + return f""" + + + + + {title_html} + + + +
+ + + +""" + + +def main() -> None: + """Generate the stlite index.html file.""" + args = parse_args() + + script_path = Path(__file__).resolve() + project_root = script_path.parent.parent.resolve() + app_path = (project_root / args.app).resolve() + out_path = (project_root / args.out).resolve() + + mounted_files = collect_files(project_root, app_path) + html_text = build_html( + title=args.title, + css_url=args.css, + js_url=args.js, + entrypoint=app_path.relative_to(project_root).as_posix(), + mounted_files=mounted_files, + ) + + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(html_text, encoding="utf-8") + + print( + f"Written: {out_path.relative_to(project_root)} " + f"({len(html_text.encode('utf-8')):,} bytes, {len(mounted_files)} files mounted)" + ) + print(f"Pyodide packages: {', '.join(PYODIDE_PACKAGES)}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/parameter_loader.py b/utils/parameter_loader.py index 902d0a6..2705511 100644 --- a/utils/parameter_loader.py +++ b/utils/parameter_loader.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import os import yaml import pandas as pd from decimal import Decimal import re +from typing import Any def load_params_from_excel(excel_file): @@ -31,11 +34,11 @@ def load_params_from_excel(excel_file): return params -def flatten_dict(d, level=0): - flat = {} +def flatten_dict(d: dict[str, Any], level: int = 0) -> dict[str, Any]: + """Flatten a nested dictionary for indented UI rendering.""" + flat: dict[str, Any] = {} for key, value in d.items(): indented_key = ("\t" * level) + str(key) - if isinstance(value, dict): flat[indented_key] = None flat.update(flatten_dict(value, level + 1)) @@ -44,8 +47,19 @@ def flatten_dict(d, level=0): return flat -def load_model_params(model_file_path, uploaded_excel=None): - if uploaded_excel: +def get_leaf_defaults(flat_params: dict[str, Any]) -> dict[str, Any]: + """Return editable leaf parameters with indentation removed.""" + cleaned: dict[str, Any] = {} + for key, value in flat_params.items(): + if value is None: + continue + cleaned[str(key).lstrip("\t")] = value + return cleaned + + +def load_model_params(model_file_path: str, uploaded_excel=None) -> dict[str, Any]: + """Load model parameters from Excel or the paired YAML file.""" + if uploaded_excel is not None: return load_params_from_excel(uploaded_excel) base = os.path.dirname(model_file_path) @@ -55,7 +69,13 @@ def load_model_params(model_file_path, uploaded_excel=None): if not os.path.exists(yaml_path): return {} - with open(yaml_path, "r") as f: - raw = yaml.safe_load(f) or {} + with open(yaml_path, "r", encoding="utf-8") as file: + raw = yaml.safe_load(file) or {} + + if isinstance(raw, dict) and isinstance(raw.get("parameters"), dict): + raw = raw["parameters"] + + if not isinstance(raw, dict): + return {} return flatten_dict(raw) diff --git a/utils/parameter_ui.py b/utils/parameter_ui.py index 445796f..a20a0ed 100644 --- a/utils/parameter_ui.py +++ b/utils/parameter_ui.py @@ -1,111 +1,48 @@ import streamlit as st -def reset_parameters_to_defaults(param_dict, params, model_id): - """ - Resets Streamlit session state widgets and params dict to values found in param_dict. - """ - items = list(param_dict.items()) - i = 0 - n = len(items) - - while i < n: - key, value = items[i] - level = len(key) - len(key.lstrip("\t")) - label = key.strip() +def reset_parameters_to_defaults( + param_dict: dict, params: dict, model_id: str +) -> None: + """Reset editable parameters to their default values.""" + for label, value in param_dict.items(): + clean_label = str(label).lstrip("\t") if value is None: - # Handle Children - j = i + 1 - while j < n: - subkey, subval = items[j] - sublevel = len(subkey) - len(subkey.lstrip("\t")) - - if sublevel <= level: - break - - # Reset logic for child - if sublevel == level + 1 and subval is not None: - sublabel = subkey.strip() - widget_key = f"{model_id}:{label}:{sublabel}" - - st.session_state[widget_key] = str(subval) - params[sublabel] = str(subval) - j += 1 - i = j continue - # Reset logic for Top-level - widget_key = f"{model_id}:{label}" - st.session_state[widget_key] = str(value) - params[label] = str(value) - i += 1 + params[clean_label] = value + st.session_state[f"{model_id}_{clean_label}"] = str(value) -def render_parameters_with_indent(param_dict, params, model_id): - """ - Render hierarchical parameters with indentation-based expanders. - Checks session_state before setting 'value' to avoid DuplicateWidgetID/API warnings. - """ - items = list(param_dict.items()) - i = 0 - n = len(items) - - while i < n: - key, value = items[i] - - level = len(key) - len(key.lstrip("\t")) - label = key.strip() +def render_parameters_with_indent( + param_dict: dict, params: dict, model_id: str +) -> None: + """Render parameter inputs without overwriting defaults with blank values.""" + for label, value in param_dict.items(): + indent_level = len(label) - len(label.lstrip("\t")) + clean_label = label.lstrip("\t") if value is None: - # Collect all children - children = [] - j = i + 1 - while j < n: - subkey, subval = items[j] - sublevel = len(subkey) - len(subkey.lstrip("\t")) - - if sublevel <= level: - break - - if sublevel == level + 1 and subval is not None: - sublabel = subkey.strip() - children.append((sublabel, subval)) - - j += 1 - - # Render the expander with all children inside - with st.sidebar.expander(label, expanded=False): - for sublabel, subval in children: - widget_key = f"{model_id}:{label}:{sublabel}" - if widget_key in st.session_state: - params[sublabel] = st.text_input( - sublabel, - key=widget_key - ) - else: - # First load: Pass the default value - params[sublabel] = st.text_input( - sublabel, - value=str(subval), - key=widget_key - ) - - i = j + st.sidebar.markdown( + ( + f"
{clean_label}
" + ), + unsafe_allow_html=True, + ) continue - # TOP-LEVEL PARAMS - widget_key = f"{model_id}:{label}" + widget_key = f"{model_id}_{clean_label}" + current_value = params.get(clean_label, value) - if widget_key in st.session_state: - params[label] = st.sidebar.text_input( - label, - key=widget_key - ) + user_value = st.sidebar.text_input( + clean_label, + value=str(current_value), + key=widget_key, + ) + + if str(user_value).strip(): + params[clean_label] = user_value else: - params[label] = st.sidebar.text_input( - label, - value=str(value), - key=widget_key - ) - i += 1 + params.setdefault(clean_label, value)