Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.41.7

Accessing a nonexistent variable in a `VarNamedTuple` now throws a `KeyError` with the original `VarName`, instead of an opaque `type NamedTuple has no field ...` error.

# 0.41.6

Add a `factorize::Bool` keyword argument for `pointwise_logdensities(model, values)`, which controls whether pointwise logdensities for factorisable distributions (e.g. `MvNormal`, `product_distribution`, etc.) are returned as a single log-density for the whole distribution, or as an array of log-densities for each factor.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.41.6"
version = "0.41.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions src/varnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, dist::LKJCholesky
for k in keys(val)
# VarNamedTuples have VarNames as keys, PartialArrays have Index optics.
subvn = val isa VarNamedTuple ? prefix(k, vn) : AbstractPPL.append_optic(vn, k)
dval[subvn] = _getindex_optic(val, k)
dval[subvn] = _getindex_optic(val, k, subvn)
end
return AbstractPPL.hasvalue(dval, vn, dist)
end
Expand All @@ -244,7 +244,7 @@ function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName, dist::LKJCholesky
for k in keys(val)
# VarNamedTuples have VarNames as keys, PartialArrays have Index optics.
subvn = val isa VarNamedTuple ? prefix(k, vn) : AbstractPPL.append_optic(vn, k)
dval[subvn] = _getindex_optic(val, k)
dval[subvn] = _getindex_optic(val, k, subvn)
end
return AbstractPPL.getvalue(dval, vn, dist)
end
Expand Down
30 changes: 21 additions & 9 deletions src/varnamedtuple/getset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const IndexWithoutChild = AbstractPPL.Index{<:Tuple,<:NamedTuple,AbstractPPL.Ide
_unimplemented() = error("Not implemented")

"""
DynamicPPL._getindex_optic(collection, optic::AbstractPPL.Optic)
DynamicPPL._getindex_optic(collection, optic::AbstractPPL.Optic, orig_vn::VarName)
DynamicPPL._getindex_optic(collection, vn::VarName)

Access the value in `collection` at the location specified by the given `optic`. If a `VarName`
Expand All @@ -27,16 +27,28 @@ Note that it is only valid to index into a `VarNamedTuple` with a `Property` opt
`PartialArray` with an `Index` optic. Other combinations are not valid. When we have reached
the leaf of the VNT i.e. a value, we could still handle pure `Index` optics if the value is
an `AbstractArray`, but otherwise the only valid optic is `Iden`.
Comment thread
penelopeysm marked this conversation as resolved.

`orig_vn` is used to keep track of the original VarName used to index into a VarNamedTuple,
and is only for error reporting purposes.
"""
function _getindex_optic(vnt::VarNamedTuple, vn::VarName)
return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn))
return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn), vn)
end
function _getindex_optic(vnt::VarNamedTuple, vn::VarName, orig_vn)
return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn), orig_vn)
end
@inline _getindex_optic(@nospecialize(x::Any), ::AbstractPPL.Iden) = x
@inline _getindex_optic(x::Any, o::AbstractPPL.AbstractOptic) = o(x)
function _getindex_optic(vnt::VarNamedTuple, optic::AbstractPPL.Property{S}) where {S}
return _getindex_optic(getindex(vnt.data, S), optic.child)

@inline _getindex_optic(@nospecialize(x::Any), ::AbstractPPL.Iden, orig_vn) = x
@inline _getindex_optic(x::Any, o::AbstractPPL.AbstractOptic, orig_vn) = o(x)
function _getindex_optic(
vnt::VarNamedTuple, optic::AbstractPPL.Property{S}, orig_vn
) where {S}
if !haskey(vnt.data, S)
throw(KeyError(orig_vn))
end
return _getindex_optic(getindex(vnt.data, S), optic.child, orig_vn)
end
function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index)
function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index, orig_vn)
coptic = AbstractPPL.concretize_top_level(optic, pa.data)
child_value =
if _is_multiindex(pa, coptic.ix...; coptic.kw...) &&
Expand All @@ -49,9 +61,9 @@ function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index)
else
getindex(pa, coptic.ix...; coptic.kw...)
end
return _getindex_optic(child_value, optic.child)
return _getindex_optic(child_value, optic.child, orig_vn)
end
function _getindex_optic(arr::AbstractArray, optic::IndexWithoutChild)
function _getindex_optic(arr::AbstractArray, optic::IndexWithoutChild, orig_vn)
coptic = AbstractPPL.concretize_top_level(optic, arr)
return Base.getindex(arr, coptic.ix...; coptic.kw...)
end
Expand Down
2 changes: 1 addition & 1 deletion test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ end
@test_throws ArgumentError to_vector_params(vecvals, ldf)

accs = OnlyAccsVarInfo(VectorParamAccumulator(ldf))
@test_throws ErrorException init!!(
@test_throws KeyError init!!(
extra_model, accs, InitFromPrior(), transform_strategy
)
end
Expand Down
13 changes: 13 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ end
vi, TransformedValue(x, NoTransform()), Normal(), vn, x
)
@test !isempty(vi)

@testset "KeyError for missing varname" begin
@model function test_model()
x ~ Normal()
return nothing
end
vi2 = VarInfo(test_model())
# KeyError propagates from VarNamedTuple through VarInfo
@test_throws KeyError DynamicPPL.getindex_internal(vi2, @varname(y))
@test_throws KeyError DynamicPPL.get_transformed_value(vi2, @varname(y))
# Direct VarNamedTuple access also throws KeyError
@test_throws KeyError vi2.values[@varname(y)]
end
end

@testset "get/set/acclogp" begin
Expand Down
12 changes: 12 additions & 0 deletions test/varnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,18 @@ Base.size(st::SizedThing) = st.size
end
end

@testset "KeyError for missing properties" begin
vnt = @vnt begin
x.a := 1.0
end
# Should throw KeyError for missing top-level symbol
@test_throws KeyError vnt[@varname(y)]
# Should throw KeyError for missing nested property
@test_throws KeyError vnt[@varname(x.b)]
# Sanity check: accessing existing property should work
@test vnt[@varname(x.a)] == 1.0
end

@testset "haskey on PartialArray" begin
@testset "no ALBs" begin
vnt = @vnt begin
Expand Down
Loading