Skip to content

Enzyme reverse stochastic failures with Threads.@threads + TSVI #1131

@penelopeysm

Description

@penelopeysm

Run this extract using #1113 or any subsequent PR. The correct gradient is -51.0 (both FiniteDifferences and ForwardDiff get this right). However, sometimes Enzyme misses out on some terms and reports -50 or -48. This happens with both slow LDF and fast LDF, so it's unrelated to #1113.

(Make sure to launch Julia with multiple threads, of course)

using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra, ADTypes, ForwardDiff, Enzyme, FiniteDifferences

if Threads.nthreads() == 1
    error("run this code with multiple threads")
end

const adtypes = (
    AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
    AutoForwardDiff(),
    AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const)
)

function check_ldfs(model)
    vi = VarInfo(model)
    xs = [1.0]
    for adtype in adtypes
        sldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
        fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi; adtype=adtype)
        sldf_grad = LogDensityProblems.logdensity_and_gradient(sldf, xs)
        fldf_grad = LogDensityProblems.logdensity_and_gradient(fldf, xs)
        @show adtype
        @show sldf_grad
        @show fldf_grad
        @assert sldf_grad[2]  fldf_grad[2]
    end
end
@model function threads(y=zeros(50))
    x ~ Normal()
    Threads.@threads for i in eachindex(y)
        y[i] ~ Normal(x)
    end
end
check_ldfs(setthreadsafe(threads(), true))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions