diff --git a/src/varnamedtuple/map.jl b/src/varnamedtuple/map.jl index 58f17b726..f9130ca88 100644 --- a/src/varnamedtuple/map.jl +++ b/src/varnamedtuple/map.jl @@ -47,6 +47,36 @@ function DynamicPPL.subset(parent_vnt::VarNamedTuple, vns) init=VarNamedTuple(), ) end +""" + delete!!(vnt::VarNamedTuple, vn::VarName) + +Return a new `VarNamedTuple` with the variable `vn` removed. + +# Examples +```jldoctest +julia> using DynamicPPL, BangBang + +julia> vnt = VarNamedTuple() +VarNamedTuple() + +julia> vnt = setindex!!(vnt, 1.0, @varname(a)) +VarNamedTuple +└─ a => 1.0 + +julia> vnt = setindex!!(vnt, 2.0, @varname(b)) +VarNamedTuple +├─ a => 1.0 +└─ b => 2.0 + +julia> delete!!(vnt, @varname(a)) +VarNamedTuple +└─ b => 2.0 +``` +""" +function BangBang.delete!!(vnt::VarNamedTuple, vn::VarName) + remaining_vns = filter(k -> !subsumes(vn, k), keys(vnt)) + return DynamicPPL.subset(vnt, remaining_vns) +end """ apply!!(func, vnt::VarNamedTuple, name::VarName) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 2714bd17a..110ca0d92 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -24,7 +24,7 @@ using DynamicPPL.VarNamedTuples: NoTemplate, templated_setindex_no_overwrite!! using AbstractPPL: AbstractPPL, VarName, concretize, prefix, @opticof -using BangBang: setindex!!, empty!! +using BangBang: setindex!!, empty!!, delete!! using DimensionalData: DimensionalData as DD using InvertedIndices: InvertedIndices as II using OffsetArrays: OffsetArrays as OA @@ -2094,7 +2094,27 @@ Base.size(st::SizedThing) = st.size end @test densify!!(vnt) == VarNamedTuple(; x=CA.ComponentArray(; a=1.0, b=2.0)) end + @testset "delete!!" begin + vnt = VarNamedTuple() + vnt = setindex!!(vnt, 1.0, @varname(a)) + vnt = setindex!!(vnt, 2.0, @varname(b)) + vnt = setindex!!(vnt, 3.0, @varname(c)) + + # Delete a single variable + vnt2 = delete!!(vnt, @varname(a)) + @test !haskey(vnt2, @varname(a)) + @test haskey(vnt2, @varname(b)) + @test haskey(vnt2, @varname(c)) + # Delete another variable + vnt3 = delete!!(vnt2, @varname(b)) + @test !haskey(vnt3, @varname(b)) + @test haskey(vnt3, @varname(c)) + + # Original is unchanged + @test haskey(vnt, @varname(a)) + @test length(keys(vnt)) == 3 + end @testset "skeleton" begin function test_skeleton(orig_vnt, expected_skeleton) @test (@inferred skeleton(orig_vnt)) == expected_skeleton