Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/workflows/AutoDiffIntegration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ on:
AD:
required: true
type: string
permissions:
contents: read

concurrency:
# Skip intermediate builds: always.
Expand All @@ -14,18 +16,14 @@ concurrency:

jobs:
test:
name: Julia ${{ matrix.AD }} - ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
name: Julia ${{ matrix.AD }} - ${{ matrix.version }} - ubuntu-latest - ${{ github.event_name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- 'lts'
- '1.11'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/DynamicPPL.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@ on:
tags: ['*']
pull_request:
workflow_dispatch:
permissions:
contents: read
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
name: Julia ${{ matrix.version }} - ubuntu-latest - ${{ github.event_name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- 'lts'
- '1.11'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@ on:
tags: ['*']
pull_request:
workflow_dispatch:
permissions:
contents: read
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
name: Julia ${{ matrix.version }} - ubuntu-latest - ${{ github.event_name }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- 'lts'
- '1.11'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
14 changes: 0 additions & 14 deletions .github/workflows/Zygote.yml

This file was deleted.

1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
|:---------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |

Expand Down
2 changes: 1 addition & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
# Arguments
- `ad::ADTypes.AbstractADType`:
automatic differentiation backend. Currently supports
`ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`,
`ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`,
`ADTypes.AutoMooncake()` and
`ADTypes.AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Evidence lower-bound objective with the reparameterization gradient formulation[
# Requirements
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblem` should satisfy either of the following: The target has a capability of at least `LogDensityProblems.LogDensityOrder{1}()` and the AD backend is one of `ReverseDiff`, `Zygote`, `Mooncake`, and `AutoEnzyme` in reverse mode so that `ADTypes.mode(adtype) == ADTypes.ReverseMode` is true. (In this case, `AdvancedVI` will take advantage of the existing `LogDensityProblems.logdensity_and_gradient`.) Otherwise, `LogDensityProblems.logdensity` should be differentiable under the selected AD backend.
- The target `LogDensityProblem` should satisfy either of the following: The target has a capability of at least `LogDensityProblems.LogDensityOrder{1}()` and the AD backend is one of `ReverseDiff`, `Mooncake`, and `AutoEnzyme` in reverse mode so that `ADTypes.mode(adtype) == ADTypes.ReverseMode` is true. (In this case, `AdvancedVI` will take advantage of the existing `LogDensityProblems.logdensity_and_gradient`.) Otherwise, `LogDensityProblems.logdensity` should be differentiable under the selected AD backend.
- The sampling process `rand(q)` must be differentiable by the selected AD backend.

Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
Expand Down Expand Up @@ -53,10 +53,10 @@ function init(
prob
else
if !(
adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme} &&
adtype isa Union{<:AutoReverseDiff,<:AutoMooncake,<:AutoEnzyme} &&
ADTypes.mode(adtype) isa ADTypes.ReverseMode
)
@info "The capability of the supplied target `LogDensityProblem` $(capability) is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode."
@info "The capability of the supplied target `LogDensityProblem` $(capability) is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoMooncake`, or `AutoEnzyme` in reverse mode."
end
MixedADLogDensityProblem(prob)
end
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2.1, 1"
Expand All @@ -46,5 +45,4 @@ Statistics = "1"
StatsBase = "0.34"
Test = "1"
Tracker = "0.2.20"
Zygote = "0.6.63, 0.7"
julia = "1.10, 1.11.2"
4 changes: 2 additions & 2 deletions test/general/mixedad_logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ function mixedad_test_fwd(x, prob)
)/2
end

# MixedADLogDensityProblem only supports ReverseDiff, Zygote, Enzyme, Mooncake in reverse-mode
if (AD isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoEnzyme,<:AutoMooncake}) &&
# MixedADLogDensityProblem only supports ReverseDiff, Enzyme, Mooncake in reverse-mode
if (AD isa Union{<:AutoReverseDiff,<:AutoEnzyme,<:AutoMooncake}) &&
(ADTypes.mode(AD) isa ADTypes.ReverseMode)
@testset "MixedADLogDensityProblem" begin
model = MixedADTestModel()
Expand Down
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ const AD = if AD_str == "ReverseDiff"
AutoReverseDiff()
elseif AD_str == "ForwardDiff"
AutoForwardDiff()
elseif AD_str == "Zygote"
using Zygote
AutoZygote()
elseif AD_str == "Mooncake"
using Mooncake
AutoMooncake(; config=Mooncake.Config())
Expand All @@ -39,6 +36,8 @@ elseif AD_str == "Enzyme"
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
)
else
throw(ArgumentError("Unsupported AD backend for tests: $AD_str"))
end

if GROUP == "DynamicPPL"
Expand Down
Loading