diff --git a/HISTORY.md b/HISTORY.md index 74becb227..55bdd406d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,12 @@ +# 0.41.3 + +Add a lower-level constructor for `LogDensityFunction` which directly takes a VNT of `RangeAndTransform`s plus a sample vectorised input. +This is only intended for use in Turing: users should not need to use this directly. + +All other constructors are still available and unchanged in behaviour. + +To facilitate the functionality needed for Turing, this also adds more accessor functions for `LogDensityFunction`, namely `get_all_ranges_and_transforms`, `get_sample_input_vector`. + # 0.41.2 Export the accessor functions `get_values(::VarInfo)` and `get_logdensity_callable(::LogDensityFunction)`, so that users do not need to access internal fields of these types directly. diff --git a/Project.toml b/Project.toml index b570d9015..36a590677 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.41.2" +version = "0.41.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index f57433bbc..2bb455b66 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -64,10 +64,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte ```@docs LogDensityFunction -get_input_vector_type RangeAndTransform get_range_and_transform +get_all_ranges_and_transforms get_logdensity_callable +get_input_vector_type +get_sample_input_vector ``` Internally, this is accomplished using [`init!!`](@ref) on: diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bc735c4c4..3001210c8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,8 +130,10 @@ export AbstractVarInfo, OnlyAccsVarInfo, to_vector_params, get_input_vector_type, + get_sample_input_vector, RangeAndTransform, get_range_and_transform, + get_all_ranges_and_transforms, get_logdensity_callable, # Leaf contexts AbstractContext, diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 275da3057..da9d5f22a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -29,11 +29,11 @@ using Random: Random """ DynamicPPL.LogDensityFunction( model::Model, - getlogdensity::Any=getlogjoint_internal, - vi_vnt_or_tfm_strategy=_default_vnt(model, UnlinkAll()), - accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity); + getlogdensity::Any, + ranges_and_transforms::VarNamedTuple, + x::AbstractVector{<:Real}, + accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs(getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - fix_transform::Bool=false, ) A struct which contains a model, along with all the information necessary to: @@ -48,6 +48,11 @@ backend type, then `logdensity_and_gradient` is also implemented. ## Positional arguments +!!! note + You should almost never need to call this particular constructor of + `LogDensityFunction`. Instead you should prefer the more convenient constructors that do + not need an input vector `x`. + The first argument is the DynamicPPL model. The second argument, `getlogdensity` should be a callable which takes a single argument: an @@ -71,53 +76,37 @@ are several functions in DynamicPPL that are 'supported' out of the box: was created with a linked or unlinked VarInfo. This is done primarily to ease interoperability with MCMC samplers. -**The third argument** can take many forms, but it essentially specifies a set of vectorised -parameters to be used for constructing the single vectorised representation of the model. -The parameters stored in this argument determine whether the resulting `LogDensityFunction` -will be linked, unlinked, or mixed. For example, if you pass a `VarNamedTuple` consisting -entirely of `TransformedValue{T,DynamicLink}`s, then the resulting `LogDensityFunction` will -be fully linked. +The third argument is a VarNamedTuple which maps VarNames seen in the model to their +corresponding [`RangeAndTransform`](@ref). Each `RangeAndTransform`, as the name suggests, +contains a *range* which says which indices in the vectorised parameters correspond to that +variable, and a *transform* which says how to obtain the original (raw) value of that +variable from the slice. -You can pass either: +The fourth argument is a sample vector of parameters, which should be consistent with the +ranges specified in the previous argument. This is used to determine the dimension and the +expected element type of the vectorised parameters, and is also used in AD preparation. The +values in the vector are not important. -- **`vnt`**: a `VarNamedTuple` which contains vectorised representations of all the random - variables in the model. This is useful if you already have one, either by creating a full - `VarInfo` and accessing its `values` field, or by creating a `OnlyAccsVarInfo` with a - `VectorValueAccumulator` and calling `get_vector_values` on it. - -- **`vi`**: a `VarInfo`, in which case the `vi.values` field is used. - -- **`oavi`**: an [`OnlyAccsVarInfo`](@ref), in which case the [`get_vector_values`](@ref) - function is used to extract a VarNamedTuple of vector values from the - [`VectorValueAccumulator`](@ref) inside it. If the `OnlyAccsVarInfo` does not contain a - `VectorValueAccumulator`, then an error is thrown. +!!! warning "Compiled ReverseDiff" + For compiled ReverseDiff, the values in the vector are used to compile the tape, and so + if your model has control flow that depends on the values of the parameters, then you + may find that the resulting `LogDensityFunction` only yields correct results for parameters + that trigger the same control flow as the sample vector. In general, functions with + parameter-dependent control flow should not be differentiated with compiled ReverseDiff. -**But by far the most convenient way to specify this argument is to just pass a transform -strategy.** In this case, the transform strategy will be first used to generate a set of -vectorised parameters, from which the relevant information will be extracted. This does come -at the cost of doing one extra model evaluation. Whilst `LogDensityFunction` construction is -unlikely to occur in performance-sensitive code paths, if you absolutely cannot pay this -price, then you should generate the vectorised parameters yourself and pass them here -instead. +The last positional argument, `accs`, allows you to specify an `AccumulatorTuple` or a tuple +of `AbstractAccumulator`s which will be used _when evaluating the log density_`. (Note that +any accumulators from the previous argument are discarded.) This argument is not mandatory: +by default, this uses an internal function, `DynamicPPL.ldf_accs`, which attempts to choose +an appropriate set of accumulators based on which kind of log-density is being calculated. -The last argument, `accs`, allows you to specify an `AccumulatorTuple` or a tuple of -`AbstractAccumulator`s which will be used _when evaluating the log density_`. (Note that any -accumulators from the previous argument are discarded.) By default, this uses an internal -function, `DynamicPPL.ldf_accs`, which attempts to choose an appropriate set of accumulators -based on which kind of log-density is being calculated. +## Keyword arguments If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the gradient of the log density. Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD backend itself to have been loaded (e.g. with `import Backend`). -Finally, the `fix_transforms` keyword argument allows you to specify whether the transforms -used in the `LogDensityFunction` should be cached at the time of construction. If so, the -model is evaluated once using the provided transform strategy, and the transforms used for -each variable are stored in the `LogDensityFunction`. This allows you to avoid the overhead -of recalculating transforms during each log-density evaluation. See [the documentation on -fixed transforms](@ref fixed-transforms) for more information. - ## Fields Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart @@ -129,6 +118,14 @@ from: - `ldf.transform_strategy`: The transform strategy that specifies the transforms for all variables in the model. +For all other fields, please use the corresponding getter functions provided in the API: + +- [`get_logdensity_callable`](@ref) +- [`get_input_vector_type`](@ref) +- [`get_sample_input_vector`](@ref) +- [`get_range_and_transform`](@ref) +- [`get_all_ranges_and_transforms`](@ref) + # Extended help Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a @@ -190,7 +187,7 @@ struct LogDensityFunction{ # type of the vector passed to logdensity functions X<:AbstractVector, AC<:AccumulatorTuple, - # whether all transforms are FixedTransforms, enabling fast parameter extraction + # whether all transforms are FixedTransforms AllFixed, } model::M @@ -200,59 +197,27 @@ struct LogDensityFunction{ _varname_ranges::VNT _adprep::ADP _dim::Int + _x::X _accs::AC function LogDensityFunction( model::Model, - getlogdensity::Any=getlogjoint_internal, - vnt::VarNamedTuple=_default_vnt(model, UnlinkAll()), + getlogdensity::Any, + ranges_and_transforms::VarNamedTuple, + x::AbstractVector{<:Real}, accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs( getlogdensity ); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - fix_transforms::Bool=false, ) + dim = length(x) # Determine LDF transform strategy. - dynamic_transform_strategy = infer_transform_strategy_from_values(vnt) - # `dynamic_transform_strategy` might be LinkAll() or UnlinkAll(), for example. We - # might need to convert this to a set of fixed transforms. - transform_strategy = if fix_transforms - # Reevaluate model again to determine the fixed transforms. This is kind of - # wasteful: for example, we could tie this model evaluation to one of the - # previous ones, but it's fine, since it's only done once in the LDF - # constructor. - transforms_vnt = get_fixed_transforms(model, dynamic_transform_strategy) - fixed_transform_strategy = WithTransforms( - transforms_vnt, dynamic_transform_strategy - ) - # We need to update `vnt` to be consistent with the new transform strategy. - vnt = update_transforms!!(vnt, transforms_vnt) - fixed_transform_strategy - else - # No fixing; just traverse the VNT to determine both. - dynamic_transform_strategy - end - ranges_and_transforms = get_rangeandtransforms(vnt) - + transform_strategy = infer_transform_strategy_from_values(ranges_and_transforms) # Determine whether all transforms are fixed. This enables fast parameter # extraction in ParamsWithStats without model re-evaluation. all_fixed = all( rat -> rat.transform isa FixedTransform, values(ranges_and_transforms) ) - - # Get vectorised parameters. Note that `internal_values_as_vector` just concatenates - # all the vectors inside in iteration order of the VNT's keys. *In principle*, the - # result of that should always be consistent with the ranges extracted above via - # `get_rangeandtransforms`, since both are based on the same underlying VNT, and both - # iterate over the keys in the same order. However, this is an implementation - # detail, and so we should probably not rely on it! - # Therefore, we use `to_vector_params_inner` to also perform some checks that the - # vectorised parameters are concatenated in the order specified by the ranges. We do - # need to use internal_values_as_vector here once to get the correct element type - # and dimension. - trial_x = internal_values_as_vector(vnt) - dim, et = length(trial_x), eltype(trial_x) - x = to_vector_params_inner(vnt, ranges_and_transforms, et, dim) # convert to AccumulatorTuple if needed accs = AccumulatorTuple(accs) # Do AD prep if needed @@ -286,11 +251,98 @@ struct LogDensityFunction{ ranges_and_transforms, prep, dim, + x, accs, ) end end +""" + DynamicPPL.LogDensityFunction( + model::Model, + getlogdensity::Any=getlogjoint_internal, + vecvals_or_strategy=UnlinkAll(), + accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs(getlogdensity); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + fix_transforms::Bool=false, + ) + +Most users of LogDensityFunction should use this constructor, which does **not** require +passing a sample input vector `x`. + +The first two arguments are the same as in the four-argument constructor. + +- `model` is the DynamicPPL model for which we want to construct a LogDensityFunction. +- `getlogdensity` is a callable which takes a single argument: an `OnlyAccsVarInfo`, and + returns a `Real` corresponding to the log density of interest. Most of the time this is + `getlogjoint_internal`. + +**The third argument** can take many forms, but it essentially encodes all the necessary +information to generate the `RangeAndTransform`s, as well as the sample input vector `x` +that is required for the four-argument constructor. + +You can pass either: + +- **`vnt`**: a `VarNamedTuple` which contains vectorised representations of all the random + variables in the model (i.e., it maps `VarName`s to + `TransformedValue{<:AbstractVector}`s). This is useful if you already have one, either by + creating a full `VarInfo` and accessing its `values` field, or by creating a + `OnlyAccsVarInfo` with a `VectorValueAccumulator` and calling `get_vector_values` on it. + +- **`vi`**: a `VarInfo`, in which case the `vi.values` field is used. + +- **`oavi`**: an [`OnlyAccsVarInfo`](@ref), in which case the [`get_vector_values`](@ref) + function is used to extract a VarNamedTuple of vector values from the + [`VectorValueAccumulator`](@ref) inside it. If the `OnlyAccsVarInfo` does not contain a + `VectorValueAccumulator`, then an error is thrown. + +- **`transform_strategy`**: *by far the most convenient way*. In this case, the transform + strategy will be first used to generate a set of vectorised parameters, from which the + relevant information will be extracted. This does come at the cost of doing one extra + model evaluation. Whilst `LogDensityFunction` construction is unlikely to occur in + performance-sensitive code paths, if you absolutely cannot pay this price, then you should + generate the vectorised parameters yourself and pass them here instead. + +## Keyword arguments + +The `adtype` keyword argument allows you to specify an AD type for gradient preparation and +calculation. + +The `fix_transforms` keyword argument allows you to specify whether the transforms used in +the `LogDensityFunction` should be cached at the time of construction. If so, the model is +evaluated once using the provided transform strategy, and the transforms used for each +variable are stored in the `LogDensityFunction`. This allows you to avoid the overhead of +recalculating transforms during each log-density evaluation. See [the documentation on fixed +transforms](@ref fixed-transforms) for more information. +""" +function LogDensityFunction( + model::Model, + getlogdensity::Any, + # This VNT should map varnames to TransformedValue{<:AbstractVector} + vecvals::VarNamedTuple, + accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs(getlogdensity); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + fix_transforms::Bool=false, +) + # Handle fixed transforms flag. + if fix_transforms + all_fixed = all(tv -> get_transform(tv) isa FixedTransform, values(vecvals)) + if !all_fixed + # We need to update the transforms in `vnt` to be consistent with the new + # transform strategy. This requires reevaluating the model in + # `get_fixed_transforms`, which is perhaps a bit unfortunate, but probably + # tolerable since this isn't something that is in a performance-sensitive code + # path. + dynamic_transform_strategy = infer_transform_strategy_from_values(vecvals) + transforms_vnt = get_fixed_transforms(model, dynamic_transform_strategy) + vecvals = update_transforms!!(vecvals, transforms_vnt) + end + end + ranges_and_transforms, x = get_rat_and_samplevec(vecvals) + return LogDensityFunction( + model, getlogdensity, ranges_and_transforms, x, accs; adtype=adtype + ) +end function LogDensityFunction( model::Model, getlogdensity::Any, @@ -323,22 +375,19 @@ function LogDensityFunction( end function LogDensityFunction( model::Model, - getlogdensity::Any, - link_strat::AbstractTransformStrategy, + getlogdensity::Any=getlogjoint_internal, + transform_strategy::AbstractTransformStrategy=UnlinkAll(), accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs(getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, fix_transforms::Bool=false, ) - vnt = _default_vnt(model, link_strat) - return LogDensityFunction( - model, getlogdensity, vnt, accs; adtype=adtype, fix_transforms=fix_transforms - ) -end - -function _default_vnt(model::Model, transform_strategy::AbstractTransformStrategy) + # note that this reevaluates the model oavi = OnlyAccsVarInfo(VectorValueAccumulator()) _, oavi = DynamicPPL.init!!(model, oavi, InitFromPrior(), transform_strategy) - return getacc(oavi, Val(VECTORVAL_ACCNAME)).values + vecvals = getacc(oavi, Val(VECTORVAL_ACCNAME)).values + return LogDensityFunction( + model, getlogdensity, vecvals, accs; adtype=adtype, fix_transforms=fix_transforms + ) end """ @@ -351,20 +400,32 @@ Note that if you pass a vector of a different type, it will be converted to the type. This allows you however to determine upfront what kind of vector should be passed in. It is also useful for determining e.g. whether Float32 or Float64 parameters are expected. """ -function get_input_vector_type(::LogDensityFunction{M,A,L,G,R,P,X}) where {M,A,L,G,R,P,X} - return X -end +get_input_vector_type(::LogDensityFunction{M,A,L,G,R,P,X}) where {M,A,L,G,R,P,X} = X + +""" + DynamicPPL.get_sample_input_vector(::LogDensityFunction)::AbstractVector{<:Real} + +Get the sample input vector `x` used to construct the LogDensityFunction. +""" +get_sample_input_vector(ldf::LogDensityFunction) = ldf._x """ DynamicPPL.get_range_and_transform(ldf::LogDensityFunction, vn::VarName)::RangeAndTransform A `LogDensityFunction` stores a mapping from `VarName`s to their corresponding ranges in the vectorised parameter representation, along with their transform status. This function -retrieves that information. +retrieves that information for a single VarName. """ -function get_range_and_transform(ldf::LogDensityFunction, vn::VarName) - return ldf._varname_ranges[vn] -end +get_range_and_transform(ldf::LogDensityFunction, vn::VarName) = ldf._varname_ranges[vn] + +""" + DynamicPPL.get_all_ranges_and_transforms(ldf::LogDensityFunction)::VarNamedTuple + +A `LogDensityFunction` stores a mapping from `VarName`s to their corresponding ranges in the +vectorised parameter representation, along with their transform status. This function +retrieves the complete mapping. +""" +get_all_ranges_and_transforms(ldf::LogDensityFunction) = ldf._varname_ranges """ DynamicPPL.get_logdensity_callable(ldf::LogDensityFunction) @@ -375,9 +436,7 @@ usecases in DynamicPPL use [`DynamicPPL.getlogjoint_internal`](@ref) for this pu This function retrieves that callable. """ -function get_logdensity_callable(l::LogDensityFunction) - return l._getlogdensity -end +get_logdensity_callable(l::LogDensityFunction) = l._getlogdensity ################################### # LogDensityProblems.jl interface # @@ -580,7 +639,7 @@ _use_closure(::ADTypes.AbstractADType) = false ###################################################### """ - get_rangeandtransforms(vnt::VarNamedTuple) + get_rat_and_samplevec(vnt::VarNamedTuple) Given a `VarNamedTuple` that contains vectorised values (i.e., `TransformedValue{<:AbstractVector}`), extract the ranges of each variable in the vectorised @@ -588,26 +647,35 @@ parameter representation. Further infer the transform status of each variable fr of the vectorised value. This function returns a VarNamedTuple mapping all VarNames to their corresponding -`RangeAndTransform`. +`RangeAndTransform`, plus a vector of parameters obtained by concatenating all the +vectorised values together. """ -function get_rangeandtransforms(vnt::VarNamedTuple) +function get_rat_and_samplevec(vnt::VarNamedTuple) # Note: can't use map_values!! here as that might mutate the VNT itself! - ranges_vnt, _ = mapreduce( + ranges_vnt, x, _ = mapreduce( identity, - function ((ranges_vnt, offset), pair) + function ((ranges_vnt, params, offset), pair) vn, tv = pair val = get_internal_value(tv) + if !(val isa AbstractVector) + throw( + ArgumentError( + "Expected all values in the provided VarNamedTuple to be TransformedValues wrapping AbstractVectors, but the value for variable `$vn` is a $(typeof(val)).", + ), + ) + end range = offset:(offset + length(val) - 1) offset += length(val) + params = vcat(params, val) ral = RangeAndTransform(range, tv.transform) template = vnt.data[AbstractPPL.getsym(vn)] ranges_vnt = templated_setindex!!(ranges_vnt, ral, vn, template) - return ranges_vnt, offset + return ranges_vnt, params, offset end, vnt; - init=(VarNamedTuple(), 1), + init=(VarNamedTuple(), Union{}[], 1), ) - return ranges_vnt + return ranges_vnt, x end """ diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 2a41034ba..107772c61 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -88,9 +88,15 @@ end ldf = LogDensityFunction(f(), getlogprior, UnlinkAll()) @test DynamicPPL.get_logdensity_callable(ldf) == getlogprior @test DynamicPPL.get_input_vector_type(ldf) == Vector{Float64} + @test DynamicPPL.get_sample_input_vector(ldf) isa Vector{Float64} + @test length(DynamicPPL.get_sample_input_vector(ldf)) == 1 rat = DynamicPPL.get_range_and_transform(ldf, @varname(x)) @test rat.range == 1:1 @test rat.transform isa Unlink + vnt = DynamicPPL.get_all_ranges_and_transforms(ldf) + @test only(keys(vnt)) == @varname(x) + @test vnt[@varname(x)].range == 1:1 + @test vnt[@varname(x)].transform isa Unlink end @testset "LogDensityFunction: correctness with multiple threads" begin