Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,17 @@ This most commonly happens in situations like `x^2`.
"""
_rewrite_generic(::Expr, x::Number) = x, true

function _is_generator(expr)
return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :generator)
end

function _is_flatten(expr)
return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :flatten)
function _is_generic_generator_call(expr)
if Meta.isexpr(expr, :call, 2)
return _is_generator_or_flatten(expr.args[2])
end
return Meta.isexpr(expr, :call, 3) &&
Meta.isexpr(expr.args[2], :parameters) &&
_is_generator_or_flatten(expr.args[3])
end

_is_generator_or_flatten(expr) = _is_generator(expr) || _is_flatten(expr)

function _is_parameters(expr)
return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters)
function _is_generator_or_flatten(expr)
return Meta.isexpr(expr, :generator) || Meta.isexpr(expr, :flatten)
end

function _is_kwarg(expr, kwarg::Symbol)
Expand Down Expand Up @@ -113,7 +112,7 @@ function _rewrite_generic(stack::Expr, expr::Expr)
elseif Meta.isexpr(expr.args[2], :(...))
# If the first argument is a splat.
return esc(expr), false
elseif _is_generator_or_flatten(expr) || _is_parameters(expr)
elseif _is_generic_generator_call(expr)
if !(expr.args[1] in (:sum, :Σ, :∑))
# We don't know what this is. Return the expression and don't let
# future callers mutate.
Expand Down
16 changes: 16 additions & 0 deletions test/test_rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,22 @@ function test_issue_76_trailing_dimensions()
return
end

function test_issue_359()
for (expr, has_zero) in Pair{Expr,Bool}[
:(sum(i for i in I))=>true,
:(sum(i for i in I; init = 0))=>false,
:(sum(i for i in I, init in 0))=>false,
]
stack = quote end
root, is_mutable = MA._rewrite_generic(stack, expr)
@test root isa Symbol
@test is_mutable
@test length(stack.args) == 3
@test occursin(".Zero()", string(stack.args[2])) == has_zero
end
return
end

end # module

TestRewriteGeneric.runtests()
Loading