diff --git a/gator/common/utility.py b/gator/common/utility.py index 8ff4c59..6ecfb3e 100644 --- a/gator/common/utility.py +++ b/gator/common/utility.py @@ -34,15 +34,13 @@ def get_username() -> str: @overload def as_couroutine( fn: Callable[_P, Union[_R, Awaitable[_R]]], - ) -> Callable[_P, Awaitable[_R]]: - ... + ) -> Callable[_P, Awaitable[_R]]: ... except ImportError: @overload def as_couroutine( fn: Callable[..., Union[_R, Awaitable[_R]]], - ) -> Callable[..., Awaitable[_R]]: - ... + ) -> Callable[..., Awaitable[_R]]: ... def as_couroutine(fn): @@ -56,3 +54,80 @@ async def async_fn(*args, **kwargs): return fn(*args, **kwargs) return async_fn + + +def find_command_substitutions(text: str) -> list[tuple[int, int, str]]: + """ + Find all command substitutions $(cmd) and `cmd` in the text, handling + nested parentheses and quoted strings correctly. + + Returns a list of (start_pos, end_pos, original_text) tuples. + """ + substitutions = [] + i = 0 + + while i < len(text): + # Check for $( + if i < len(text) - 1 and text[i : i + 2] == "$(": + start = i + i += 2 + depth = 1 + in_single_quote = False + in_double_quote = False + + # Find matching closing parenthesis, respecting quotes + while i < len(text) and depth > 0: + char = text[i] + + # Handle backslash escaping (only in double quotes or outside quotes) + if char == "\\" and not in_single_quote and i + 1 < len(text): + i += 2 # Skip the backslash and next character + continue + + # Handle single quotes (toggle state, but not inside double quotes) + if char == "'" and not in_double_quote: + in_single_quote = not in_single_quote + i += 1 + continue + + # Handle double quotes (toggle state, but not inside single quotes) + if char == '"' and not in_single_quote: + in_double_quote = not in_double_quote + i += 1 + continue + + # Only count parentheses when not inside any quotes + if not in_single_quote and not in_double_quote: + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + + i += 1 + + if depth == 0: + # Found matching closing paren + substitutions.append((start, i, text[start:i])) + # else: unmatched - let it through and shell will error + + # Check for backticks + elif text[i] == "`": + start = i + i += 1 + + # Find closing backtick (no nesting for backticks) + while i < len(text) and text[i] != "`": + # Handle escaped backticks + if text[i] == "\\" and i + 1 < len(text): + i += 2 + else: + i += 1 + + if i < len(text) and text[i] == "`": + i += 1 + substitutions.append((start, i, text[start:i])) + # else: unmatched - let it through + else: + i += 1 + + return substitutions diff --git a/gator/scheduler/local.py b/gator/scheduler/local.py index 7dfee01..ea70105 100644 --- a/gator/scheduler/local.py +++ b/gator/scheduler/local.py @@ -103,8 +103,8 @@ async def _inner(): async with self.update_lock: # Launch jobs self.slots[task.ident] = granted - self.launched_processes[task.ident] = await asyncio.create_subprocess_shell( - " ".join(self.create_command(task, {"concurrency": granted})), + self.launched_processes[task.ident] = await asyncio.create_subprocess_exec( + *self.create_command(task, {"concurrency": granted}), stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT, diff --git a/gator/wrapper.py b/gator/wrapper.py index cea0c51..2930f4f 100644 --- a/gator/wrapper.py +++ b/gator/wrapper.py @@ -27,6 +27,7 @@ from .common.layer import BaseLayer, MetricResponse, UsageResponse from .common.summary import Summary from .common.types import Attribute, JobResult, LogSeverity, ProcStat +from .common.utility import find_command_substitutions class Wrapper(BaseLayer): @@ -246,6 +247,24 @@ async def __launch(self) -> None: # Expand variables in the command command = expandvars.expand(self.spec.command, environ=env) args = [expandvars.expand(str(arg), environ=env) for arg in self.spec.args] + # Check for command substitutions and warn user they won't be evaluated (unless using shell) + common_shells = {"sh", "bash", "zsh", "ksh", "csh", "tcsh", "fish", "dash"} + is_shell_command = Path(command).name in common_shells + if not is_shell_command: + detected_substitutions = [] + for substitution in find_command_substitutions(command): + detected_substitutions.append(substitution[2]) # Extract original_text + for arg in args: + for substitution in find_command_substitutions(arg): + detected_substitutions.append(substitution[2]) # Extract original_text + if detected_substitutions: + substitutions_list = "\n".join(f" {sub}" for sub in detected_substitutions) + await self.logger.warning( + "Gator only supports simple environment variable substitutions. " + "Command substitutions will pass through to the command without being " + "evaluated by Gator.\n" + f"Detected command substitutions:\n{substitutions_list}" + ) full_cmd = shlex.join((command, *args)) # Ensure the tracking directory exists self.tracking.mkdir(parents=True, exist_ok=True) diff --git a/pyproject.toml b/pyproject.toml index 1e1089c..16dbda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ psutil = "^5.9.4" rich = "^13.3.4" tabulate = "^0.9.0" pyyaml = "^6.0" -expandvars = "^0.9.0" +expandvars = "^1.1.2" websockets = "^11.0.2" aiosqlite = "^0.19.0" aiohttp = "^3.12.13" diff --git a/tests/test_command_substitution.py b/tests/test_command_substitution.py new file mode 100644 index 0000000..ceb525e --- /dev/null +++ b/tests/test_command_substitution.py @@ -0,0 +1,294 @@ +# Copyright 2024, Peter Birch, mailto:peter@lightlogic.co.uk +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gator.common.utility import find_command_substitutions + + +class TestFindCommandSubstitutions: + """Test suite for finding command substitutions in text""" + + def test_find_simple_command_substitution(self): + """Test finding simple $(cmd) patterns""" + text = "echo $(hostname)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (5, 16, "$(hostname)") + + def test_find_nested_command_substitution(self): + """Test finding nested $(cmd $(cmd)) patterns""" + text = "echo $(echo $(whoami))" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (5, 22, "$(echo $(whoami))") + + def test_find_complex_nested_command_substitution(self): + """Test finding complex nested patterns like $(date +%Y-$(date +%m))""" + text = "echo $(date +%Y-$(date +%m))" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (5, 28, "$(date +%Y-$(date +%m))") + + def test_find_backticks(self): + """Test finding `cmd` backtick patterns""" + text = "echo `hostname`" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (5, 15, "`hostname`") + + def test_find_escaped_backticks(self): + """Test handling escaped backticks""" + text = r"echo `echo \`nested\``" + subs = find_command_substitutions(text) + assert len(subs) == 1 + # Should capture the whole thing including escaped backticks + assert subs[0][2] == r"`echo \`nested\``" + + def test_find_multiple_substitutions(self): + """Test finding multiple command substitutions in one string""" + text = "echo $(cmd) and `another`" + subs = find_command_substitutions(text) + assert len(subs) == 2 + assert subs[0] == (5, 11, "$(cmd)") + assert subs[1] == (16, 25, "`another`") + + def test_find_with_variables(self): + """Test that variables don't interfere with finding command substitutions""" + text = "echo $HOME and $(hostname)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (15, 26, "$(hostname)") + + def test_find_unmatched_paren(self): + """Test that unmatched parentheses are handled gracefully""" + text = "echo $(incomplete" + subs = find_command_substitutions(text) + # Should not find anything or handle gracefully + assert len(subs) == 0 + + def test_multiple_nested_levels(self): + """Test deeply nested command substitutions""" + text = "echo $(outer $(middle $(inner)))" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(outer $(middle $(inner)))" + + def test_parentheses_in_double_quotes(self): + """Test that parentheses inside double quotes are handled correctly""" + text = '$(echo ")")' + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (0, 11, '$(echo ")")') + + def test_parentheses_in_single_quotes(self): + """Test that parentheses inside single quotes are handled correctly""" + text = "$(echo ')')" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (0, 11, "$(echo ')')") + + def test_escaped_parentheses(self): + """Test that escaped parentheses are handled correctly""" + text = r"$(echo \))" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0] == (0, 10, r"$(echo \))") + + def test_complex_quoting(self): + """Test complex quoting scenarios""" + text = """$(echo "foo (bar)" 'baz (qux)' end)""" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == """$(echo "foo (bar)" 'baz (qux)' end)""" + + def test_nested_quotes(self): + """Test nested command substitution with quotes""" + text = '$(echo "outer $(echo inner)")' + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == '$(echo "outer $(echo inner)")' + + def test_mixed_quotes_in_nested(self): + """Test nested command substitution with mixed quotes""" + text = """$(echo "$(echo 'nested')")""" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == """$(echo "$(echo 'nested')")""" + + def test_backslash_in_double_quotes(self): + """Test backslash escaping inside double quotes""" + text = r'$(echo "foo \" bar")' + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == r'$(echo "foo \" bar")' + + def test_backslash_in_single_quotes(self): + """Test that backslashes are literal in single quotes""" + # Backslashes are literal in single quotes, so this is valid: + text = r"$(echo 'foo \ bar')" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == r"$(echo 'foo \ bar')" + + def test_dollar_in_single_quotes(self): + """Test that $ is literal in single quotes""" + text = "$(echo '$HOME')" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo '$HOME')" + + def test_dollar_in_double_quotes(self): + """Test that $ is special in double quotes""" + text = '$(echo "$HOME")' + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == '$(echo "$HOME")' + + def test_multiple_levels_with_quotes(self): + """Test deeply nested substitutions with quotes""" + text = "$(outer \"$(middle '$(inner)')\")" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(outer \"$(middle '$(inner)')\")" + + def test_backticks_with_quotes(self): + """Test backticks with quoted content""" + text = '`echo "hello (world)"`' + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == '`echo "hello (world)"`' + + def test_backticks_with_nested_backticks_escaped(self): + """Test escaped backticks inside backtick command substitution""" + text = r"`echo \`nested\``" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == r"`echo \`nested\``" + + def test_empty_quotes(self): + """Test empty quoted strings""" + text = "$(echo \"\" '')" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo \"\" '')" + + def test_adjacent_quotes(self): + """Test adjacent quoted strings""" + text = "$(echo \"foo\"'bar')" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo \"foo\"'bar')" + + def test_escaped_dollar_outside_quotes(self): + """Test that escaped dollars are preserved""" + text = r"$(echo \$HOME)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == r"$(echo \$HOME)" + + def test_multiple_commands_with_quotes(self): + """Test multiple command substitutions with various quoting""" + text = """$(echo "foo)") and $(echo ')') and `echo ")"`""" + subs = find_command_substitutions(text) + assert len(subs) == 3 + assert subs[0][2] == '$(echo "foo)")' + assert subs[1][2] == "$(echo ')')" + assert subs[2][2] == '`echo ")"`' + + def test_ansi_c_quoting(self): + """Test ANSI-C quoting $'...' with escape sequences""" + text = r"$(echo $'hello\nworld')" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == r"$(echo $'hello\nworld')" + + def test_arithmetic_expansion(self): + """Test arithmetic expansion $((...)) which has double parens""" + text = "$(echo $((1 + 2)))" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo $((1 + 2)))" + + def test_arithmetic_with_nested_parens(self): + """Test arithmetic with nested expressions""" + text = "$(echo $(($(echo 5) + 3)))" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo $(($(echo 5) + 3)))" + + def test_parameter_expansion(self): + """Test various parameter expansion forms""" + text = "$(echo ${var:-default})" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo ${var:-default})" + + def test_parameter_expansion_with_parens(self): + """Test parameter expansion with parentheses in default""" + text = '$(echo ${var:-"default (value)"})' + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == '$(echo ${var:-"default (value)"})' + + def test_command_with_semicolons(self): + """Test commands with semicolons""" + text = "$(echo foo; echo bar)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo foo; echo bar)" + + def test_command_with_pipes(self): + """Test commands with pipes""" + text = "$(echo foo | grep f)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo foo | grep f)" + + def test_command_with_redirects(self): + """Test commands with redirections""" + text = "$(cat < file.txt > output.txt)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(cat < file.txt > output.txt)" + + def test_glob_patterns(self): + """Test glob patterns inside command substitution""" + text = "$(ls *.txt)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(ls *.txt)" + + def test_brace_expansion(self): + """Test brace expansion""" + text = "$(echo {a,b,c})" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(echo {a,b,c})" + + def test_tilde_expansion(self): + """Test tilde expansion""" + text = "$(ls ~/Documents)" + subs = find_command_substitutions(text) + assert len(subs) == 1 + assert subs[0][2] == "$(ls ~/Documents)" + + def test_no_command_substitutions(self): + """Test text with no command substitutions""" + text = "echo hello world" + subs = find_command_substitutions(text) + assert len(subs) == 0 + + def test_empty_string(self): + """Test empty string""" + subs = find_command_substitutions("") + assert len(subs) == 0 diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py index 494c2c4..96ae4d8 100644 --- a/tests/test_local_scheduler.py +++ b/tests/test_local_scheduler.py @@ -52,7 +52,7 @@ async def test_local_scheduling(self, mocker, tmp_path): assert sched.quiet is False # Patch asyncio so we don't launch any real operations as_sub = mocker.patch( - "gator.scheduler.local.asyncio.create_subprocess_shell", + "gator.scheduler.local.asyncio.create_subprocess_exec", new=AsyncMock(), ) as_tsk = mocker.patch( @@ -87,11 +87,25 @@ def _create_proc(*_args, **_kwargs): as_sub.assert_has_calls( [ call( - f"python3 -m gator --limit-error=0 --limit-critical=0" - " --parent test:1234 --interval 7 --scheduler local --all-msg " - "--internal " - f"--id T{x} --tracking {(tmp_path / f'T{x}').as_posix()}" - " --sched-arg concurrency=1", + "python3", + "-m", + "gator", + "--limit-error=0", + "--limit-critical=0", + "--parent", + "test:1234", + "--interval", + "7", + "--scheduler", + "local", + "--all-msg", + "--internal", + "--id", + f"T{x}", + "--tracking", + (tmp_path / f"T{x}").as_posix(), + "--sched-arg", + "concurrency=1", stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, @@ -106,6 +120,87 @@ def _create_proc(*_args, **_kwargs): # Check all monitors were fired up as_mon.assert_has_calls([call(f"T{x}", y) for x, y in zip(range(10), procs)]) + async def test_scheduler_argument_safety(self, mocker, tmp_path): + """Ensure arguments with special characters are handled safely""" + # Create a scheduler + sched = LocalScheduler( + tracking=tmp_path / "tracking", + parent="test:1234", + interval=5, + quiet=True, + logger=self.logger, + ) + # Patch asyncio so we don't launch any real operations + as_sub = mocker.patch( + "gator.scheduler.local.asyncio.create_subprocess_exec", + new=AsyncMock(), + ) + mocker.patch.object( + sched, + "_LocalScheduler__monitor", + new=AsyncMock(wraps=sched._LocalScheduler__monitor), + ) + procs = [] + + def _create_proc(*args, **kwargs): + nonlocal procs + # Capture the command arguments to verify they're passed as a list + proc = AsyncMock() + procs.append((proc, args, kwargs)) + return proc + + as_sub.side_effect = _create_proc + + # Create a job with potentially dangerous characters in arguments + # These should be passed as literal arguments, not interpreted by shell + dangerous_job = Job( + "dangerous_test", + cwd=tmp_path.as_posix(), + command="echo", + args=[ + "hello; rm -rf /", + "$(whoami)", + "`id`", + "&& cat /etc/passwd", + ], + ) + + # Note: Since gator launches jobs using its own command structure, + # the Job spec's command and args won't be directly passed to subprocess. + # Instead, gator uses "python3 -m gator" with the job spec. + # However, the fix ensures ANY arguments in the command list are safe. + + # Launch the task + child = Child( + spec=dangerous_job, + ident="dangerous", + entry=MagicMock(), + tracking=tmp_path / "dangerous", + ) + await sched.launch([child]) + + # Wait for launch + await sched.launch_task + + # Verify create_subprocess_exec was called (not create_subprocess_shell) + assert as_sub.call_count == 1 + + # Get the command that was passed + _proc, cmd_args, cmd_kwargs = procs[0] + + # Verify it's a list of arguments (as positional args to exec) + assert len(cmd_args) > 0, "Command should be passed as positional arguments" + + # Verify the command starts with python3 -m gator (base command) + assert cmd_args[0] == "python3" + assert cmd_args[1] == "-m" + assert cmd_args[2] == "gator" + + # Verify stdin/stdout/stderr are redirected + assert cmd_kwargs["stdin"] == subprocess.DEVNULL + assert cmd_kwargs["stdout"] == subprocess.DEVNULL + assert cmd_kwargs["stderr"] == subprocess.STDOUT + async def test_local_scheduler_default_launch(self, mocker, tmp_path): """Check that launch() without `internal` flag uses Tier/scheduler for a single Job""" from gator.launch import launch diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index ab1b111..5cb027b 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -372,6 +372,109 @@ async def test_wrapper_metric(self, tmp_path, mocker) -> None: # Wait for task to complete await t_wrp + async def test_wrapper_command_substitution(self, tmp_path, mocker) -> None: + """Verify that command substitution $(cmd) and `cmd` are preserved""" + # Define a job specification with command substitution + job = Job( + "test", + cwd=tmp_path.as_posix(), + command="bash", + args=["-c", "echo Hello from $(hostname)"], + ) + # Create a wrapper + trk_dir = tmp_path / "tracking" + wrp = Wrapper(spec=job, client=self.client, tracking=trk_dir, logger=self.logger) + + # Mock the subprocess creation to inspect the command + original_exec = mocker.patch("asyncio.create_subprocess_exec") + mock_proc = AsyncMock() + mock_proc.pid = 12345 + mock_proc.returncode = 0 + mock_proc.wait = AsyncMock(return_value=0) + mock_proc.stdout = AsyncMock() + mock_proc.stdout.at_eof = lambda: True + mock_proc.stdout.readline = AsyncMock(return_value=b"") + mock_proc.stderr = AsyncMock() + mock_proc.stderr.at_eof = lambda: True + mock_proc.stderr.readline = AsyncMock(return_value=b"") + original_exec.return_value = mock_proc + + # Run the job + await wrp.launch() + + # Verify create_subprocess_exec was called + assert original_exec.called + call_args = original_exec.call_args + + # The first positional arg should be the command + assert call_args[0][0] == "bash" + # The second should be "-c" + assert call_args[0][1] == "-c" + # The third should contain $(hostname) - NOT (hostname)! + assert call_args[0][2] == "echo Hello from $(hostname)" + assert "$(hostname)" in call_args[0][2], "Command substitution should be preserved" + + async def test_wrapper_command_substitution_warning(self, tmp_path) -> None: + """Verify that a warning is logged when command substitutions are detected in + non-shell commands""" + # Define a job specification with command substitution in a non-shell command + job = Job( + "test", + cwd=tmp_path.as_posix(), + command="echo", + args=["Hello from $(hostname)", "and `date`"], + ) + # Create a wrapper + trk_dir = tmp_path / "tracking" + wrp = Wrapper(spec=job, client=self.client, tracking=trk_dir, logger=self.logger) + + # Run the job + await wrp.launch() + + # Check that a warning was logged + mcs = self.mk_db.push_logentry.mock_calls + warning_logs = [ + x.args[0] + for x in mcs + if x.args[0].severity is LogSeverity.WARNING + ] + assert len(warning_logs) > 0, "Expected at least one warning log entry" + + # Check the warning message contains the expected information + warning_msg = warning_logs[0].message + assert "Detected command substitutions:" in warning_msg + assert "$(hostname)" in warning_msg + assert "`date`" in warning_msg + assert "Gator only supports simple environment variable substitutions" in warning_msg + assert "will pass through to the command without being evaluated" in warning_msg + + async def test_wrapper_command_substitution_no_warning_for_shells(self, tmp_path) -> None: + """Verify that NO warning is logged when command substitutions are used with + shell commands""" + # Define a job specification with command substitution using bash + job = Job( + "test", + cwd=tmp_path.as_posix(), + command="bash", + args=["-c", "echo Hello from $(hostname)"], + ) + # Create a wrapper + trk_dir = tmp_path / "tracking" + wrp = Wrapper(spec=job, client=self.client, tracking=trk_dir, logger=self.logger) + + # Run the job + await wrp.launch() + + # Check that NO warning was logged about command substitutions + mcs = self.mk_db.push_logentry.mock_calls + warning_logs = [ + x.args[0] + for x in mcs + if x.args[0].severity is LogSeverity.WARNING + and "Command substitution" in x.args[0].message + ] + assert len(warning_logs) == 0, "Expected no warning for shell commands with substitutions" + async def test_wrapper_internal_launch(self, tmp_path, mocker) -> None: """Check that launch() with `internal=True` uses Wrapper for a single Job""" from gator.launch import launch