Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/DynamicPPL.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- uses: julia-actions/julia-runtest@v1
env:
GROUP: DynamicPPL
AD: ReverseDiff
AD: Mooncake
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
Expand Down
14 changes: 0 additions & 14 deletions .github/workflows/Enzyme.yml

This file was deleted.

14 changes: 0 additions & 14 deletions .github/workflows/ReverseDiff.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- uses: julia-actions/julia-runtest@v1
env:
GROUP: All
AD: ReverseDiff
AD: Mooncake
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
Expand Down
14 changes: 0 additions & 14 deletions .github/workflows/Zygote.yml

This file was deleted.

1 change: 0 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Release 0.7

## Removal of special treatment to `Bijectors.TransformedDistribution`
Expand Down
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.7.1"
version = "0.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -17,19 +16,24 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"

[extensions]
AdvancedVIDifferentiationInterfaceExt = ["DifferentiationInterface"]
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
AdvancedVIForwardDiffExt = ["ForwardDiff"]
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]
AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "DifferentiationInterface", "LogDensityProblems"]
AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "LogDensityProblems"]

[compat]
ADTypes = "1"
Expand All @@ -42,13 +46,15 @@ DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.39, 0.40"
Enzyme = "0.13"
FillArrays = "1.3"
ForwardDiff = "0.10, 1"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4, 0.5"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.6"
Random = "1"
Statistics = "1"
ReverseDiff = "1"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.10, 1.11.2"
Expand Down
38 changes: 19 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
| AD Backend | Integration Status |
|:---------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |

# AdvancedVI.jl

Expand Down Expand Up @@ -53,7 +50,7 @@ function LogDensityProblems.logdensity(model::LogReg, θ)
logprior_β = logpdf(MvNormal(Zeros(d), σ), β)
logprior_σ = logpdf(LogNormal(0, 3), σ)

logit = X*β
logit = X * β
loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y)
return loglike_y + logprior_β + logprior_σ
end
Expand Down Expand Up @@ -89,6 +86,8 @@ Since most VI algorithms assume that the posterior is unconstrained, we will app
This amounts to wrapping it into a `LogDensityProblem` that applies the transformation and the corresponding Jacobian adjustment.

```julia
using ForwardDiff

struct TransformedLogDensityProblem{Prob,BInv}
prob::Prob
binv::BInv
Expand All @@ -113,10 +112,17 @@ function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem)
return prod(Bijectors.output_size(b, (d,)))
end

function LogDensityProblems.logdensity_and_gradient(
prob_trans::TransformedLogDensityProblem, θ
)
f = Base.Fix1(LogDensityProblems.logdensity, prob_trans)
return f(θ), ForwardDiff.gradient(f, θ)
end

function LogDensityProblems.capabilities(
::Type{TransformedLogDensityProblem{Prob,BInv}}
) where {Prob,BInv}
return LogDensityProblems.capabilities(Prob)
return LogDensityProblems.LogDensityOrder{1}()
end;
```

Expand Down Expand Up @@ -151,10 +157,10 @@ prob_trans = TransformedLogDensityProblem(prob)
For the VI algorithm, we will use `KLMinRepGradDescent`:

```julia
using ADTypes, ReverseDiff
using ADTypes, Mooncake
using AdvancedVI

alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff(); operator=ClipScale())
alg = KLMinRepGradDescent(ADTypes.AutoMooncake(); operator=ClipScale())
```

This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient[^TL2014][^RMW2014][^KW2014].
Expand All @@ -165,31 +171,23 @@ For this example, we will use Gaussian variational family, which is part of the
These require the scale matrix to have strictly positive eigenvalues at all times.
Here, the projection operator `ClipScale` ensures this.

This `KLMinRepGradDescent`, in particular, assumes that the target `LogDensityProblem` has gradients.
For this, it is straightforward to use `LogDensityProblemsAD`:

```julia
using DifferentiationInterface: DifferentiationInterface
using LogDensityProblemsAD: LogDensityProblemsAD

prob_trans_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), prob_trans);
```
This `KLMinRepGradDescent`, in particular, differentiates the target `LogDensityProblem` directly through the chosen backend.

For the variational family, we will consider a `FullRankGaussian` approximation:

```julia
using LinearAlgebra

d = LogDensityProblems.dimension(prob_trans_ad)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d)))
d = LogDensityProblems.dimension(prob_trans)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6 * I, d, d)))
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)));
```

We can now run VI:

```julia
max_iter = 10^3
q_opt, info, _ = AdvancedVI.optimize(alg, max_iter, prob_trans_ad, q);
q_opt, info, _ = AdvancedVI.optimize(alg, max_iter, prob_trans, q);
```

Recall that we applied a change-of-variable to the posterior to make it unconstrained.
Expand All @@ -205,5 +203,7 @@ q_trans = Bijectors.TransformedDistribution(q_opt, binv)
For more examples and details, please refer to the documentation.

[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. PMLR.

[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. PMLR.

[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*.
2 changes: 0 additions & 2 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
Expand All @@ -34,5 +33,4 @@ Random = "1"
ReverseDiff = "1"
SimpleUnPack = "1"
StableRNGs = "1"
Zygote = "0.6, 0.7"
julia = "1.10, 1.11.2"
3 changes: 1 addition & 2 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using AdvancedVI
using BenchmarkTools
using Distributions
using DistributionsAD
using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake
using Enzyme, ForwardDiff, ReverseDiff, Mooncake
using FillArrays
using InteractiveUtils
using LinearAlgebra
Expand Down Expand Up @@ -68,7 +68,6 @@ begin
("RepGradELBO + STL", StickingTheLandingEntropy()),
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ReverseDiff", AutoReverseDiff()),
("Mooncake", AutoMooncake(; config=Mooncake.Config())),
# ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)),
Expand Down
12 changes: 2 additions & 10 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,18 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
OpenML = "8b6db2d4-7670-4922-a472-f9537c81ab66"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StanLogDensityProblems = "a545de4d-8dba-46db-9d34-4e41d3f07807"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

Expand All @@ -27,22 +23,18 @@ ADTypes = "1"
Accessors = "0.1"
AdvancedVI = "0.7, 0.6"
Bijectors = "0.13.6, 0.14, 0.15"
DataFrames = "1"
DifferentiationInterface = "0.7"
Distributions = "0.25"
Documenter = "1"
FillArrays = "1"
ForwardDiff = "0.10, 1"
Functors = "0.5"
JSON = "0.21, 1"
LogDensityProblems = "2.1.1"
LogDensityProblemsAD = "1"
Mooncake = "0.4, 0.5"
NormalizingFlows = "0.2.2"
OpenML = "0.3"
Optimisers = "0.3, 0.4"
Plots = "1"
QuasiMonteCarlo = "0.3"
ReverseDiff = "1"
StanLogDensityProblems = "0.1"
StatsFuns = "1"
julia = "1.10, 1.11.2"
7 changes: 7 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@

# Keep GR headless during docs builds.
ENV["GKSwstype"] = "100"

using AdvancedVI
using Documenter
using Mooncake
using Plots

# Necessary for invoking the docstring specializations
using Random
using ADTypes

default(; size=(700, 420))

DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true)

makedocs(;
Expand Down
Loading
Loading