From 846fcdc5a52bb66e166aa53cade01e0d15799869 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Tue, 31 Mar 2026 13:46:51 -0400 Subject: [PATCH] WIP: Inline optimized any/all for ArrayPartition on Julia 1.13+ Julia 1.13 (JuliaLang/julia#61184) removes the f::Function restriction from any/all, so defining any(f, ::ArrayPartition) drops invalidations from ~780 to 1 (verified with SnoopCompileCore). On 1.13+, the optimized partition-level methods are defined directly in the main package. On older Julia, the subpackage split is preserved. Co-Authored-By: Chris Rackauckas Co-Authored-By: Claude Opus 4.6 (1M context) --- src/array_partition.jl | 23 ++++++++++++++++++----- test/partitions_test.jl | 16 ++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 0f2eba96..f8bda538 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -741,9 +741,22 @@ ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve """ struct AP end -# Optimized partition-level any/all for ArrayPartition lives in -# RecursiveArrayToolsArrayPartitionAnyAll to avoid ~780 invalidations. -# Without the extension, any/all uses the AbstractArray element-by-element -# fallback, which triggers scalar indexing errors on GPU arrays. -# Load the subpackage to fix: +# Optimized partition-level any/all for ArrayPartition. +# +# On Julia ≥ 1.13 (JuliaLang/julia#61184), Base removes the f::Function +# restriction from any/all, so defining any(f, ::ArrayPartition) causes +# only 1 invalidation (down from ~780). Safe to inline directly. +# +# On Julia < 1.13, the methods live in RecursiveArrayToolsArrayPartitionAnyAll +# to avoid the invalidations. Without the extension, any/all uses the +# AbstractArray element-by-element fallback, which triggers scalar indexing +# errors on GPU arrays. Load the subpackage to fix: # using RecursiveArrayToolsArrayPartitionAnyAll +@static if VERSION >= v"1.13.0-DEV.0" + Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x)) + Base.any(f::Function, A::ArrayPartition) = any((any(f, x) for x in A.x)) + Base.any(A::ArrayPartition) = any(identity, A) + Base.all(f, A::ArrayPartition) = all((all(f, x) for x in A.x)) + Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x)) + Base.all(A::ArrayPartition) = all(identity, A) +end diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 44bf10c9..1b7b7fdb 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -178,13 +178,17 @@ recursivecopy!(dest_ap, src_ap) @inferred mapreduce(string, *, x) @test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q" -# any/all — optimized partition-level iteration requires RecursiveArrayToolsArrayPartitionAnyAll -# to avoid ~780 invalidations. Without the extension, Base's AbstractArray fallback is used. -# On GPU arrays, the fallback triggers scalar indexing errors — load the subpackage to fix. -@test which(any, Tuple{Function, ArrayPartition}).module === Base -@test which(all, Tuple{Function, ArrayPartition}).module === Base +# any/all — on Julia ≥ 1.13, optimized methods are inlined (1 invalidation). +# On older Julia, they require RecursiveArrayToolsArrayPartitionAnyAll (~780 invalidations). +@static if VERSION >= v"1.13.0-DEV.0" + @test which(any, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools + @test which(all, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools +else + @test which(any, Tuple{Function, ArrayPartition}).module === Base + @test which(all, Tuple{Function, ArrayPartition}).module === Base +end -# Correctness tests (work via AbstractArray fallback on CPU) +# Correctness tests @test !any(isnan, AP[[1, 2], [3.0, 4.0]]) @test !any(isnan, AP[[3.0, 4.0]]) @test any(isnan, AP[[NaN], [3.0, 4.0]])