diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 0d2b6ee..867a20f 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -39,18 +39,17 @@ This most commonly happens in situations like `x^2`. """ _rewrite_generic(::Expr, x::Number) = x, true -function _is_generator(expr) - return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :generator) -end - -function _is_flatten(expr) - return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :flatten) +function _is_generic_generator_call(expr) + if Meta.isexpr(expr, :call, 2) + return _is_generator_or_flatten(expr.args[2]) + end + return Meta.isexpr(expr, :call, 3) && + Meta.isexpr(expr.args[2], :parameters) && + _is_generator_or_flatten(expr.args[3]) end -_is_generator_or_flatten(expr) = _is_generator(expr) || _is_flatten(expr) - -function _is_parameters(expr) - return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters) +function _is_generator_or_flatten(expr) + return Meta.isexpr(expr, :generator) || Meta.isexpr(expr, :flatten) end function _is_kwarg(expr, kwarg::Symbol) @@ -113,7 +112,7 @@ function _rewrite_generic(stack::Expr, expr::Expr) elseif Meta.isexpr(expr.args[2], :(...)) # If the first argument is a splat. return esc(expr), false - elseif _is_generator_or_flatten(expr) || _is_parameters(expr) + elseif _is_generic_generator_call(expr) if !(expr.args[1] in (:sum, :Σ, :∑)) # We don't know what this is. Return the expression and don't let # future callers mutate. diff --git a/test/test_rewrite_generic.jl b/test/test_rewrite_generic.jl index 2686983..063b18e 100644 --- a/test/test_rewrite_generic.jl +++ b/test/test_rewrite_generic.jl @@ -675,6 +675,22 @@ function test_issue_76_trailing_dimensions() return end +function test_issue_359() + for (expr, has_zero) in Pair{Expr,Bool}[ + :(sum(i for i in I))=>true, + :(sum(i for i in I; init = 0))=>false, + :(sum(i for i in I, init in 0))=>false, + ] + stack = quote end + root, is_mutable = MA._rewrite_generic(stack, expr) + @test root isa Symbol + @test is_mutable + @test length(stack.args) == 3 + @test occursin(".Zero()", string(stack.args[2])) == has_zero + end + return +end + end # module TestRewriteGeneric.runtests()