diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e698824..a687285 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,10 @@ repos: rev: v2.3.2 hooks: - id: autopep8 +- repo: https://github.com/jkittner/double-indent + rev: 0.1.5 + hooks: + - id: double-indent - repo: https://github.com/asottile/reorder-python-imports rev: v3.16.0 hooks: diff --git a/README.md b/README.md index c2f7509..e7bf9a5 100644 --- a/README.md +++ b/README.md @@ -3,14 +3,15 @@ # flake8-timeout -flake8 plugin which checks that a timeout is set in all `requests` and `urllib.request.open` calls. +A flake8 plugin that checks for missing `timeout` parameters in network calls. -- For example: `requests.post('https://example.com')` or `urllib.request.open('https://example.com')` will trigger `TIM100` -- `requests.post('https://example.com', timeout=5)` or `urllib.request.open('https://example.com', timeout=5)` is expected instead +By default, the plugin checks common HTTP libraries but can be configured to track any function that accepts a timeout parameter. ## installation -`pip install flake8-timeout` +```bash +pip install flake8-timeout +``` ## flake8 code @@ -18,7 +19,23 @@ flake8 plugin which checks that a timeout is set in all `requests` and `urllib.r | ------ | -------------------------------- | | TIM100 | timeout missing for request call | -## as a pre-commit hook +## default tracked functions + +The plugin tracks these functions by default: + +- `requests.get` +- `requests.post` +- `requests.put` +- `requests.delete` +- `requests.head` +- `requests.patch` +- `requests.options` +- `requests.request` +- `urllib.request.urlopen` (timeout at positional index 2) + +## configuration + +### as a pre-commit hook See [pre-commit](https://pre-commit.com) for instructions @@ -26,8 +43,67 @@ Sample `.pre-commit-config.yaml`: ```yaml - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-timeout==2.0.0] +``` + +### extending the defaults + +Use `--timeout-extend-funcs` to add custom functions while keeping the defaults: + +**Command line:** +```bash +flake8 --timeout-extend-funcs=my_http_lib.request,custom.api.call:1 +``` + +**Pre-commit:** +```yaml +- repo: https://github.com/pycqa/flake8 + rev: 7.0.0 hooks: - id: flake8 - additional_dependencies: [flake8-timeout==0.3.0] + additional_dependencies: [flake8-timeout==2.0.0] + args: [--timeout-extend-funcs=my_http_lib.request,custom.api.call:1] +``` + +This will check the defaults plus your custom functions. + +### overriding the defaults + +Use `--timeout-funcs` to replace the defaults entirely: + +**Command line:** +```bash +flake8 --timeout-funcs=custom.http.get,custom.http.post:2 +``` + +**Pre-commit:** +```yaml +- repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-timeout==2.0.0] + args: [--timeout-funcs=custom.http.get,custom.http.post:2] +``` + +This will only check the functions you specify, ignoring the defaults. + +### positional timeout arguments + +Some functions accept timeout as a positional argument. Specify the 0-based index after a colon: + +``` +my_lib.fetch:2 # timeout is at index 2 (3rd argument) +other.call:0 # timeout is at index 0 (1st argument) +``` + +Example with positional timeout: + +```python +# my_lib.fetch(url, data, timeout) +my_lib.fetch('https://api.example.com', None, 30) # OK - timeout at index 2 +my_lib.fetch('https://api.example.com', None) # TIM100 - missing timeout ``` diff --git a/flake8_timeout.py b/flake8_timeout.py index ac792e7..e585e11 100644 --- a/flake8_timeout.py +++ b/flake8_timeout.py @@ -1,71 +1,246 @@ +import argparse import ast import importlib.metadata as importlib_metadata from collections.abc import Generator from typing import Any -MSG = 'TIM100 request call has no timeout' +from flake8.options.manager import OptionManager -METHODS = [ - 'request', 'get', 'head', 'post', - 'patch', 'put', 'delete', 'options', +MSG = 'TIM100 request call has no timeout' +# Format: 'module.function' or 'module.function:positional_index' +DEFAULT_TRACKED_FUNCTIONS = [ + 'urllib.request.urlopen:2', # urlopen(url, data=None, timeout=...) + 'requests.get', # get(url, **kwargs) + 'requests.post', + 'requests.put', + 'requests.delete', + 'requests.head', + 'requests.patch', + 'requests.options', + 'requests.request', ] +def parse_function_spec(spec: str) -> tuple[tuple[str, str], int | None]: + # Split off positional index if present + if ':' in spec: + func_part, index_str = spec.rsplit(':', 1) + if not index_str.isdigit(): + raise ValueError( + f"Positional index must be an integer in spec: {spec}", + ) + positional_index = int(index_str) + else: + func_part = spec + positional_index = None + + # Parse the function part + parts = func_part.split('.') + if len(parts) < 2: + raise ValueError( + f"Function spec must be at least 'module.function': {spec}", + ) + + return ('.'.join(parts[:-1]), parts[-1]), positional_index + + class Visitor(ast.NodeVisitor): - def __init__(self) -> None: + def __init__( + self, + tracked_functions: set[tuple[str, str]], + timeout_positional: dict[str, int], + ) -> None: self.assignments: list[tuple[int, int]] = [] + # map local names to (module, attr) tuples + # 'urlopen': ('urllib.request', 'urlopen') + # 'request': ('urllib', 'request') for module imports + self.imports: dict[str, tuple[str, str | None]] = {} + self.tracked_functions = tracked_functions + self.timeout_positional = timeout_positional + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + local_name = ( + alias.asname if alias.asname else alias.name.split('.')[-1] + ) + self.imports[local_name] = (alias.name, None) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module is None: + self.generic_visit(node) + return + + for alias in node.names: + local_name = alias.asname if alias.asname else alias.name + self.imports[local_name] = (node.module, alias.name) + self.generic_visit(node) + + def _check_timeout( + self, + node: ast.Call, + func_spec: str | None = None, + ) -> bool: + # Check keyword arguments + for kwarg in node.keywords: + if ( + kwarg.arg == 'timeout' and + isinstance(kwarg.value, ast.Constant) and + kwarg.value.value is not None + ): + return True + + # Check posargs if function has a known positional timeout index + if func_spec and func_spec in self.timeout_positional: + pos_index = self.timeout_positional[func_spec] + if len(node.args) > pos_index: + return True + + return False def visit_Call(self, node: ast.Call) -> None: - if ( - isinstance(node, ast.Call) and - isinstance(node.func, ast.Attribute) and - isinstance(node.func.value, ast.Name) and - node.func.value.id == 'requests' and - node.func.attr in METHODS - ): - for kwarg in node.keywords: - if ( - kwarg.arg == 'timeout' and - isinstance(kwarg.value, ast.Constant) and - kwarg.value.value is not None - ): - break - else: + func_spec: str | None = None + + # direct function call + if isinstance(node.func, ast.Name): + func_name = node.func.id + if func_name in self.imports: + module, attr = self.imports[func_name] + # attr should be the function name for 'from X import Y' + if attr and (module, attr) in self.tracked_functions: + func_spec = f"{module}.{attr}" + + # attribute call + elif isinstance(node.func, ast.Attribute): + attr_name = node.func.attr + + # Check if the base is a Name (requests.get or request.urlopen) + if isinstance(node.func.value, ast.Name): + base_name = node.func.value.id + + if base_name in self.imports: + module, imported_attr = self.imports[base_name] + + # If imported_attr is None, it's a module import + if imported_attr is None: + # Check if module.attr is tracked + if (module, attr_name) in self.tracked_functions: + func_spec = f"{module}.{attr_name}" + else: + full_module = f"{module}.{imported_attr}" + if (full_module, attr_name) in self.tracked_functions: + func_spec = f"{full_module}.{attr_name}" + + # nested attribute: urllib.request.urlopen('url') + elif isinstance(node.func.value, ast.Attribute): + if isinstance(node.func.value.value, ast.Name): + base = node.func.value.value.id + middle = node.func.value.attr + func = attr_name + full_spec = f"{base}.{middle}.{func}" + module_part = f"{base}.{middle}" + + if (module_part, func) in self.tracked_functions: + func_spec = full_spec + + if func_spec: + if not self._check_timeout(node, func_spec): self.assignments.append((node.lineno, node.col_offset)) - elif ( - isinstance(node, ast.Call) and - isinstance(node.func, ast.Attribute) and - isinstance(node.func.value, ast.Attribute) and - isinstance(node.func.value.value, ast.Name) and - node.func.value.value.id == 'urllib' and - node.func.value.attr == 'request' and - node.func.attr == 'urlopen' - ): - for kwarg in node.keywords: - if ( - kwarg.arg == 'timeout' and - isinstance(kwarg.value, ast.Constant) and - kwarg.value.value is not None - ): - break - else: - # check if it was passed as a positional argument instead - # args are: (url, data=None, [timeout, ]*, cafile=None ... - if len(node.args) < 3: - self.assignments.append((node.lineno, node.col_offset)) self.generic_visit(node) +class Namespace(argparse.Namespace): + timeout_funcs: list[str] = [] + timeout_extend_funcs: list[str] = [] + + class Plugin: name = __name__ version = importlib_metadata.version(__name__) def __init__(self, tree: ast.AST): self._tree = tree + self.tracked_functions = getattr(Plugin, 'tracked_functions', None) + self.extend_tracked_functions = getattr( + Plugin, 'extend_tracked_functions', [], + ) + + @classmethod + def add_options(cls, option_manager: OptionManager) -> None: + option_manager.add_option( + '--timeout-funcs', + default=DEFAULT_TRACKED_FUNCTIONS, + parse_from_config=True, + comma_separated_list=True, + help=( + 'Comma-separated list of fully qualified function names to ' + 'check for timeout. This OVERRIDES the defaults. ' + 'Format: "module.function" or "module.function:index" where ' + 'index is the positional argument index for timeout ' + '(e.g., "foo.bar.baz,my.func:2"). ' + ), + ) + option_manager.add_option( + '--timeout-extend-funcs', + default='', + parse_from_config=True, + comma_separated_list=True, + help=( + 'Comma-separated list of additional fully qualified function ' + 'names to check for timeout. This EXTENDS the default list. ' + 'Format: "module.function" or "module.function:index" where ' + 'index is the positional argument index for timeout ' + '(e.g., "foo.bar.baz,my.func:2").' + ), + ) + + @classmethod + def parse_options(cls, options: Namespace) -> None: + # Validate tracked_functions specs + for spec in options.timeout_funcs: + parse_function_spec(spec) + + # Validate extend_tracked_functions specs + for spec in options.timeout_extend_funcs: + parse_function_spec(spec) + + cls.tracked_functions = options.timeout_funcs + cls.extend_tracked_functions = options.timeout_extend_funcs + + def _parse_tracked_functions( + self, + specs: list[str], + ) -> tuple[set[tuple[str, str]], dict[str, int]]: + tracked = set() + positional = {} + + for spec in specs: + (module, func), pos_index = parse_function_spec(spec) + tracked.add((module, func)) + if pos_index is not None: + positional[f"{module}.{func}"] = pos_index + + return tracked, positional def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]: - visitor = Visitor() + # Determine which functions to track + if self.extend_tracked_functions: + # Extension mode: use defaults + extensions + specs = list(DEFAULT_TRACKED_FUNCTIONS) + specs.extend(self.extend_tracked_functions) + elif self.tracked_functions is not None: + # Override mode: use only specified functions + specs = self.tracked_functions + else: + # Default mode: use defaults + specs = list(DEFAULT_TRACKED_FUNCTIONS) + + tracked_functions, timeout_positional = self._parse_tracked_functions( + specs, + ) + + visitor = Visitor(tracked_functions, timeout_positional) visitor.visit(self._tree) for line, col in visitor.assignments: yield line, col, MSG, type(self) diff --git a/tests/flake8_timeout_test.py b/tests/flake8_timeout_test.py index 8d5f69c..54e98a5 100644 --- a/tests/flake8_timeout_test.py +++ b/tests/flake8_timeout_test.py @@ -1,7 +1,9 @@ import ast import pytest +from flake8.options.manager import OptionManager +from flake8_timeout import parse_function_spec from flake8_timeout import Plugin @@ -12,10 +14,45 @@ def results(s): @pytest.mark.parametrize( 's', ( - 'print("hello hello world")', - '', - 'a = 5', - 'a = foo(x=5, timeout=None)', + pytest.param('print("hello hello world")', id='print-statement'), + pytest.param('', id='empty-string'), + pytest.param('a = 5', id='simple-assignment'), + pytest.param( + 'a = foo(x=5, timeout=None)', + id='untracked-call-with-timeout', + ), + pytest.param( + 'from . import something\nsomething.func("url")', + id='relative-import', + ), + pytest.param( + 'from foo import bar\nbar("url")', + id='direct-call-not-tracked', + ), + pytest.param( + 'unknown.method("url")', + id='unimported-attribute-call', + ), + pytest.param( + 'urllib.request.unknown("url")', + id='nested-untracked-attribute', + ), + pytest.param( + 'some_object.method("url")', + id='unimported-base-object', + ), + pytest.param( + 'funcs = [lambda x: x]\nfuncs[0]("url")', + id='call-on-subscript', + ), + pytest.param( + 'def get_client():\n import requests\n return requests\nget_client().get("url")', # noqa: E501 + id='call-on-call-result', + ), + pytest.param( + 'def get_obj():\n class Obj:\n attr = None\n return Obj()\nget_obj().attr.method("url")', # noqa: E501 + id='nested-attribute-on-call', + ), ), ) def test_no_requests_expression(s): @@ -25,8 +62,32 @@ def test_no_requests_expression(s): @pytest.mark.parametrize( 's', ( - 'a = requests.session(foo="bar")', - 'a = urllib.request.Request(foo="bar")', + pytest.param( + '''\ +import requests +a = requests.session(foo="bar") +''', + id='requests-session', + ), + pytest.param( + '''\ +import urllib.request +a = urllib.request.Request(foo="bar") +''', + id='urllib-request-constructor', + ), + pytest.param( + 'import requests\nrequests.unknown_method("url")', + id='requests-unknown-method', + ), + pytest.param( + 'from urllib import request\n\nrequest.unknown_method("url")', + id='urllib-request-unknown-method', + ), + pytest.param( + 'from urllib import request\n\nrequest.untracked_function("url")', + id='urllib-request-untracked-function', + ), ), ) def test_unknown_method(s): @@ -36,14 +97,24 @@ def test_unknown_method(s): @pytest.mark.parametrize( 's', ( - 'a = requests.post("https://example.com", timeout=5, foo="bar")', - '''\ + pytest.param( + '''\ +import requests +a = requests.post("https://example.com", timeout=5, foo="bar") +''', + id='requests-post-with-timeout', + ), + pytest.param( + '''\ +import urllib.request a = urllib.request.urlopen( "https://example.com", timeout=5, bar="baz", ) ''', + id='urllib-urlopen-with-timeout', + ), ), ) def test_timout_is_kwarg(s): @@ -51,70 +122,453 @@ def test_timout_is_kwarg(s): def test_timout_is_arg(): - s = 'a = urllib.request.urlopen("https://example.com", None, 5, arg="t")' + s = '''\ +import urllib.request +a = urllib.request.urlopen("https://example.com", None, 5, arg="t") +''' assert not results(s) @pytest.mark.parametrize( 's', ( - 'a = requests.post("https://example.com", params={"bar": "baz"})', - 'a = requests.post("https://example.com", timeout=None)', - 'a = requests.get("https://example.com", timeout=None)', - 'a = requests.put("https://example.com", timeout=None)', - 'a = requests.delete("https://example.com", timeout=None)', - 'a = urllib.request.urlopen("https://example.com", bar="baz")', + pytest.param( + '''\ +import requests +a = requests.post("https://example.com", params={"bar": "baz"}) +''', + id='requests-post-no-timeout', + ), + pytest.param( + '''\ +import requests +a = requests.post("https://example.com", timeout=None) +''', + id='requests-post-timeout-none', + ), + pytest.param( + '''\ +import requests +a = requests.get("https://example.com", timeout=None) +''', + id='requests-get-timeout-none', + ), + pytest.param( + '''\ +import requests +a = requests.put("https://example.com", timeout=None) +''', + id='requests-put-timeout-none', + ), + pytest.param( + '''\ +import requests +a = requests.delete("https://example.com", timeout=None) +''', + id='requests-delete-timeout-none', + ), + pytest.param( + '''\ +import urllib.request +a = urllib.request.urlopen("https://example.com", bar="baz") +''', + id='urllib-urlopen-no-timeout', + ), ), ) def test_timeout_missing(s): msg, = results(s) - assert msg == '1:4: TIM100 request call has no timeout' + assert msg == '2:4: TIM100 request call has no timeout' @pytest.mark.parametrize( 's', ( - 'a = foo(bar=requests.get("https://example.com"))', - 'a = foo(bar=urllib.request.urlopen("https://example.com"))', + pytest.param( + '''\ +import requests +a = foo(bar=requests.get("https://example.com")) +''', + id='requests-as-kwarg', + ), + pytest.param( + '''\ +import urllib.request +a = foo(bar=urllib.request.urlopen("https://example.com")) +''', + id='urllib-as-kwarg', + ), ), ) -def test_call_as_kwarg(s): +def test_call_as_func_kwarg(s): msg, = results(s) - assert msg == '1:12: TIM100 request call has no timeout' + assert msg == '2:12: TIM100 request call has no timeout' @pytest.mark.parametrize( 's', ( - 'a = foo(requests.get("https://example.com"))', - 'a = foo(urllib.request.urlopen("https://example.com"))', + pytest.param( + '''\ +import requests +a = foo(requests.get("https://example.com")) +''', + id='requests-as-positional', + ), + pytest.param( + '''\ +import urllib.request +a = foo(urllib.request.urlopen("https://example.com")) +''', + id='urllib-as-positional', + ), ), ) -def test_call_as_arg(s): +def test_call_as_func_pos_arg(s): msg, = results(s) - assert msg == '1:8: TIM100 request call has no timeout' + assert msg == '2:8: TIM100 request call has no timeout' @pytest.mark.parametrize( 's', ( - 'foo(bar=requests.get("https://example.com"))', - 'foo(bar=urllib.request.urlopen("https://example.com"))', + pytest.param( + '''\ +import requests +foo(bar=requests.get("https://example.com")) +''', + id='requests-no-assignment', + ), + pytest.param( + '''\ +import urllib.request +foo(bar=urllib.request.urlopen("https://example.com")) +''', + id='urllib-no-assignment', + ), ), ) -def test_call_as_kwarg_no_assing(s): +def test_call_as_func_kwarg_no_assing(s): msg, = results(s) - assert msg == '1:8: TIM100 request call has no timeout' + assert msg == '2:8: TIM100 request call has no timeout' @pytest.mark.parametrize( 's', ( - 'def foo(bar=requests.get("https://example.com")):\n ...', - 'def foo(bar=urllib.request.urlopen("https://example.com")):\n ...', + pytest.param( + '''\ +import requests +def foo(bar=requests.get("https://example.com")): ... +''', + id='requests-function-default', + ), + pytest.param( + '''\ +import urllib.request +def foo(bar=urllib.request.urlopen("https://example.com")): ... +''', + id='urllib-function-default', + ), ), ) -def test_call_as_function_argument_default(s): - s = 'def foo(bar=requests.get("https://example.com")):\n ...' +def test_call_as_func_arg_default(s): + msg, = results(s) + assert msg == '2:12: TIM100 request call has no timeout' + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'from urllib.request import urlopen\nurlopen("google.com")', + id='from-import-direct-call', + ), + pytest.param( + 'from urllib import request\nrequest.urlopen("google.com")', + id='from-import-module-then-attr', + ), + pytest.param( + 'from urllib.request import urlopen as _urlopen\n_urlopen("t.de")', + id='from-import-with-alias', + ), + pytest.param( + 'from requests import get\nget("https://example.com")', + id='requests-from-import', + ), + pytest.param( + 'from requests import post\npost("https://example.com", data={"key": "value"})', # noqa: E501 + id='requests-from-import-post', + ), + pytest.param( + 'import requests as req\nreq.get("https://example.com")', + id='import-requests-with-alias', + ), + pytest.param( + 'import urllib.request as ur\nur.urlopen("google.com")', + id='import-urllib-request-with-alias', + ), + ), +) +def test_different_import_styles_no_timeout(s): + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'from urllib.request import urlopen\nurlopen("google.com", timeout=5)', # noqa: E501 + id='from-import-direct-call', + ), + pytest.param( + 'from urllib import request\nrequest.urlopen("google.com", timeout=5)', # noqa: E501 + id='from-import-module-then-attr', + ), + pytest.param( + 'from urllib.request import urlopen as _urlopen\n_urlopen("google.com", timeout=5)', # noqa: E501 + id='from-import-with-alias', + ), + pytest.param( + 'from requests import get\nget("https://example.com", timeout=5)', + id='requests-from-import', + ), + pytest.param( + 'import requests as req\nreq.get("https://t.com", timeout=5)', + id='import-requests-with-alias', + ), + pytest.param( + 'import urllib.request as ur\nur.urlopen("google.com", timeout=5)', + id='import-urllib-request-with-alias', + ), + ), +) +def test_import_styles_with_timeout(s): + assert not results(s) + + +@pytest.fixture +def manager(): + mgr = OptionManager( + version='0', + plugin_versions='', + formatter_names=(), + parents=[], + ) + Plugin.add_options(mgr) + return mgr + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'import requests\nrequests.get("url")', + id='requests-get', + ), + pytest.param( + 'import requests\nrequests.post("url")', + id='requests-post', + ), + pytest.param( + 'from requests import get\nget("url")', + id='requests-from-import-get', + ), + pytest.param( + 'from requests import post\npost("url")', + id='requests-from-import-post', + ), + pytest.param( + 'import urllib.request\nurllib.request.urlopen("url")', + id='urllib-import-urlopen', + ), + pytest.param( + 'from urllib.request import urlopen\nurlopen("url")', + id='urllib-from-import-urlopen', + ), + ), +) +def test_option_parsing_no_custom_functions_use_defaults( + s: str, + manager: OptionManager, +) -> None: + Plugin.parse_options(manager.parse_args([])) + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + +def test_option_parsing_extend_functions(manager: OptionManager) -> None: + options = manager.parse_args([ + '--timeout-extend-funcs=foo.bar.baz,my.module.request', + ]) + Plugin.parse_options(options) + + # default functions should still work + s = 'import requests\nrequests.get("url")' + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + # extended function should also work + s = 'from foo import bar\nbar.baz("url")' + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + +def test_option_parsing_override_functions(manager: OptionManager) -> None: + options = manager.parse_args(['--timeout-funcs=foo.bar.baz']) + Plugin.parse_options(options) + + # default functions should not be detected + s = 'import requests\nrequests.get("url")' + assert not results(s) + + # new function is detected + s = 'from foo import bar\nbar.baz("url")' + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + +def test_option_parsing_positional_arg_timeout(manager: OptionManager) -> None: + options = manager.parse_args([ + '--timeout-extend-funcs=my.func:3,other.func:1', + ]) + Plugin.parse_options(options) + + # timeout at idx 3 works + s = 'from my import func\nfunc("url", None, None, 10)' + assert not results(s) + + # timeout at idx 1 works + s = 'from other import func\nfunc("url", 10)' + assert not results(s) + + +def test_custom_tracked_function_no_timeout(manager: OptionManager) -> None: + options = manager.parse_args( + ['--timeout-extend-funcs=foo.bar.baz'], + ) + Plugin.parse_options(options) + + s = 'from foo import bar\nbar.baz("url")' + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + +def test_custom_tracked_function_with_timeout(manager: OptionManager) -> None: + options = manager.parse_args( + ['--timeout-extend-funcs=foo.bar.baz'], + ) + Plugin.parse_options(options) + + s = 'from foo import bar\nbar.baz("url", timeout=5)' + assert not results(s) + + +def test_custom_positional_timeout(manager: OptionManager) -> None: + options = manager.parse_args([ + '--timeout-extend-funcs=my.module.func:2', + ]) + Plugin.parse_options(options) + + s = '''\ +from my import module +module.func('url', None, 10) +''' + assert not results(s) + + +def test_custom_positional_timeout_missing(manager: OptionManager) -> None: + options = manager.parse_args([ + '--timeout-extend-funcs=my.module.func:2', + ]) + Plugin.parse_options(options) + + s = 'from my import module\nmodule.func("url", None)' + # Should fail because only 2 args but needs 3 + msg, = results(s) + assert msg == '2:0: TIM100 request call has no timeout' + + +@pytest.mark.parametrize( + ('spec', 'expected'), + ( + pytest.param( + 'urllib.request.urlopen', + (('urllib.request', 'urlopen'), None), + id='urllib-no-index', + ), + pytest.param( + 'requests.get', + (('requests', 'get'), None), + id='requests-no-index', + ), + pytest.param( + 'foo.bar.baz', + (('foo.bar', 'baz'), None), + id='nested-module-no-index', + ), + pytest.param( + 'a.b.c.d.e', + (('a.b.c.d', 'e'), None), + id='deeply-nested-no-index', + ), + pytest.param( + 'urllib.request.urlopen:2', + (('urllib.request', 'urlopen'), 2), + id='urllib-with-index', + ), + pytest.param( + 'my.func:0', + (('my', 'func'), 0), + id='index-zero', + ), + pytest.param( + 'foo.bar:5', + (('foo', 'bar'), 5), + id='index-five', + ), + ), +) +def test_parse_function_spec_valid(spec, expected): + assert parse_function_spec(spec) == expected + + +@pytest.mark.parametrize( + ('spec', 'error_msg'), + ( + pytest.param( + 'single', + "Function spec must be at least 'module.function': single", + id='missing-module', + ), + pytest.param( + 'my.func:abc', + 'Positional index must be an integer in spec: my.func:abc', + id='invalid-index-non-digit', + ), + pytest.param( + 'foo.bar:notanumber', + 'Positional index must be an integer in spec: foo.bar:notanumber', + id='invalid-index-text', + ), + ), +) +def test_parse_function_spec_invalid(spec, error_msg): + with pytest.raises(ValueError) as excinfo: + parse_function_spec(spec) + msg, = excinfo.value.args + assert msg == error_msg + + +def test_complex_nested_expression(): + s = '''\ +import requests + +def outer(): + def inner(): + requests.get('url') + return inner +''' msg, = results(s) - assert msg == '1:12: TIM100 request call has no timeout' + assert msg == '5:8: TIM100 request call has no timeout'