diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 191a5eb2..3db5b8ca 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -797,6 +797,15 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, end symtype = symbolic_type(_arg) elsymtype = symbolic_type(eltype(_arg)) + # For symbolic indices, `A[sym, :]` is semantically equivalent to `A[sym]` + # (the no-args symbolic getindex already returns the full timeseries). + # Routing through the no-args path here also avoids a broadcast shape bug + # in SymbolicIndexingInterface's `GetStateIndex` when the underlying index + # is a `Vector{Int}` (array-symbolic) combined with a `Colon` time index. + if (symtype != NotSymbolic() || elsymtype != NotSymbolic()) && + length(args) == 1 && args[1] === Colon() + return A[_arg] + end return if symtype == NotSymbolic() && elsymtype == NotSymbolic() if _arg isa Union{Tuple, AbstractArray} && diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 3ed72f06..ad7277e8 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -4,7 +4,9 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +RecursiveArrayToolsShorthandConstructors = "39fb7555-b4ad-4efd-8abe-30331df017d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -18,6 +20,7 @@ ModelingToolkit = "8.33, 9, 10, 11" MonteCarloMeasurements = "1.1" NLsolve = "4" OrdinaryDiffEq = "6.31, 7" +OrdinaryDiffEqRosenbrock = "1, 2" StaticArrays = "1" SymbolicIndexingInterface = "0.3" Tables = "1" diff --git a/test/downstream/downstream_events.jl b/test/downstream/downstream_events.jl index 230dcbb0..2118cbc7 100644 --- a/test/downstream/downstream_events.jl +++ b/test/downstream/downstream_events.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools +using OrdinaryDiffEq, StaticArrays, RecursiveArrayTools, RecursiveArrayToolsShorthandConstructors u0 = AP[SVector{1}(50.0), SVector{1}(0.0)] tspan = (0.0, 15.0) diff --git a/test/downstream/odesolve.jl b/test/downstream/odesolve.jl index 2ca163ab..5fe162d8 100644 --- a/test/downstream/odesolve.jl +++ b/test/downstream/odesolve.jl @@ -1,4 +1,5 @@ -using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface, StaticArrays +using OrdinaryDiffEq, OrdinaryDiffEqRosenbrock, NLsolve, RecursiveArrayTools, + RecursiveArrayToolsShorthandConstructors, Test, ArrayInterface, StaticArrays function lorenz(du, u, p, t) du[1] = 10.0 * (u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] @@ -9,7 +10,7 @@ u0 = AP[[1.0, 0.0], [0.0]] tspan = (0.0, 100.0) prob = ODEProblem(lorenz, u0, tspan) sol = solve(prob, Tsit5()) -sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = false))) +sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = AutoFiniteDiff()))) sol = solve(prob, AutoTsit5(Rosenbrock23())) @test all(Array(sol) .== sol) @@ -72,4 +73,4 @@ end u = fill(SVector{2}(ones(2)), 2, 3) ode = ODEProblem(rhs!, VectorOfArray(u), (0.0, 1.0)) sol = solve(ode, Tsit5()) -@test SciMLBase.successful_retcode(sol) +@test successful_retcode(sol) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index b420d80c..3cbae750 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -77,7 +77,7 @@ ts = 0:0.5:10 sol_ts = sol(ts) @assert sol_ts isa DiffEqArray test_tables_interface( - sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")], + sol_ts, [:timestamp; Symbol.(string.(unknowns(lv)))], hcat(ts, Array(sol_ts)') )