diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index efe996a9..3a06a443 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -45,7 +45,22 @@ function Integrals._compute_dfdp_and_f(::Integrals.MooncakeVJP, cache, p, Δ) if isinplace(cache) if cache.f isa BatchIntegralFunction - error("MooncakeVJP does not yet support BatchIntegralFunction with in-place functions") + dx = similar( + cache.f.integrand_prototype, + size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1 + ) + _f = x -> (cache.f(dx, x, p); dx) + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + integralfunc_closure_p = p -> (cache.f(dx, x_, p); dx) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + Δ_batch = Δ_val isa AbstractArray ? reshape(Δ_val, size(Δ_val)..., 1) : [Δ_val] + z, grads = Mooncake.value_and_pullback!!( + cache_z, Δ_batch, integralfunc_closure_p, p + ) + return grads[2] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) else dx = similar(cache.f.integrand_prototype) _f = x -> (cache.f(dx, x, p); dx) @@ -63,7 +78,17 @@ function Integrals._compute_dfdp_and_f(::Integrals.MooncakeVJP, cache, p, Δ) else _f = x -> cache.f(x, p) if cache.f isa BatchIntegralFunction - error("MooncakeVJP does not yet support BatchIntegralFunction") + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + integralfunc_closure_p = p -> cache.f(x_, p) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + Δ_batch = Δ_val isa AbstractArray ? reshape(Δ_val, size(Δ_val)..., 1) : [Δ_val] + z, grads = Mooncake.value_and_pullback!!( + cache_z, Δ_batch, integralfunc_closure_p, p + ) + return grads[2] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) else dfdp_ = function (x, p) integralfunc_closure_p = p -> cache.f(x, p) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index c5806a36..c054a5fb 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -182,13 +182,8 @@ do_tests_mooncake = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) end sol_fp = testf(lb, ub, p) - cache = Mooncake.prepare_gradient_cache( - testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p - ) - forwpassval, - gradients = Mooncake.value_and_gradient!!( - cache, testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p - ) + cache = Mooncake.prepare_gradient_cache(testf, lb, ub, p) + forwpassval, gradients = Mooncake.value_and_gradient!!(cache, testf, lb, ub, p) @test forwpassval == sol_fp @@ -363,6 +358,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "Batched, one-dimensional, scalar, oop derivative test" alg = nameof(typeof(alg)) integrand = j scalarize = i bf = BatchIntegralFunction((x, p) -> batch_helper(f, x, p)) do_tests(; f = bf, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f = bf, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) end ## Batch, One-dimensional nout @@ -378,6 +374,10 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), f = bf, scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol ) + do_tests_mooncake(; + f = bf, scalarize, lb = 1.0, ub = 3.0, + p = [2.0i for i in 1:nout], alg, abstol, reltol + ) end ### Batch, N-dimensional @@ -392,6 +392,9 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), do_tests(; f = bf, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol ) + do_tests_mooncake(; + f = bf, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol + ) end ### Batch, N-dimensional nout @@ -408,6 +411,10 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), f = bf, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol ) + do_tests_mooncake(; + f = bf, scalarize, lb = ones(dim), ub = 3ones(dim), + p = [2.0i for i in 1:nout], alg, abstol, reltol + ) end ### Batch, one-dimensional @@ -421,6 +428,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "Batched, one-dimensional, scalar, iip derivative test" alg = nameof(typeof(alg)) integrand = j scalarize = i bfiip = BatchIntegralFunction((y, x, p) -> batch_helper!(f, y, x, p), zeros(0)) do_tests(; f = bfiip, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f = bfiip, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) end ## Batch, one-dimensional nout @@ -437,6 +445,10 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), f = bfiip, scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol ) + do_tests_mooncake(; + f = bfiip, scalarize, lb = 1.0, ub = 3.0, + p = [2.0i for i in 1:nout], alg, abstol, reltol + ) end ### Batch, N-dimensional @@ -453,6 +465,10 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), f = bfiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol ) + do_tests_mooncake(; + f = bfiip, scalarize, lb = ones(dim), + ub = 3ones(dim), p = 2.0, alg, abstol, reltol + ) end ### Batch, N-dimensional nout iip @@ -470,6 +486,10 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), f = bfiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol ) + do_tests_mooncake(; + f = bfiip, scalarize, lb = ones(dim), ub = 3ones(dim), + p = [2.0i for i in 1:nout], alg, abstol, reltol + ) end @testset "ChangeOfVariables rrules" begin diff --git a/test/nested_ad_tests.jl b/test/nested_ad_tests.jl index 80baad34..a0ea832f 100644 --- a/test/nested_ad_tests.jl +++ b/test/nested_ad_tests.jl @@ -1,4 +1,4 @@ -using Integrals, FiniteDiff, ForwardDiff, Cubature, Cuba, Zygote, Test +using Integrals, FiniteDiff, ForwardDiff, Cubature, Cuba, Zygote, Mooncake, Test my_parameters = [1.0, 2.0] my_function(x, p) = x^2 + p[1]^3 * x + p[2]^2 @@ -47,3 +47,19 @@ dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 @test dp2 ≈ dp3 + +function testf3_mooncake(lb, ub, p) + prob = IntegralProblem(_ff3, (lb, ub), p) + return solve( + prob, CubatureJLh(); reltol = 1.0e-3, abstol = 1.0e-3, + sensealg = Integrals.ReCallVJP(Integrals.MooncakeVJP()) + )[1] +end + +if pkgversion(Mooncake) >= v"0.5.26" + testf3_mooncake_p = p -> testf3_mooncake(lb, ub, p) + cache = Mooncake.prepare_gradient_cache(testf3_mooncake_p, p) + _, grads = Mooncake.value_and_gradient!!(cache, testf3_mooncake_p, p) + dp4 = grads[2] + @test dp4 ≈ dp3 +end