diff --git a/.gitignore b/.gitignore index 10d8a5f69..723790dec 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ Manifest.toml benchmarks/*.json LocalPreferences.toml + diff --git a/HISTORY.md b/HISTORY.md index a1d040ec8..a1eef96b2 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,7 @@ # 0.41.7 +Enable usage of `ComponentVector`s on the left-hand side of tilde-statements. + Accessing a nonexistent variable in a `VarNamedTuple` now throws a `KeyError` with the original `VarName`, instead of an opaque `type NamedTuple has no field ...` error. # 0.41.6 diff --git a/Project.toml b/Project.toml index f7820a2c5..f1be81a23 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -39,6 +40,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] +DynamicPPLComponentArraysExt = ["ComponentArrays"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] @@ -55,6 +57,7 @@ BangBang = "0.4.1" Bijectors = "0.15.17" Chairmarks = "1.3.1" Compat = "4" +ComponentArrays = "0.15" ConstructionBase = "1.5.4" DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" diff --git a/ext/DynamicPPLComponentArraysExt.jl b/ext/DynamicPPLComponentArraysExt.jl new file mode 100644 index 000000000..0e239ee63 --- /dev/null +++ b/ext/DynamicPPLComponentArraysExt.jl @@ -0,0 +1,64 @@ +module DynamicPPLComponentArraysExt +using DynamicPPL: DynamicPPL +using DynamicPPL.VarNamedTuples: + PartialArray, + AllowAll, + SetPermissions, + _setindex_optic!!, + _getindex_optic, + make_leaf, + make_leaf_singleindex, + _is_multiindex, + make_leaf_multiindex +using ComponentArrays: ComponentArrays, ComponentArray, ComponentVector +using AbstractPPL + +# Helper: convert a Property optic label S to an integer Index optic +function _property_to_index( + template::ComponentVector, optic::AbstractPPL.Property{S} +) where {S} + ax = ComponentArrays.getaxes(template)[1] + idx = first(ax[S].idx) + return AbstractPPL.Index((idx,), NamedTuple(), optic.child) +end + +function DynamicPPL.VarNamedTuples.make_leaf( + value, optic::AbstractPPL.Property{S}, template::ComponentVector +) where {S} + return if optic.child isa AbstractPPL.Iden + index_optic = _property_to_index(template, optic) + make_leaf(value, index_optic, template) + else + # This branch is needed to handle nested axes in ComponentArrays: the idea is that + # if x is e.g. ComponentArray(a=(b=1)) and we are trying to set `x.a.b`, then we + # first index into `x.a` to get the slice of the ComponentArray. The easiest way to + # handle this is to call the default method. + invoke( + make_leaf, + Tuple{Any,AbstractPPL.Property{S},AbstractArray}, + value, + optic, + template, + ) + end +end + +function DynamicPPL.VarNamedTuples._setindex_optic!!( + pa::PartialArray{<:Any,<:Any,<:ComponentVector}, + value, + optic::AbstractPPL.Property{S}, + template, + permissions::SetPermissions=AllowAll(), +) where {S} + index_optic = _property_to_index(pa.data, optic) + return _setindex_optic!!(pa, value, index_optic, template, permissions) +end + +function DynamicPPL.VarNamedTuples._getindex_optic( + pa::PartialArray{<:Any,<:Any,<:ComponentVector}, optic::AbstractPPL.Property{S}, orig_vn +) where {S} + index_optic = _property_to_index(pa.data, optic) + return _getindex_optic(pa, index_optic, orig_vn) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 73cff23ed..944c32102 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,12 +30,14 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" [compat] ADTypes = "1" AbstractMCMC = "5.10" AbstractPPL = "0.14" Accessors = "0.1" Aqua = "0.8" +ComponentArrays = "0.15" BangBang = "0.4" Bijectors = "0.15.17" Chairmarks = "1" diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index cc5f0dac3..2714bd17a 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -28,6 +28,7 @@ using BangBang: setindex!!, empty!! using DimensionalData: DimensionalData as DD using InvertedIndices: InvertedIndices as II using OffsetArrays: OffsetArrays as OA +using ComponentArrays: ComponentArrays as CA struct GetSetTestCase # The VarName being set. @@ -309,6 +310,60 @@ Base.size(st::SizedThing) = st.size ) end + @testset "ComponentArray" begin + ca = CA.ComponentArray(; a=1.0, b=2.0) + test_get_set(GetSetTestCase(@varname(x[1]), 1.0, ca, [])) + test_get_set(GetSetTestCase(@varname(x[2]), 2.0, ca, [])) + test_get_set(GetSetTestCase(@varname(x.a), 1.0, ca, [])) + test_get_set(GetSetTestCase(@varname(x.b), 2.0, ca, [])) + test_get_set(GetSetTestCase(@varname(x[1:2]), [1.0, 2.0], ca, [])) + + # ComponentVector with array-valued fields + ca3 = CA.ComponentArray(; a=[1.0, 2.0], b=[3.0, 4.0]) + test_get_set(GetSetTestCase(@varname(x.a), [1.0, 2.0], ca3, [])) + test_get_set(GetSetTestCase(@varname(x.b), [3.0, 4.0], ca3, [])) + test_get_set(GetSetTestCase(@varname(x.a[1]), 1.0, ca3, [])) + + # with nested fields + ca4 = CA.ComponentArray(; a=(; x=1.0, y=2.0)) + test_get_set(GetSetTestCase(@varname(x.a.x), 10.0, ca4, [])) + test_get_set(GetSetTestCase(@varname(x.a.y), 20.0, ca4, [])) + test_get_set(GetSetTestCase(@varname(x[1]), 10.0, ca4, [])) + test_get_set(GetSetTestCase(@varname(x[2]), 20.0, ca4, [])) + + # Mixed index/property access + val = rand() + vns = (@varname(x[1]), @varname(x.a)) + for set_vn in vns + vnt = DynamicPPL.templated_setindex!!(VarNamedTuple(), val, set_vn, ca) + for get_vn in vns + @test vnt[get_vn] == val + end + end + + # Check that setting one and overwriting with the other works + val = rand() + new_val = val + 1 + for (vn1, vn2) in + ((@varname(x[1]), @varname(x.a)), (@varname(x.a), @varname(x[1]))) + vnt = VarNamedTuple() + vnt = DynamicPPL.templated_setindex!!(vnt, val, vn1, ca) + @test vnt[vn1] == vnt[vn2] == val # Sanity check. + vnt = DynamicPPL.templated_setindex!!(vnt, new_val, vn2, ca) + @test vnt[vn1] == vnt[vn2] == new_val + end + + # Check that MustNotOverwrite is respected. + for vn1 in vns + vnt = DynamicPPL.templated_setindex!!(VarNamedTuple(), val, vn1, ca) + for vn2 in vns + @test_throws MustNotOverwriteError DynamicPPL.VarNamedTuples.templated_setindex_no_overwrite!!( + vnt, new_val, vn2, ca + ) + end + end + end + @testset "InvertedIndices" begin # TODO(penelopeysm): Templated setindex fails for II.Not(). I really don't know # why but there is some failure in constant propagation when setting the mask @@ -2029,6 +2084,15 @@ Base.size(st::SizedThing) = st.size x[2:3] := SizedThing((2,)) end @test densify!!(vnt) == vnt + + # Check with ComponentArrays + x = CA.ComponentArray(; a=0.0, b=0.0) + vnt = @vnt begin + @template x + x.a := 1.0 + x.b := 2.0 + end + @test densify!!(vnt) == VarNamedTuple(; x=CA.ComponentArray(; a=1.0, b=2.0)) end @testset "skeleton" begin @@ -2147,6 +2211,13 @@ Base.size(st::SizedThing) = st.size end v12s = VarNamedTuple(; x=DD.DimArray(fill(nothing, 2, 3), (:a, :b))) test_skeleton(v12, v12s) + + v13 = @vnt begin + @template x = CA.ComponentArray(; a=0.0, b=0.0) + x.a := 1.0 + end + v13s = VarNamedTuple(; x=CA.ComponentArray(; a=nothing, b=nothing)) + test_skeleton(v13, v13s) end end