diff --git a/Project.toml b/Project.toml index da88f4e6..8e78e0e7 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ MuladdMacro = "0.2" OpenCL = "0.9, 0.10" Parameters = "0.12" RecursiveArrayTools = "3.37, 4" -SciMLBase = "2.144" +SciMLBase = "3" Setfield = "1" SimpleDiffEq = "1.11" SimpleNonlinearSolve = "2" diff --git a/src/algorithms.jl b/src/algorithms.jl index efada405..7f8878f3 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -77,7 +77,7 @@ u0 = Float32[1.0; 0.0; 0.0] tspan = (0.0f0, 100.0f0) p = [10.0f0, 28.0f0, 8 / 3.0f0] prob = ODEProblem(lorenz, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = rand(Float32, 3) .* p) +prob_func = (prob, ctx) -> remake(prob, p = rand(Float32, 3) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) @time sol = solve(monteprob, Tsit5(), EnsembleGPUArray(CUDADevice()), trajectories = 10_000, saveat = 1.0f0) @@ -143,7 +143,7 @@ u0 = @SVector [1.0f0; 0.0f0; 0.0f0] tspan = (0.0f0, 10.0f0) p = @SVector [10.0f0, 28.0f0, 8 / 3.0f0] prob = ODEProblem{false}(lorenz, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) +prob_func = (prob, ctx) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) @time sol = solve( diff --git a/src/solve.jl b/src/solve.jl index bf510d66..3e88f1c5 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -10,15 +10,30 @@ function SciMLBase.__solve( }; trajectories, batch_size = trajectories, unstable_check = (dt, u, p, t) -> false, adaptive = true, + seed = nothing, + rng = nothing, + rng_func = SciMLBase.default_rng_func, kwargs... ) if trajectories == 1 return SciMLBase.__solve( ensembleprob, alg, EnsembleSerial(); trajectories = 1, - kwargs... + seed, rng, rng_func, kwargs... ) end + # Pre-generate per-trajectory seeds for reproducibility (matching SciMLBase v3 protocol) + sim_seeds = (rng !== nothing || seed !== nothing) ? + SciMLBase.generate_sim_seeds(rng, seed, trajectories) : nothing + + # Bundle ensemble RNG state for passing to SciMLBase.solve_batch (CPU offload path) + ensemble_rng_state = (; + sim_seeds, + _solve_rng_mode = Val(:none), + rng_func, + master_rng = rng, + ) + cpu_trajectories = ( ( ensemblealg isa EnsembleGPUArray || @@ -53,8 +68,8 @@ function SciMLBase.__solve( function f() return SciMLBase.solve_batch( - ensembleprob, _alg, EnsembleThreads(), cpu_II, nothing; - kwargs... + ensembleprob, _alg, EnsembleThreads(), cpu_II, nothing, + ensemble_rng_state; kwargs... ) end @@ -69,6 +84,7 @@ function SciMLBase.__solve( time = @elapsed sol = batch_solve( ensembleprob, alg, ensemblealg, 1:gpu_trajectories, adaptive; + sim_seeds, rng_func, master_rng = rng, unstable_check = unstable_check, kwargs... ) if cpu_trajectories != 0 @@ -83,6 +99,7 @@ function SciMLBase.__solve( similar( batch_solve( ensembleprob, alg, ensemblealg, 1:batch_size, adaptive; + sim_seeds, rng_func, master_rng = rng, unstable_check = unstable_check, kwargs... ), 0 @@ -100,6 +117,7 @@ function SciMLBase.__solve( end batch_data = batch_solve( ensembleprob, alg, ensemblealg, I, adaptive; + sim_seeds, rng_func, master_rng = rng, unstable_check = unstable_check, kwargs... ) if ensembleprob.reduction !== SciMLBase.DEFAULT_REDUCTION @@ -120,6 +138,7 @@ function SciMLBase.__solve( end x = batch_solve( ensembleprob, alg, ensemblealg, I, adaptive; + sim_seeds, rng_func, master_rng = rng, unstable_check = unstable_check, kwargs... ) yield() @@ -145,10 +164,20 @@ function SciMLBase.__solve( end end +function _make_ensemble_context(i, sim_seeds, rng_func, master_rng) + sim_seed = sim_seeds !== nothing ? sim_seeds[i] : nothing + pre_ctx = SciMLBase.EnsembleContext(i, 1, 0, sim_seed, nothing, master_rng) + sim_rng = rng_func(pre_ctx) + return @set pre_ctx.rng = sim_rng +end + function batch_solve( ensembleprob, alg, ensemblealg::Union{EnsembleArrayAlgorithm, EnsembleKernelAlgorithm}, I, adaptive; + sim_seeds = nothing, + rng_func = SciMLBase.default_rng_func, + master_rng = nothing, kwargs... ) @assert !isempty(I) @@ -157,17 +186,18 @@ function batch_solve( return if ensemblealg isa EnsembleGPUKernel if ensembleprob.safetycopy probs = map(I) do i + ctx = _make_ensemble_context(i, sim_seeds, rng_func, master_rng) make_prob_compatible( ensembleprob.prob_func( deepcopy(ensembleprob.prob), - i, - 1 + ctx ) ) end else probs = map(I) do i - make_prob_compatible(ensembleprob.prob_func(ensembleprob.prob, i, 1)) + ctx = _make_ensemble_context(i, sim_seeds, rng_func, master_rng) + make_prob_compatible(ensembleprob.prob_func(ensembleprob.prob, ctx)) end end # Using inner saveat requires all of them to be of same size, @@ -246,7 +276,7 @@ function batch_solve( ReturnCode.Terminated : ReturnCode.Success ), - i + _make_ensemble_context(I[i], sim_seeds, rng_func, master_rng) )[1] end for i in eachindex(probs) @@ -258,11 +288,13 @@ function batch_solve( else if ensembleprob.safetycopy probs = map(I) do i - ensembleprob.prob_func(deepcopy(ensembleprob.prob), i, 1) + ctx = _make_ensemble_context(i, sim_seeds, rng_func, master_rng) + ensembleprob.prob_func(deepcopy(ensembleprob.prob), ctx) end else probs = map(I) do i - ensembleprob.prob_func(ensembleprob.prob, i, 1) + ctx = _make_ensemble_context(i, sim_seeds, rng_func, master_rng) + ensembleprob.prob_func(ensembleprob.prob, ctx) end end u0 = reduce(hcat, Array(probs[i].u0) for i in 1:length(I)) @@ -316,7 +348,7 @@ function batch_solve( stats = sol.stats, retcode = sol.retcode ), - i + _make_ensemble_context(I[i], sim_seeds, rng_func, master_rng) )[1] for i in 1:length(probs) ] @@ -339,7 +371,7 @@ function batch_solve( stats = sol.stats, retcode = sol.retcode ), - i + _make_ensemble_context(I[i], sim_seeds, rng_func, master_rng) )[1] for i in 1:length(probs) ] @@ -539,13 +571,13 @@ function ChainRulesCore.rrule( end function solve_batch( - prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_size; - kwargs... + prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_size, + ensemble_rng_state; kwargs... ) if length(II) == 1 || Threads.nthreads() == 1 return SciMLBase.solve_batch( - prob, alg, EnsembleSerial(), II, pmap_batch_size; - kwargs... + prob, alg, EnsembleSerial(), II, pmap_batch_size, + ensemble_rng_state; kwargs... ) end @@ -565,8 +597,8 @@ function solve_batch( I_local = II[(batch_size * (i - 1) + 1):(batch_size * i)] end SciMLBase.solve_batch( - prob, alg, EnsembleSerial(), I_local, pmap_batch_size; - kwargs... + prob, alg, EnsembleSerial(), I_local, pmap_batch_size, + ensemble_rng_state; kwargs... ) end return SciMLBase.tighten_container_eltype(batch_data) diff --git a/test/distributed_multi_gpu.jl b/test/distributed_multi_gpu.jl index d0d12338..08ccc552 100644 --- a/test/distributed_multi_gpu.jl +++ b/test/distributed_multi_gpu.jl @@ -14,8 +14,8 @@ addprocs(2) p = (10.0f0, 28.0f0, 8 / 3.0f0) Random.seed!(1) pre_p_distributed = [rand(Float32, 3) for i in 1:10] - function prob_func_distributed(prob, i, repeat) - remake(prob, p = pre_p_distributed[i] .* p) + function prob_func_distributed(prob, ctx) + remake(prob, p = pre_p_distributed[ctx.sim_id] .* p) end end diff --git a/test/ensemblegpuarray.jl b/test/ensemblegpuarray.jl index aac5a531..3e5f39a6 100644 --- a/test/ensemblegpuarray.jl +++ b/test/ensemblegpuarray.jl @@ -13,7 +13,7 @@ tspan = (0.0f0, 100.0f0) p = (10.0f0, 28.0f0, 8 / 3.0f0) prob = ODEProblem(lorenz, u0, tspan, p) const pre_p = [rand(Float32, 3) for i in 1:10] -prob_func = (prob, i, repeat) -> remake(prob, p = pre_p[i] .* p) +prob_func = (prob, ctx) -> remake(prob, p = pre_p[ctx.sim_id] .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func) @info "Explicit Methods" @@ -254,7 +254,8 @@ u0 = Float32[1.0; 0.0; 0.0] tspan = (0.0f0, 100.0f0) p = LorenzParameters(10.0f0, 28.0f0, 8 / 3.0f0) prob = ODEProblem(lorenzp, u0, tspan, p) -function param_prob_func(prob, i, repeat) +function param_prob_func(prob, ctx) + i = ctx.sim_id p = LorenzParameters( pre_p[i][1] .* 10.0f0, pre_p[i][2] .* 28.0f0, @@ -277,11 +278,11 @@ saveats = 1.0f0:1.0f0:10.0f0 prob = ODEProblem(lorenz, u0, tspan, p) monteprob = EnsembleProblem( prob_jac, - prob_func = (prob, i, repeat) -> remake( + prob_func = (prob, ctx) -> remake( prob; tspan = ( 0.0f0, - saveats[i], + saveats[ctx.sim_id], ) ) ) diff --git a/test/ensemblegpuarray_inputtypes.jl b/test/ensemblegpuarray_inputtypes.jl index adbf2497..cd5e26b3 100644 --- a/test/ensemblegpuarray_inputtypes.jl +++ b/test/ensemblegpuarray_inputtypes.jl @@ -15,7 +15,7 @@ u0 = [ tspan = (0.0f0, 100.0f0) p = (10.0f0, 28.0f0, 8 / 3.0f0) prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = rand(Float32, 3) .* p) +prob_func = (prob, ctx) -> remake(prob, p = rand(Float32, 3) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func) @time sol = solve( monteprob, Tsit5(), EnsembleGPUArray(backend), trajectories = 10_000, @@ -27,7 +27,7 @@ u0 = [1f0u"m";0u"m";0u"m"] tspan = (0.0f0u"s",100.0f0u"s") p = (10.0f0,28.0f0,8/3f0) prob = ODEProblem(lorenz,u0,tspan,p) -prob_func = (prob,i,repeat) -> remake(prob,p=rand(Float32,3).*p) +prob_func = (prob,ctx) -> remake(prob,p=rand(Float32,3).*p) monteprob = EnsembleProblem(prob, prob_func = prob_func) @test_broken sol = solve(monteprob,Tsit5(),EnsembleGPUArray(),trajectories=10_000,saveat=1.0f0u"s") =# diff --git a/test/ensemblegpuarray_oop.jl b/test/ensemblegpuarray_oop.jl index 95a06da9..4cb2a51f 100644 --- a/test/ensemblegpuarray_oop.jl +++ b/test/ensemblegpuarray_oop.jl @@ -32,7 +32,7 @@ u0 = SA[1.0f0; 0.0f0; 0.0f0] tspan = (0.0f0, 100.0f0) p = SA[10.0f0, 28.0f0, 8 / 3.0f0] prob = ODEProblem(func, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = rand(Float32, 3) .* p) +prob_func = (prob, ctx) -> remake(prob, p = rand(Float32, 3) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) @time sol = solve( monteprob, Tsit5(), EnsembleGPUArray(backend), trajectories = 10_000, diff --git a/test/ensemblegpuarray_sde.jl b/test/ensemblegpuarray_sde.jl index 95a7d3b0..6a126144 100644 --- a/test/ensemblegpuarray_sde.jl +++ b/test/ensemblegpuarray_sde.jl @@ -19,7 +19,7 @@ tspan = (0.0f0, 10.0f0) p = (10.0f0, 28.0f0, 8 / 3.0f0) prob = SDEProblem(lorenz, multiplicative_noise, u0, tspan, p) const pre_p = [rand(Float32, 3) for i in 1:10] -prob_func = (prob, i, repeat) -> remake(prob, p = pre_p[i] .* p) +prob_func = (prob, ctx) -> remake(prob, p = pre_p[ctx.sim_id] .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func) @info "Explicit Methods" @@ -54,7 +54,7 @@ tspan = (0.0f0, 10.0f0) p = (10.0f0, 28.0f0, 8 / 3.0f0) prob = SDEProblem(lorenz, multiplicative_noise, u0, tspan, p, noise_rate_prototype = NRate) -prob_func = (prob, i, repeat) -> remake(prob, p = p) +prob_func = (prob, ctx) -> remake(prob, p = p) monteprob = EnsembleProblem(prob, prob_func = prob_func) @test_throws "Incompatible problem detected. EnsembleGPUArray currently requires `prob.noise_rate_prototype === nothing`, i.e. only diagonal noise is currently supported. Track https://github.com/SciML/DiffEqGPU.jl/issues/331 for more information." sol = solve( diff --git a/test/gpu_kernel_de/conversions.jl b/test/gpu_kernel_de/conversions.jl index 7339b5d6..1106f44e 100644 --- a/test/gpu_kernel_de/conversions.jl +++ b/test/gpu_kernel_de/conversions.jl @@ -15,7 +15,7 @@ u0 = [1.0f0; 0.0f0; 0.0f0] tspan = (0.0f0, 10.0f0) p = [10.0f0, 28.0f0, 8 / 3.0f0] prob = ODEProblem{false}(lorenz, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) +prob_func = (prob, ctx) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) ## Don't test the problems in which GPUs don't support FP64 completely yet diff --git a/test/gpu_kernel_de/finite_diff.jl b/test/gpu_kernel_de/finite_diff.jl index 8f3527ed..ac7ef7e2 100644 --- a/test/gpu_kernel_de/finite_diff.jl +++ b/test/gpu_kernel_de/finite_diff.jl @@ -14,7 +14,7 @@ tspan = (0.0f0, 10.0f0) prob = ODEProblem{false}(f, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = p) +prob_func = (prob, ctx) -> remake(prob, p = p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) osol = solve(prob, Rodas5P(), dt = 0.01f0, save_everystep = false) diff --git a/test/gpu_kernel_de/forward_diff.jl b/test/gpu_kernel_de/forward_diff.jl index 5e3f5d7f..9a554263 100644 --- a/test/gpu_kernel_de/forward_diff.jl +++ b/test/gpu_kernel_de/forward_diff.jl @@ -29,7 +29,7 @@ tspan = (0.0f0, 10.0f0) prob = ODEProblem{false}(lorenz, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = p) +prob_func = (prob, ctx) -> remake(prob, p = p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) for alg in ( diff --git a/test/gpu_kernel_de/gpu_ode_continuous_callbacks.jl b/test/gpu_kernel_de/gpu_ode_continuous_callbacks.jl index 2f4dad1e..8cb5bcd8 100644 --- a/test/gpu_kernel_de/gpu_ode_continuous_callbacks.jl +++ b/test/gpu_kernel_de/gpu_ode_continuous_callbacks.jl @@ -13,7 +13,7 @@ u0 = @SVector[45.0f0, 0.0f0] tspan = (0.0f0, 15.0f0) p = @SVector [10.0f0] prob = ODEProblem{false}(f, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = prob.p) +prob_func = (prob, ctx) -> remake(prob, p = prob.p) monteprob = EnsembleProblem(prob, safetycopy = false) function affect!(integrator) diff --git a/test/gpu_kernel_de/gpu_ode_discrete_callbacks.jl b/test/gpu_kernel_de/gpu_ode_discrete_callbacks.jl index 22e58095..c3099b11 100644 --- a/test/gpu_kernel_de/gpu_ode_discrete_callbacks.jl +++ b/test/gpu_kernel_de/gpu_ode_discrete_callbacks.jl @@ -9,7 +9,7 @@ function f(u, p, t) end u0 = @SVector [10.0f0] prob = ODEProblem{false}(f, u0, (0.0f0, 10.0f0)) -prob_func = (prob, i, repeat) -> remake(prob, p = prob.p) +prob_func = (prob, ctx) -> remake(prob, p = prob.p) monteprob = EnsembleProblem(prob, safetycopy = false) algs = [GPUTsit5(), GPUVern7(), GPUVern9()] diff --git a/test/gpu_kernel_de/gpu_ode_regression.jl b/test/gpu_kernel_de/gpu_ode_regression.jl index 4d10075f..54f6d76f 100644 --- a/test/gpu_kernel_de/gpu_ode_regression.jl +++ b/test/gpu_kernel_de/gpu_ode_regression.jl @@ -18,7 +18,7 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p) algs = (GPUTsit5(), GPUVern7(), GPUVern9()) for alg in algs - prob_func = (prob, i, repeat) -> remake(prob, p = p) + prob_func = (prob, ctx) -> remake(prob, p = p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) @info typeof(alg) @@ -125,7 +125,7 @@ for alg in algs ## With random parameters - prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) + prob_func = (prob, ctx) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) local sol = solve( diff --git a/test/gpu_kernel_de/gpu_sde_convergence.jl b/test/gpu_kernel_de/gpu_sde_convergence.jl index b9e0f395..17c1d0d9 100644 --- a/test/gpu_kernel_de/gpu_sde_convergence.jl +++ b/test/gpu_kernel_de/gpu_sde_convergence.jl @@ -15,7 +15,7 @@ dts = 1 .// 2 .^ (5:-1:2) ensemble_prob = EnsembleProblem( prob; - output_func = (sol, i) -> (sol.u[end], false) + output_func = (sol, ctx) -> (sol.u[end], false) ) @info "EM" diff --git a/test/gpu_kernel_de/stiff_ode/gpu_ode_continuous_callbacks.jl b/test/gpu_kernel_de/stiff_ode/gpu_ode_continuous_callbacks.jl index 00f5715c..1cbe90ed 100644 --- a/test/gpu_kernel_de/stiff_ode/gpu_ode_continuous_callbacks.jl +++ b/test/gpu_kernel_de/stiff_ode/gpu_ode_continuous_callbacks.jl @@ -27,7 +27,7 @@ p = @SVector [10.0f0] func = ODEFunction(f, jac = f_jac, tgrad = f_tgrad) prob = ODEProblem{false}(func, u0, tspan, p) -prob_func = (prob, i, repeat) -> remake(prob, p = prob.p) +prob_func = (prob, ctx) -> remake(prob, p = prob.p) monteprob = EnsembleProblem(prob, safetycopy = false) function affect!(integrator) diff --git a/test/gpu_kernel_de/stiff_ode/gpu_ode_discrete_callbacks.jl b/test/gpu_kernel_de/stiff_ode/gpu_ode_discrete_callbacks.jl index e7acf2f5..418f130d 100644 --- a/test/gpu_kernel_de/stiff_ode/gpu_ode_discrete_callbacks.jl +++ b/test/gpu_kernel_de/stiff_ode/gpu_ode_discrete_callbacks.jl @@ -19,7 +19,7 @@ end func = ODEFunction(f, jac = f_jac, tgrad = f_tgrad) u0 = @SVector [10.0f0] prob = ODEProblem{false}(func, u0, (0.0f0, 10.0f0)) -prob_func = (prob, i, repeat) -> remake(prob, p = prob.p) +prob_func = (prob, ctx) -> remake(prob, p = prob.p) monteprob = EnsembleProblem(prob, safetycopy = false) algs = [GPURosenbrock23(), GPURodas4()] diff --git a/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl b/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl index 234d430e..04ec0c84 100644 --- a/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl +++ b/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl @@ -30,7 +30,7 @@ large_prob = ODEProblem(f_large, large_u0, (0.0f0, 10.0f0)) algs = (GPURosenbrock23(), GPURodas4(), GPURodas5P(), GPUKvaerno3(), GPUKvaerno5()) for alg in algs - prob_func = (prob, i, repeat) -> remake(prob, p = p) + prob_func = (prob, ctx) -> remake(prob, p = p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) @info typeof(alg) @@ -147,7 +147,7 @@ for alg in algs ## With random parameters - prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 1)) .* p) + prob_func = (prob, ctx) -> remake(prob, p = (@SVector rand(Float32, 1)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) local sol = solve( diff --git a/test/reduction.jl b/test/reduction.jl index 5ca435a3..391c5b07 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -13,12 +13,12 @@ end prob = ODEProblem(f!, [0.5], (0.0, 1.0)) -function output_func(sol, i) +function output_func(sol, ctx) return last(sol), false end -function prob_func(prob, i, repeat) - return remake(prob, u0 = ra[i] * prob.u0) +function prob_func(prob, ctx) + return remake(prob, u0 = ra[ctx.sim_id] * prob.u0) end function reduction(u, batch, I) diff --git a/test/reverse_ad_tests.jl b/test/reverse_ad_tests.jl index 931f1512..5d9be305 100644 --- a/test/reverse_ad_tests.jl +++ b/test/reverse_ad_tests.jl @@ -10,8 +10,8 @@ end function model(θ, ensemblealg) prob = ODEProblem(modelf, [θ[1]], (0.0, 1.0), [θ[2], θ[3]]) - function prob_func(prob, i, repeat) - return remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + function prob_func(prob, ctx) + return remake(prob, u0 = 0.5 .+ ctx.sim_id / 100 .* prob.u0) end ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)