Support configurable ROCm path (ROCM_PATH / --rocm-path)#75
Support configurable ROCm path (ROCM_PATH / --rocm-path)#75
Conversation
MAD previously assumed ROCm was under /opt/rocm. Newer Rock tar/whl packages often do not create that path, which caused 'Unable to determine gpu vendor' when amd-smi was installed elsewhere. - Add get_rocm_path(override); resolution: override -> ROCM_PATH env -> /opt/rocm. - Context accepts optional rocm_path for vendor detection, rocminfo, .info/version, docker env. - ROCmValidator and detect_gpu_vendor use configurable rocm_path. - RunModels passes args.rocm_path to Context and uses ctx for amd-smi paths. - Add madengine run --rocm-path. Tests: is_amd_gpu() uses get_rocm_path(). Backward compatible when ROCM_PATH and --rocm-path are unset.
There was a problem hiding this comment.
Pull request overview
Adds support for non-default ROCm installations by introducing a configurable ROCm root path (CLI --rocm-path and ROCM_PATH env) and wiring it through vendor detection, ROCm validation, and model execution paths.
Changes:
- Add
get_rocm_path()helper to resolve ROCm root viaoverride -> ROCM_PATH -> /opt/rocm. - Plumb resolved
rocm_paththroughContext, GPU vendor detection/validation, andrun_models(including Docker env var propagation). - Update integration test AMD detection to respect configurable ROCm paths.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_gpu_renderD_nodes.py | Uses get_rocm_path() to locate amd-smi for AMD detection in integration tests. |
| src/madengine/utils/gpu_validator.py | Makes ROCm validation and vendor detection accept an optional rocm_path. |
| src/madengine/tools/run_models.py | Passes --rocm-path into Context and uses context ROCm path for amd-smi invocation. |
| src/madengine/mad.py | Adds madengine run --rocm-path CLI option. |
| src/madengine/core/context.py | Stores/resolves ROCm path in context, exports ROCM_PATH into Docker env, and uses it for vendor detection/rocminfo/version file. |
| src/madengine/core/constants.py | Introduces get_rocm_path() helper for ROCm root resolution. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/madengine/utils/gpu_validator.py
Outdated
| # Try version file | ||
| version_file = '/opt/rocm/.info/version' | ||
| version_file = os.path.join(self.rocm_path, '.info', 'version') | ||
| if os.path.exists(version_file): |
There was a problem hiding this comment.
ROCm path is now configurable, but ROCmValidator still runs hipconfig from PATH. If ROCm is installed under a non-default location and the user only sets ROCM_PATH/--rocm-path, version detection will fail even though {rocm_path}/bin/hipconfig exists. Use the resolved self.rocm_path (e.g., run the absolute hipconfig path or prepend {rocm_path}/bin to PATH for subprocess calls).
src/madengine/core/context.py
Outdated
| # ROCM_PATH allows non-default ROCm installs (e.g. Rock tar/whl) when /opt/rocm is not present. | ||
| rocm_path_escaped = self._rocm_path.replace("'", "'\"'\"'") | ||
| return self.console.sh( | ||
| 'bash -c \'if [[ -f /usr/bin/nvidia-smi ]] && $(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo "NVIDIA"; elif [[ -f /opt/rocm/bin/amd-smi ]]; then echo "AMD"; elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; else echo "Unable to detect GPU vendor"; fi || true\'' | ||
| f'ROCM_PATH="{rocm_path_escaped}" bash -c \'if [[ -f /usr/bin/nvidia-smi ]] && $(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo "NVIDIA"; elif [[ -f "${{ROCM_PATH}}/bin/amd-smi" ]]; then echo "AMD"; elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; else echo "Unable to detect GPU vendor"; fi || true\'' | ||
| ) |
There was a problem hiding this comment.
get_gpu_vendor() builds a shell command by interpolating self._rocm_path into the command string. Since --rocm-path / ROCM_PATH is user-controlled and Console.sh runs with shell=True, this is vulnerable to shell injection / quoting breakage (escaping only ' is not sufficient). Prefer passing ROCM_PATH via the env= parameter to Console.sh (or use shlex.quote and avoid nested shell where possible).
src/madengine/core/context.py
Outdated
| # Get ROCm version | ||
| rocm_version_str = self.console.sh("cat /opt/rocm/.info/version | cut -d'-' -f1") | ||
| version_file = os.path.join(self._rocm_path, ".info", "version") | ||
| rocm_version_str = self.console.sh(f"cat {version_file} | cut -d'-' -f1") |
There was a problem hiding this comment.
version_file comes from configurable self._rocm_path but is interpolated into cat {version_file} without quoting. A ROCm path containing spaces/metacharacters will break the command and can lead to shell injection because Console.sh uses shell=True. Quote version_file (or read the file directly in Python and avoid shell execution).
| rocm_version_str = self.console.sh(f"cat {version_file} | cut -d'-' -f1") | |
| try: | |
| with open(version_file, "r", encoding="utf-8") as vf: | |
| rocm_version_str = vf.read() | |
| except OSError as e: | |
| raise RuntimeError(f"Failed to read ROCm version file '{version_file}': {e}") |
| self.rocm_path = get_rocm_path(rocm_path) | ||
| self.ESSENTIAL_PATHS = { | ||
| 'rocm_root': self.rocm_path, | ||
| 'hip_path': os.path.join(self.rocm_path, 'bin', 'hipconfig'), | ||
| 'rocminfo': os.path.join(self.rocm_path, 'bin', 'rocminfo'), | ||
| } | ||
| self.RECOMMENDED_PATHS = { | ||
| 'amd_smi': os.path.join(self.rocm_path, 'bin', 'amd-smi'), | ||
| 'rocm_smi': os.path.join(self.rocm_path, 'bin', 'rocm-smi'), | ||
| } |
There was a problem hiding this comment.
ROCmValidator now accepts rocm_path and builds absolute paths in ESSENTIAL_PATHS/RECOMMENDED_PATHS, but later validation still executes tools via PATH (e.g., rocminfo, amd-smi, rocm-smi). This makes rocm_path ineffective for non-default installs unless the user also modifies PATH. Consider having _run_command inject {self.rocm_path}/bin into PATH (or call the absolute tool paths from these dicts) so validation is consistent with the configured ROCm root.
src/madengine/core/context.py
Outdated
| rocminfo_path = os.path.join(self._rocm_path, "bin", "rocminfo") | ||
| arch = self.console.sh(f"{rocminfo_path} |grep -o -m 1 'gfx.*'") | ||
| if not arch or arch.strip() == "": |
There was a problem hiding this comment.
rocminfo_path is derived from ROCM_PATH/--rocm-path but is injected into a shell pipeline without quoting. If the path contains spaces or shell metacharacters, the command can break or be exploited. Quote the path (e.g., via shlex.quote) or avoid the shell pipeline by executing rocminfo directly and parsing output in Python.
src/madengine/tools/run_models.py
Outdated
| amd_smi_path = os.path.join(self.context.ctx["rocm_path"], "bin", "amd-smi") | ||
| self.console.sh(f"{amd_smi_path} || true") | ||
| elif gpu_vendor.find("NVIDIA") != -1: |
There was a problem hiding this comment.
amd_smi_path is constructed from configurable rocm_path and then concatenated into a shell command without quoting. If the path contains spaces/metacharacters, this can break execution or enable shell injection. Quote the path (or pass it as an argv list via subprocess rather than a shell string).
src/madengine/tools/run_models.py
Outdated
| amd_smi_path = os.path.join(self.context.ctx["rocm_path"], "bin", "amd-smi") | ||
| smi = model_docker.sh(f"{amd_smi_path} || true") | ||
| elif gpu_vendor.find("NVIDIA") != -1: |
There was a problem hiding this comment.
Same issue here: amd_smi_path comes from configurable rocm_path and is injected into a shell command without quoting. This can break for paths with spaces and is a shell-injection risk. Quote the path or execute via subprocess without shell=True.
| rocm_path = get_rocm_path() | ||
| amd_smi_path = os.path.join(rocm_path, "bin", "amd-smi") | ||
| vendor = console.sh( | ||
| 'bash -c \'if [[ -f /opt/rocm/bin/amd-smi ]]; then echo "AMD"; elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; else echo "OTHER"; fi || true\'' | ||
| f'bash -c \'if [[ -f "{amd_smi_path}" ]]; then echo "AMD"; elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; else echo "OTHER"; fi || true\'' | ||
| ) |
There was a problem hiding this comment.
amd_smi_path is interpolated into a bash -c string. If ROCM_PATH contains a double quote, the test command can break (and, since Console.sh uses shell=True, it can also become an injection vector). Consider using shlex.quote or passing the path via an env var to the shell snippet instead of embedding it directly.
- Resolve conflicts in context.py and run_models.py: keep configurable ROCM_PATH and main's rocm-smi fallback for vendor detection and SMI echo. - Pass ROCM_PATH via subprocess env in get_gpu_vendor (no shell interpolation). - Read ROCm version from .info/version in Python; run rocminfo via subprocess for architecture; quote ROCm bin paths (shlex) for shell pipelines. - Console.sh: merge partial env with os.environ so inherited PATH etc. remain. - get_rocm_path: use normpath so '/' is not stripped to empty. - ROCmValidator: run hipconfig/rocminfo/amd-smi/rocm-smi from rocm_path/bin when those binaries exist. - Integration test is_amd_gpu: detect AMD via ROCM_PATH env in bash snippet. Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ro = self.context.ctx["rocm_path"] | ||
| amd_smi = shlex.quote(os.path.join(ro, "bin", "amd-smi")) | ||
| rocm_smi = shlex.quote(os.path.join(ro, "bin", "rocm-smi")) |
There was a problem hiding this comment.
shlex.quote() is safe for POSIX shells, but Docker.sh() embeds the entire command inside a double-quoted bash -c "..." string. If rocm_path contains a single quote, shlex.quote() will emit a "'" sequence (includes double quotes) which can break the surrounding quoting and potentially allow command injection. Consider avoiding shlex.quote() here (or escaping for a double-quoted context), or updating Docker.sh() to pass the command without wrapping it in unescaped double quotes (e.g., pass via stdin / use an argv form).
| ro = self.context.ctx["rocm_path"] | |
| amd_smi = shlex.quote(os.path.join(ro, "bin", "amd-smi")) | |
| rocm_smi = shlex.quote(os.path.join(ro, "bin", "rocm-smi")) | |
| def shell_double_quoted(value): | |
| return '"' + value.replace("\\", "\\\\").replace('"', '\\"').replace("$", "\\$").replace("`", "\\`") + '"' | |
| ro = self.context.ctx["rocm_path"] | |
| amd_smi = shell_double_quoted(os.path.join(ro, "bin", "amd-smi")) | |
| rocm_smi = shell_double_quoted(os.path.join(ro, "bin", "rocm-smi")) |
| """ | ||
| # Check if the GPU vendor is NVIDIA or AMD, and if it is unable to detect the GPU vendor. | ||
| # ROCM_PATH via subprocess env avoids embedding user-controlled paths in shell strings. | ||
| vendor_env = {**os.environ, "ROCM_PATH": self._rocm_path} |
There was a problem hiding this comment.
Console.sh() now merges any provided env with os.environ, so building vendor_env by copying os.environ here is redundant. You can pass just {"ROCM_PATH": self._rocm_path} (and/or any overrides) to keep the intent clear and avoid an extra full env copy.
| vendor_env = {**os.environ, "ROCM_PATH": self._rocm_path} | |
| vendor_env = {"ROCM_PATH": self._rocm_path} |
| rocm_path = get_rocm_path() | ||
| vendor_env = {**os.environ, "ROCM_PATH": rocm_path} | ||
| vendor = console.sh( | ||
| 'bash -c \'if [[ -f /opt/rocm/bin/amd-smi ]]; then echo "AMD"; elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; else echo "OTHER"; fi || true\'' | ||
| 'bash -c \'if [[ -f "${ROCM_PATH}/bin/amd-smi" ]]; then echo "AMD"; ' | ||
| 'elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; ' | ||
| 'else echo "OTHER"; fi || true\'', | ||
| env=vendor_env, |
There was a problem hiding this comment.
Console.sh() now merges a provided env dict with the inherited environment, so {**os.environ, "ROCM_PATH": rocm_path} is redundant here. Passing just {"ROCM_PATH": rocm_path} keeps the test simpler and avoids copying the full environment.
| """Prefer ``{rocm_path}/bin/{name}`` when present, else bare *name* on PATH.""" | ||
| full = os.path.join(self.rocm_path, "bin", name) | ||
| if os.path.isfile(full): |
There was a problem hiding this comment.
_rocm_tool_cmd() only checks os.path.isfile(full) before returning the absolute path. If the file exists but is not executable (or is a directory), the subsequent subprocess.run() will fail with a less clear error. Consider using os.access(full, os.X_OK) (or Path(full).is_file() + os.access) so the fallback to PATH is used unless the local binary is actually runnable.
| """Prefer ``{rocm_path}/bin/{name}`` when present, else bare *name* on PATH.""" | |
| full = os.path.join(self.rocm_path, "bin", name) | |
| if os.path.isfile(full): | |
| """Prefer ``{rocm_path}/bin/{name}`` when runnable, else bare *name* on PATH.""" | |
| full = os.path.join(self.rocm_path, "bin", name) | |
| if os.path.isfile(full) and os.access(full, os.X_OK): |
MAD previously assumed ROCm was under /opt/rocm. Newer Rock tar/whl packages often do not create that path, which caused 'Unable to determine gpu vendor' when amd-smi was installed elsewhere.
Backward compatible when ROCM_PATH and --rocm-path are unset.