diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index c9b01e2..b0f2424 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -25,6 +25,7 @@ def __init__( use_math_symbols: bool = False, use_signature: bool = True, use_set_symbols: bool = False, + _match_subject_stack: list[str] = [], ) -> None: """Initializer. @@ -34,6 +35,7 @@ def __init__( use_signature: Whether to add the function signature before the expression or not. use_set_symbols: Whether to use set symbols or not. + _match_subject_stack: a stack of subject names that are used in match """ self._expression_codegen = expression_codegen.ExpressionCodegen( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols @@ -42,6 +44,7 @@ def __init__( use_math_symbols=use_math_symbols ) self._use_signature = use_signature + self._match_subject_stack = _match_subject_stack def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( @@ -140,7 +143,10 @@ def visit_If(self, node: ast.If) -> str: return latex + r", & \mathrm{otherwise} \end{array} \right." def visit_Match(self, node: ast.Match) -> str: - """Visit a Match node""" + """Visit a Match node.""" + subject_latex = self._expression_codegen.visit(node.subject) + self._match_subject_stack.append(subject_latex) + if not ( len(node.cases) >= 2 and isinstance(node.cases[-1].pattern, ast.MatchAs) @@ -162,21 +168,27 @@ def visit_Match(self, node: ast.Match) -> str: if i < len(node.cases) - 1: body_latex = self.visit(case.body[0]) cond_latex = self.visit(case.pattern) - case_latexes.append( - body_latex + r", & \mathrm{if} \ " + subject_latex + cond_latex - ) + + case_latexes.append(body_latex + r", & \mathrm{if} \ " + cond_latex) else: case_latexes.append( - self.visit(node.cases[-1].body[0]) + r", & \mathrm{otherwise}" + self.visit(case.body[0]) + r", & \mathrm{otherwise}" ) - return ( + latex = ( r"\left\{ \begin{array}{ll} " + r" \\ ".join(case_latexes) + r" \end{array} \right." ) + self._match_subject_stack.pop() + return latex + def visit_MatchValue(self, node: ast.MatchValue) -> str: - """Visit a MatchValue node""" + """Visit a MatchValue node.""" latex = self._expression_codegen.visit(node.value) - return " = " + latex + return self._match_subject_stack[-1] + " = " + latex + + def visit_MatchOr(self, node: ast.MatchOr) -> str: + """Visit a MatchOr node.""" + return r" \lor ".join(self.visit(p) for p in node.patterns) diff --git a/src/latexify/codegen/function_codegen_match_test.py b/src/latexify/codegen/function_codegen_match_test.py index 2467f94..1bc6d07 100644 --- a/src/latexify/codegen/function_codegen_match_test.py +++ b/src/latexify/codegen/function_codegen_match_test.py @@ -12,7 +12,7 @@ @test_utils.require_at_least(10) -def test_functiondef_match() -> None: +def test_visit_functiondef_match() -> None: tree = ast.parse( textwrap.dedent( """ @@ -27,8 +27,8 @@ def f(x): ) expected = ( r"f(x) =" - r" \left\{ \begin{array}{ll}" - r" 1, & \mathrm{if} \ x = 0 \\" + r" \left\{ \begin{array}{ll} " + r"1, & \mathrm{if} \ x = 0 \\" r" 3 \cdot x, & \mathrm{otherwise}" r" \end{array} \right." ) @@ -36,7 +36,7 @@ def f(x): @test_utils.require_at_least(10) -def test_matchvalue() -> None: +def test_visit_match() -> None: tree = ast.parse( textwrap.dedent( """ @@ -49,8 +49,8 @@ def test_matchvalue() -> None: ) ).body[0] expected = ( - r"\left\{ \begin{array}{ll}" - r" 1, & \mathrm{if} \ x = 0 \\" + r"\left\{ \begin{array}{ll} " + r"1, & \mathrm{if} \ x = 0 \\" r" 2, & \mathrm{otherwise}" r" \end{array} \right." ) @@ -58,7 +58,7 @@ def test_matchvalue() -> None: @test_utils.require_at_least(10) -def test_multiple_matchvalue() -> None: +def test_visit_multiple_match_cases() -> None: tree = ast.parse( textwrap.dedent( """ @@ -73,8 +73,8 @@ def test_multiple_matchvalue() -> None: ) ).body[0] expected = ( - r"\left\{ \begin{array}{ll}" - r" 1, & \mathrm{if} \ x = 0 \\" + r"\left\{ \begin{array}{ll} " + r"1, & \mathrm{if} \ x = 0 \\" r" 2, & \mathrm{if} \ x = 1 \\" r" 3, & \mathrm{otherwise}" r" \end{array} \right." @@ -83,7 +83,7 @@ def test_multiple_matchvalue() -> None: @test_utils.require_at_least(10) -def test_single_matchvalue_no_wildcards() -> None: +def test_visit_single_match_case_no_wildcards() -> None: tree = ast.parse( textwrap.dedent( """ @@ -102,7 +102,7 @@ def test_single_matchvalue_no_wildcards() -> None: @test_utils.require_at_least(10) -def test_multiple_matchvalue_no_wildcards() -> None: +def test_visit_multiple_match_cases_no_wildcards() -> None: tree = ast.parse( textwrap.dedent( """ @@ -123,34 +123,35 @@ def test_multiple_matchvalue_no_wildcards() -> None: @test_utils.require_at_least(10) -def test_matchas_nonempty() -> None: +def test_visit_match_case_no_return() -> None: tree = ast.parse( textwrap.dedent( """ match x: - case [x] as y: - return 1 + case 0: + x = 5 case _: - return 2 + return 0 """ ) ).body[0] with pytest.raises( exceptions.LatexifyNotSupportedError, - match=r"^Unsupported AST: MatchAs$", + match=r"^Match cases must contain exactly 1 return statement\.$", ): function_codegen.FunctionCodegen().visit(tree) @test_utils.require_at_least(10) -def test_matchvalue_no_return() -> None: +def test_visit_match_case_mutliple_statements() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: x = 5 + return 1 case _: return 0 """ @@ -165,22 +166,21 @@ def test_matchvalue_no_return() -> None: @test_utils.require_at_least(10) -def test_matchvalue_mutliple_statements() -> None: +def test_visit_match_case_or() -> None: tree = ast.parse( textwrap.dedent( """ match x: - case 0: - x = 5 + case 0 | 1: return 1 case _: - return 0 + return 2 """ ) ).body[0] - with pytest.raises( - exceptions.LatexifyNotSupportedError, - match=r"^Match cases must contain exactly 1 return statement\.$", - ): + assert ( function_codegen.FunctionCodegen().visit(tree) + == r"\left\{ \begin{array}{ll} 1, & \mathrm{if} \ x = 0 \lor x = 1 \\" + + r" 2, & \mathrm{otherwise} \end{array} \right." + )