Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,26 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
DiffRules = "1.15.1"
DifferentialEquations = "7.17.0"
ForwardDiff = "1.2.2"
Lux = "1.29.1"
Plots = "1.41.1"
ProfileView = "1.10.2"
SparseArrays = "1.10"
StaticArrays = "1.9.16"

[extras]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
dev = ["Revise"]
examples = ["Plots", "DifferentialEquations", "ForwardDiff", "BandedMatrices", "Lux"]
test = ["Test", "ForwardDiff", "BandedMatrices", "SparseArrays"]
2 changes: 1 addition & 1 deletion src/DualArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module DualArrays

export DualVector, Dual, jacobian

import Base: +, -, ==, getindex, size, axes, broadcasted, show, sum, vcat, convert, *, isapprox
import Base: +, -, ==, getindex, size, axes, broadcasted, show, sum, vcat, convert, *, isapprox, promote_type

using LinearAlgebra, ArrayLayouts, FillArrays, DiffRules

Expand Down
10 changes: 10 additions & 0 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ for (_, f, n) in DiffRules.diffrules(filter_modules=(:Base,))
jac = $p1.(x.value, y.value) .* x.jacobian .+ $p2.(x.value, y.value) .* y.jacobian
return DualVector(val, jac)
end
@eval function broadcasted(::typeof($f), x::Dual, y::AbstractVector)
val = $f.(x.value, y)
jac = $p1.(x.value, y) .* transpose(x.partials)
return DualVector(val, jac)
end
@eval function broadcasted(::typeof($f), x::AbstractVector, y::Dual)
val = $f.(x, y.value)
jac = $p2.(x, y.value) .* transpose(y.partials)
return DualVector(val, jac)
end
# Must have Base.$f in order not to import everything
@eval Base.$f(x::Dual, y::Dual) = Dual($f(x.value, y.value), $p1(x.value, y.value) * x.partials + $p2(x.value, y.value) * y.partials)
@eval Base.$f(x::Dual, y::Real) = Dual($f(x.value, y), $p1(x.value, y) * x.partials)
Expand Down
21 changes: 15 additions & 6 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ A dual number type that stores a value and its partials (derivatives).
- `partials::Partials`: The partial derivatives as a vector
"""
struct Dual{T, Partials <: AbstractVector{T}} <: Real
value::T
value::Union{T, Dual{T}}
partials::Partials
end



"""
DualVector{T, M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}}

Expand All @@ -32,11 +34,11 @@ For now the entries just return the values when indexed.

Constructs a DualVector, ensuring that the vector length matches the number of rows in the Jacobian.
"""
struct DualVector{T, V <: AbstractVector{T},M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}}
struct DualVector{T, V <: AbstractVector{<:Union{T, Dual{T}}},M <: AbstractMatrix{T}} <: AbstractVector{Dual{T}}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks wrong. I don't see the role of the Union, and if the elments of value are of type Dual{T} then shouldn't the elements of jacobian also be of type Dual{T}?

value::V
jacobian::M

function DualVector(value::V, jacobian::M) where {T, V <: AbstractVector{T}, M <: AbstractMatrix{T}}
function DualVector(value::V, jacobian::M) where {T, V <: AbstractVector{<:Union{T, Dual{T}}}, M <: AbstractMatrix{T}}
if size(jacobian, 1) != length(value)
x, y = length(value), size(jacobian, 1)
throw(ArgumentError("vector length must match number of rows in jacobian.\n" *
Expand All @@ -47,12 +49,19 @@ struct DualVector{T, V <: AbstractVector{T},M <: AbstractMatrix{T}} <: AbstractV
end
end

promote_type_dual(T, S) = promote_type(T, S)

promote_type_dual(::Type{<:Dual{T}}, S) where {T} = promote_type(T, S)
promote_type_dual(S, ::Type{<:Dual{T}}) where {T} = promote_type(T, S)

broadcasted(::Type{T}, d::DualVector) where {T} = DualVector(T.(d.value), T.(d.jacobian))

"""
Constructor that forces type compatibility
"""
function DualVector(value::AbstractVector, jacobian::AbstractMatrix)
T = promote_type(eltype(value), eltype(jacobian))
DualVector(convert(Vector{T}, value), convert(AbstractMatrix{T}, jacobian))
function DualVector(value::AbstractVector{T}, jacobian::AbstractMatrix{S}) where {T, S}
U = promote_type_dual(T, S)
DualVector(U.(value), U.(jacobian))
end

# Basic equality for Dual numbers
Expand Down
9 changes: 7 additions & 2 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ end
"""
Custom display method for DualVectors.
"""
Base.show(io::IO, ::MIME"text/plain", x::DualVector) =
(print(io, x.value); print(io, " + "); print(io, x.jacobian); print(io, "𝛜"))
show_dual_vector(io::IO, x::AbstractArray, ::Any) = print(io, x)
function show_dual_vector(io::IO, x::DualVector, i = 0)
show_dual_vector(io, x.value, i+1)
print(io, " + $(x.jacobian)ϵ" * repeat('\'', i))
end
Base.show(io::IO, x::DualVector) = show_dual_vector(io, x)
Base.show(io::IO, ::MIME"text/plain", x::DualVector) = Base.show(io, x)

"""
Utility function to compute the jacobian of a function `f` at point `x`.
Expand Down
21 changes: 21 additions & 0 deletions test/broadcast_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,25 @@ using DualArrays, Test, SparseArrays, LinearAlgebra
@test s.jacobian ≈ [0.0 -2.0 -3.0; -1.0 -1.0 -3.0; -1.0 -2.0 -2.0]
@test m.jacobian ≈ [3.0 2.0 3.0; 2.0 6.0 6.0; 3.0 6.0 11.0]
@test div.jacobian ≈ [0.25 -0.5 -0.75; -0.5 -0.5 -1.5; -0.75 -1.5 -1.75]

# Broadcasting between Dual and AbstractVector
a = x .+ [1.0, 2.0, 3.0]
s = [1.0, 2.0, 3.0] .- x
m = x .* [1.0, 2.0, 3.0]
div = [1.0, 2.0, 3.0] ./ x

@test a isa DualVector
@test s isa DualVector
@test m isa DualVector
@test div isa DualVector

@test a.value ≈ [3.0, 4.0, 5.0]
@test s.value ≈ [-1.0, 0.0, 1.0]
@test m.value ≈ [2.0, 4.0, 6.0]
@test div.value ≈ [0.5, 1.0, 1.5]

@test a.jacobian ≈ [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0]
@test s.jacobian ≈ [-1.0 -2.0 -3.0; -1.0 -2.0 -3.0; -1.0 -2.0 -3.0]
@test m.jacobian ≈ [1.0 2.0 3.0; 2.0 4.0 6.0; 3.0 6.0 9.0]
@test div.jacobian ≈ [-0.25 -0.5 -0.75; -0.5 -1.0 -1.5; -0.75 -1.5 -2.25]
end
11 changes: 11 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ using DualArrays: Dual
@test vcat(x, x) == DualVector([1, 1], [1 2 3;1 2 3])
@test vcat(x, y) == DualVector([1, 2, 3], [1 2 3;4 5 6;7 8 9])
end

@testset "Hessian" begin
d = DualVector(
DualVector([1, 2], [1 0;0 1]),
[1 0;0 1]
)
f(x) = x[1] * x[2]
@test f(d) isa Dual
@test f(d).partials isa DualVector
@test f(d).partials.jacobian == [0 1; 1 0]
end

include("broadcast_test.jl")
end
Loading