diff --git a/src/operators.jl b/src/operators.jl index 311afb73714..7c0ec273ae4 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -197,24 +197,24 @@ end function Base.:^(lhs::AbstractVariableRef, rhs::Integer) T = value_type(typeof(lhs)) - if rhs == 2 - return lhs * lhs + if rhs == 0 + return one(T) elseif rhs == 1 - return convert(GenericQuadExpr{T,variable_ref_type(lhs)}, lhs) - elseif rhs == 0 - return one(GenericQuadExpr{T,variable_ref_type(lhs)}) + return lhs + elseif rhs == 2 + return lhs * lhs else return GenericNonlinearExpr(:^, Any[lhs, rhs]) end end function Base.:^(lhs::GenericAffExpr{T}, rhs::Integer) where {T} - if rhs == 2 - return lhs * lhs + if rhs == 0 + return one(T) elseif rhs == 1 - return convert(GenericQuadExpr{T,variable_ref_type(lhs)}, lhs) - elseif rhs == 0 - return one(GenericQuadExpr{T,variable_ref_type(lhs)}) + return lhs + elseif rhs == 2 + return lhs * lhs else return GenericNonlinearExpr(:^, Any[lhs, rhs]) end diff --git a/test/test_macros.jl b/test/test_macros.jl index 993fe62a701..3c58375340b 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -704,9 +704,6 @@ function test_Nonliteral_exponents_in_constraint() @variable(model, x) foo() = 2 con1 = @build_constraint(x^(foo()) + x^(foo() - 1) + x^(foo() - 2) == 0) - # TODO(odow): `con2` fails to build due to a bug in MutableArithmetics. To - # fix, we need MutableArithmetics in the current scope. - MutableArithmetics = JuMP._MA con2 = @build_constraint( (x - 1)^(foo()) + (x - 1)^2 + (x - 1)^1 + (x - 1)^0 == 0 ) @@ -715,7 +712,7 @@ function test_Nonliteral_exponents_in_constraint() @test con1.func == x^2 + x @test con2.func == 2 * x^2 - 3 * x @test con3.func == 9 * x^2 - @test con4.func == convert(QuadExpr, 3 * x) + @test con4.func == 3 * x return end diff --git a/test/test_mutable_arithmetics.jl b/test/test_mutable_arithmetics.jl index d3bc9a33475..9e7489fec29 100644 --- a/test/test_mutable_arithmetics.jl +++ b/test/test_mutable_arithmetics.jl @@ -117,11 +117,9 @@ function test_extension_quadratic( VariableRefType = VariableRef, ) model = ModelType() - @variable(model, w) - @variable(model, x) - @variable(model, y) - @variable(model, z) - JuMP._MA.Test.quadratic_test(w, x, y, z) + @variable(model, x[1:4]) + # Test is excluded because of https://github.com/jump-dev/MutableArithmetics.jl/issues/227 + JuMP._MA.Test.quadratic_test(x...; exclude = ["quadratic_add_canonical"]) return end diff --git a/test/test_operator.jl b/test/test_operator.jl index 3903c3818e0..ebda3be21b5 100644 --- a/test/test_operator.jl +++ b/test/test_operator.jl @@ -376,7 +376,7 @@ function test_extension_basic_operators_variable( @test_expression_with_string x * y - 1 "x*y - 1" @test_expression_with_string(x^2, "x²", interrable = false) @test_expression_with_string(x^1, "x", interrable = false) - @test_expression_with_string(x^0, "1", interrable = false) + @test x^0 === one(T) @test_expression_with_string(x^3, "x ^ 3", interrable = false) @test_expression_with_string x^(T(15) / T(10)) "x ^ 1.5" # 2-2 Variable--Variable @@ -444,8 +444,8 @@ function test_extension_basic_operators_affexpr( "7.1 x + 2.5", inferrable = false ) - @test_expression_with_string(aff^0, "1", inferrable = false) - @test_expression_with_string((7.1 * x + 2.5)^0, "1", inferrable = false) + @test aff^0 === one(T) + @test (7.1 * x + 2.5)^0 === one(T) @test_expression_with_string(aff^3, "(7.1 x + 2.5) ^ 3", inferrable = false) @test_expression_with_string( (7.1 * x + 2.5)^3, @@ -619,7 +619,7 @@ function test_complex_pow() @variable(model, x) y = (1.0 + 2.0im) * x @test y^0 == (1.0 + 0im) - @test y^1 == 0 * y * y + y + @test y^1 == y @test y^2 == y * y @test isequal_canonical(y^3, GenericNonlinearExpr(:^, Any[y, 3])) return