diff --git a/Project.toml b/Project.toml index 43f0322..84b1e72 100644 --- a/Project.toml +++ b/Project.toml @@ -8,12 +8,8 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] DiffRules = "1.15.1" @@ -21,9 +17,7 @@ DifferentialEquations = "7.17.0" ForwardDiff = "1.2.2" Lux = "1.29.1" Plots = "1.41.1" -ProfileView = "1.10.2" SparseArrays = "1.10" -StaticArrays = "1.9.16" [extras] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" @@ -31,11 +25,9 @@ DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -dev = ["Revise"] examples = ["Plots", "DifferentialEquations", "ForwardDiff", "BandedMatrices", "Lux"] test = ["Test", "ForwardDiff", "BandedMatrices", "SparseArrays"] diff --git a/src/DualArrays.jl b/src/DualArrays.jl index 2d30c3d..bbf3392 100644 --- a/src/DualArrays.jl +++ b/src/DualArrays.jl @@ -13,7 +13,7 @@ module DualArrays export DualVector, Dual, jacobian -import Base: +, -, ==, getindex, size, axes, broadcasted, show, sum, vcat, convert, *, isapprox +import Base: +, -, ==, getindex, size, axes, broadcasted, show, sum, vcat, convert, *, isapprox, promote_type using LinearAlgebra, ArrayLayouts, FillArrays, DiffRules diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 998b50b..0749552 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -77,6 +77,16 @@ for (_, f, n) in DiffRules.diffrules(filter_modules=(:Base,)) jac = $p1.(x.value, y.value) .* x.jacobian .+ $p2.(x.value, y.value) .* y.jacobian return DualVector(val, jac) end + @eval function broadcasted(::typeof($f), x::Dual, y::AbstractVector) + val = $f.(x.value, y) + jac = $p1.(x.value, y) .* transpose(x.partials) + return DualVector(val, jac) + end + @eval function broadcasted(::typeof($f), x::AbstractVector, y::Dual) + val = $f.(x, y.value) + jac = $p2.(x, y.value) .* transpose(y.partials) + return DualVector(val, jac) + end # Must have Base.$f in order not to import everything @eval Base.$f(x::Dual, y::Dual) = Dual($f(x.value, y.value), $p1(x.value, y.value) * x.partials + $p2(x.value, y.value) * y.partials) @eval Base.$f(x::Dual, y::Real) = Dual($f(x.value, y), $p1(x.value, y) * x.partials) diff --git a/src/types.jl b/src/types.jl index 5367511..62712fb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -10,10 +10,12 @@ A dual number type that stores a value and its partials (derivatives). - `partials::Partials`: The partial derivatives as a vector """ struct Dual{T, Partials <: AbstractVector{T}} <: Real - value::T + value::Union{T, Dual{T}} partials::Partials end + + """ DualVector{T, M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}} @@ -32,11 +34,11 @@ For now the entries just return the values when indexed. Constructs a DualVector, ensuring that the vector length matches the number of rows in the Jacobian. """ -struct DualVector{T, V <: AbstractVector{T},M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}} +struct DualVector{T, V <: AbstractVector{<:Union{T, Dual{T}}},M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}} value::V jacobian::M - function DualVector(value::V, jacobian::M) where {T, V <: AbstractVector{T}, M <: AbstractMatrix{T}} + function DualVector(value::V, jacobian::M) where {T, V <: AbstractVector{<:Union{T, Dual{T}}}, M <: AbstractMatrix{T}} if size(jacobian, 1) != length(value) x, y = length(value), size(jacobian, 1) throw(ArgumentError("vector length must match number of rows in jacobian.\n" * @@ -47,12 +49,19 @@ struct DualVector{T, V <: AbstractVector{T},M <: AbstractMatrix{T}} <: AbstractV end end +promote_type_dual(T, S) = promote_type(T, S) + +promote_type_dual(::Type{<:Dual{T}}, S) where {T} = promote_type(T, S) +promote_type_dual(S, ::Type{<:Dual{T}}) where {T} = promote_type(T, S) + +broadcasted(::Type{T}, d::DualVector) where {T} = DualVector(T.(d.value), T.(d.jacobian)) + """ Constructor that forces type compatibility """ -function DualVector(value::AbstractVector, jacobian::AbstractMatrix) - T = promote_type(eltype(value), eltype(jacobian)) - DualVector(convert(Vector{T}, value), convert(AbstractMatrix{T}, jacobian)) +function DualVector(value::AbstractVector{T}, jacobian::AbstractMatrix{S}) where {T, S} + U = promote_type_dual(T, S) + DualVector(U.(value), U.(jacobian)) end # Basic equality for Dual numbers diff --git a/src/utilities.jl b/src/utilities.jl index 080155f..c656964 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -49,8 +49,13 @@ end """ Custom display method for DualVectors. """ -Base.show(io::IO, ::MIME"text/plain", x::DualVector) = - (print(io, x.value); print(io, " + "); print(io, x.jacobian); print(io, "𝛜")) +show_dual_vector(io::IO, x::AbstractArray, ::Any) = print(io, x) +function show_dual_vector(io::IO, x::DualVector, i = 0) + show_dual_vector(io, x.value, i+1) + print(io, " + $(x.jacobian)ϵ" * repeat('\'', i)) +end +Base.show(io::IO, x::DualVector) = show_dual_vector(io, x) +Base.show(io::IO, ::MIME"text/plain", x::DualVector) = Base.show(io, x) """ Utility function to compute the jacobian of a function `f` at point `x`. diff --git a/test/broadcast_test.jl b/test/broadcast_test.jl index f87d9f4..7358d9d 100644 --- a/test/broadcast_test.jl +++ b/test/broadcast_test.jl @@ -71,4 +71,25 @@ using DualArrays, Test, SparseArrays, LinearAlgebra @test s.jacobian ≈ [0.0 -2.0 -3.0; -1.0 -1.0 -3.0; -1.0 -2.0 -2.0] @test m.jacobian ≈ [3.0 2.0 3.0; 2.0 6.0 6.0; 3.0 6.0 11.0] @test div.jacobian ≈ [0.25 -0.5 -0.75; -0.5 -0.5 -1.5; -0.75 -1.5 -1.75] + + # Broadcasting between Dual and AbstractVector + a = x .+ [1.0, 2.0, 3.0] + s = [1.0, 2.0, 3.0] .- x + m = x .* [1.0, 2.0, 3.0] + div = [1.0, 2.0, 3.0] ./ x + + @test a isa DualVector + @test s isa DualVector + @test m isa DualVector + @test div isa DualVector + + @test a.value ≈ [3.0, 4.0, 5.0] + @test s.value ≈ [-1.0, 0.0, 1.0] + @test m.value ≈ [2.0, 4.0, 6.0] + @test div.value ≈ [0.5, 1.0, 1.5] + + @test a.jacobian ≈ [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0] + @test s.jacobian ≈ [-1.0 -2.0 -3.0; -1.0 -2.0 -3.0; -1.0 -2.0 -3.0] + @test m.jacobian ≈ [1.0 2.0 3.0; 2.0 4.0 6.0; 3.0 6.0 9.0] + @test div.jacobian ≈ [-0.25 -0.5 -0.75; -0.5 -1.0 -1.5; -0.75 -1.5 -2.25] end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 729b4ed..30ccb21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,6 +81,17 @@ using DualArrays: Dual @test vcat(x, x) == DualVector([1, 1], [1 2 3;1 2 3]) @test vcat(x, y) == DualVector([1, 2, 3], [1 2 3;4 5 6;7 8 9]) end + + @testset "Hessian" begin + d = DualVector( + DualVector([1, 2], [1 0;0 1]), + [1 0;0 1] + ) + f(x) = x[1] * x[2] + @test f(d) isa Dual + @test f(d).partials isa DualVector + @test f(d).partials.jacobian == [0 1; 1 0] + end include("broadcast_test.jl") end \ No newline at end of file