From 6d2e9eae0aeecaf7560e59d2d493aef0a5767f67 Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 18 Oct 2022 12:59:44 +1300 Subject: [PATCH] [Nonlinear] remove recursion when parsing expressions --- src/Nonlinear/parse.jl | 61 +++++++++++++++++++++++-------------- test/Nonlinear/Nonlinear.jl | 14 +++++++++ 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/Nonlinear/parse.jl b/src/Nonlinear/parse.jl index 8beb888bae..7f0d92549a 100644 --- a/src/Nonlinear/parse.jl +++ b/src/Nonlinear/parse.jl @@ -42,25 +42,38 @@ function parse_expression( x::Expr, parent_index::Int, ) + stack = Tuple{Int,Any}[] + push!(stack, (parent_index, x)) + while !isempty(stack) + parent, item = pop!(stack) + if item isa Expr + _parse_expression(stack, data, expr, item, parent) + else + parse_expression(data, expr, item, parent) + end + end + return +end + +function _parse_expression(stack, data, expr, x, parent_index) if isexpr(x, :call) if length(x.args) == 2 && !isexpr(x.args[2], :...) - _parse_univariate_expression(data, expr, x, parent_index) + _parse_univariate_expression(stack, data, expr, x, parent_index) else - _parse_multivariate_expression(data, expr, x, parent_index) + _parse_multivariate_expression(stack, data, expr, x, parent_index) end elseif isexpr(x, :comparison) - _parse_comparison_expression(data, expr, x, parent_index) + _parse_comparison_expression(stack, data, expr, x, parent_index) elseif isexpr(x, :...) - _parse_splat_expression(data, expr, x, parent_index) + _parse_splat_expression(stack, data, expr, x, parent_index) elseif isexpr(x, :&&) || isexpr(x, :||) - _parse_logic_expression(data, expr, x, parent_index) + _parse_logic_expression(stack, data, expr, x, parent_index) else error("Unsupported expression: $x") end - return end -function _parse_splat_expression(data, expr, x, parent_index) +function _parse_splat_expression(stack, data, expr, x, parent_index) @assert isexpr(x, :...) && length(x.args) == 1 if parent_index == -1 error( @@ -74,13 +87,14 @@ function _parse_splat_expression(data, expr, x, parent_index) "`(x + 1)...`, `[x; y]...` and `g(f(y)...)` are not.", ) end - for xi in x.args[1] - parse_expression(data, expr, xi, parent_index) + for xi in reverse(x.args[1]) + push!(stack, (parent_index, xi)) end return end function _parse_univariate_expression( + stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, x::Expr, @@ -91,17 +105,18 @@ function _parse_univariate_expression( if id === nothing # It may also be a multivariate operator like * with one argument. if haskey(data.operators.multivariate_operator_to_id, x.args[1]) - _parse_multivariate_expression(data, expr, x, parent_index) + _parse_multivariate_expression(stack, data, expr, x, parent_index) return end error("Unable to parse: $x") end push!(expr.nodes, Node(NODE_CALL_UNIVARIATE, id, parent_index)) - parse_expression(data, expr, x.args[2], length(expr.nodes)) + push!(stack, (length(expr.nodes), x.args[2])) return end function _parse_multivariate_expression( + stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, x::Expr, @@ -111,13 +126,12 @@ function _parse_multivariate_expression( id = get(data.operators.multivariate_operator_to_id, x.args[1], nothing) if id === nothing @assert x.args[1] in data.operators.comparison_operators - _parse_inequality_expression(data, expr, x, parent_index) + _parse_inequality_expression(stack, data, expr, x, parent_index) return end push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, id, parent_index)) - parent_var = length(expr.nodes) - for i in 2:length(x.args) - parse_expression(data, expr, x.args[i], parent_var) + for i in length(x.args):-1:2 + push!(stack, (length(expr.nodes), x.args[i])) end return end @@ -126,6 +140,7 @@ end # confused with `_parse_comparison_expression`, which handles things like # `a <= b <= c`. function _parse_inequality_expression( + stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, x::Expr, @@ -133,9 +148,8 @@ function _parse_inequality_expression( ) operator_id = data.operators.comparison_operator_to_id[x.args[1]] push!(expr.nodes, Node(NODE_COMPARISON, operator_id, parent_index)) - parent_var = length(expr.nodes) - for i in 2:length(x.args) - parse_expression(data, expr, x.args[i], parent_var) + for i in length(x.args):-1:2 + push!(stack, (length(expr.nodes), x.args[i])) end return end @@ -144,6 +158,7 @@ end # confused with `_parse_inequality_expression`, which handles things like # `a <= b`. function _parse_comparison_expression( + stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, x::Expr, @@ -154,14 +169,14 @@ function _parse_comparison_expression( end operator_id = data.operators.comparison_operator_to_id[x.args[2]] push!(expr.nodes, Node(NODE_COMPARISON, operator_id, parent_index)) - parent_var = length(expr.nodes) - for i in 1:2:length(x.args) - parse_expression(data, expr, x.args[i], parent_var) + for i in length(x.args):-2:1 + push!(stack, (length(expr.nodes), x.args[i])) end return end function _parse_logic_expression( + stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, x::Expr, @@ -170,8 +185,8 @@ function _parse_logic_expression( id = data.operators.logic_operator_to_id[x.head] push!(expr.nodes, Node(NODE_LOGIC, id, parent_index)) parent_var = length(expr.nodes) - parse_expression(data, expr, x.args[1], parent_var) - parse_expression(data, expr, x.args[2], parent_var) + push!(stack, (parent_var, x.args[2])) + push!(stack, (parent_var, x.args[1])) return end diff --git a/test/Nonlinear/Nonlinear.jl b/test/Nonlinear/Nonlinear.jl index 66e4505244..edfa108caa 100644 --- a/test/Nonlinear/Nonlinear.jl +++ b/test/Nonlinear/Nonlinear.jl @@ -870,6 +870,20 @@ function test_eval_atan2() return end +function test_deep_recursion() + model = Nonlinear.Model() + x = MOI.VariableIndex(1) + y = Expr(:call, :sin, x) + for _ in 1:20_000 + y = Expr(:call, :^, Expr(:call, :sqrt, y), 2) + end + start = time() + @test Nonlinear.parse_expression(model, y) isa Nonlinear.Expression + # A conservative bound to check we're not doing something expensive. + @test time() - start < 1.0 + return +end + end TestNonlinear.runtests()