Add nested conversion of Float64 to Float32#748
Conversation
Fixes inconsistent behavior (#650): - `mtl(rand(SVector{2,Float64}, n))` raised an error -`mtl(rand(Float64, n))` silently converted to Float32
|
Note: Please ignore this comment unless asked otherwise by a reviewer. Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/array.jl b/src/array.jl
index 0f8c9cc7..ea4ec7d7 100644
--- a/src/array.jl
+++ b/src/array.jl
@@ -485,26 +485,26 @@ Adapt.adapt_storage(::MtlArrayAdaptor, x::Float64) = Float32(x)
Adapt.adapt_storage(::MtlArrayAdaptor, x::Complex{Float64}) = ComplexF32(x)
# AbstractFloat → Float32
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,S} =
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: AbstractFloat, N, S} =
isbits(xs) ? xs : MtlArray{Float32,N,S}(xs)
# Float16 — preserve (more specific)
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float16,N,S} =
- isbits(xs) ? xs : MtlArray{T,N,S}(xs)
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: Float16, N, S} =
+ isbits(xs) ? xs : MtlArray{T, N, S}(xs)
# Complex{AbstractFloat} → ComplexF32
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,S} =
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: Complex{<:AbstractFloat}, N, S} =
isbits(xs) ? xs : MtlArray{ComplexF32,N,S}(xs)
# Complex{Float16} — preserve
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{Float16},N,S} =
- isbits(xs) ? xs : MtlArray{T,N,S}(xs)
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: Complex{Float16}, N, S} =
+ isbits(xs) ? xs : MtlArray{T, N, S}(xs)
# Generic — descend via adapt to handle composite types
-function Adapt.adapt_storage(to::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T,N,S}
+function Adapt.adapt_storage(to::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T, N, S}
isbits(xs) && return xs
adapted = map(x -> adapt(to, x), xs)
- MtlArray{eltype(adapted),N,S}(adapted)
+ return MtlArray{eltype(adapted), N, S}(adapted)
end
""" |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #748 +/- ##
==========================================
+ Coverage 82.01% 82.26% +0.25%
==========================================
Files 62 62
Lines 2874 2881 +7
==========================================
+ Hits 2357 2370 +13
+ Misses 517 511 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Metal Benchmarks
Details
| Benchmark suite | Current: f83fed5 | Previous: 1848fba | Ratio |
|---|---|---|---|
latency/precompile |
25782881959 ns |
25661029083 ns |
1.00 |
latency/ttfp |
2389108146 ns |
2347953229.5 ns |
1.02 |
latency/import |
1452582834 ns |
1431131708 ns |
1.01 |
integration/metaldevrt |
854062.5 ns |
877062.5 ns |
0.97 |
integration/byval/slices=1 |
1598791 ns |
1562250 ns |
1.02 |
integration/byval/slices=3 |
20475166 ns |
9561167 ns |
2.14 |
integration/byval/reference |
1599416 ns |
1555625 ns |
1.03 |
integration/byval/slices=2 |
2734125 ns |
2601916.5 ns |
1.05 |
kernel/indexing |
508208 ns |
625375 ns |
0.81 |
kernel/indexing_checked |
511750 ns |
659417 ns |
0.78 |
kernel/launch |
12500 ns |
11541 ns |
1.08 |
kernel/rand |
516042 ns |
567291 ns |
0.91 |
array/construct |
6334 ns |
6333 ns |
1.00 |
array/broadcast |
510625 ns |
608958 ns |
0.84 |
array/random/randn/Float32 |
1040084 ns |
963500 ns |
1.08 |
array/random/randn!/Float32 |
729625 ns |
752729 ns |
0.97 |
array/random/rand!/Int64 |
538375 ns |
558083 ns |
0.96 |
array/random/rand!/Float32 |
540875 ns |
586958 ns |
0.92 |
array/random/rand/Int64 |
960708 ns |
752416 ns |
1.28 |
array/random/rand/Float32 |
828666 ns |
662625 ns |
1.25 |
array/accumulate/Int64/1d |
1287958 ns |
1244875 ns |
1.03 |
array/accumulate/Int64/dims=1 |
1892459 ns |
1839791 ns |
1.03 |
array/accumulate/Int64/dims=2 |
2288000 ns |
2174979 ns |
1.05 |
array/accumulate/Int64/dims=1L |
12240666.5 ns |
11510459 ns |
1.06 |
array/accumulate/Int64/dims=2L |
9678750 ns |
9796125 ns |
0.99 |
array/accumulate/Float32/1d |
1085375 ns |
1111917 ns |
0.98 |
array/accumulate/Float32/dims=1 |
1626000 ns |
1567354.5 ns |
1.04 |
array/accumulate/Float32/dims=2 |
2038750 ns |
1882084 ns |
1.08 |
array/accumulate/Float32/dims=1L |
10420479.5 ns |
9815541 ns |
1.06 |
array/accumulate/Float32/dims=2L |
7111666.5 ns |
7235812.5 ns |
0.98 |
array/reductions/reduce/Int64/1d |
1326271 ns |
1622437.5 ns |
0.82 |
array/reductions/reduce/Int64/dims=1 |
1127000 ns |
1148333 ns |
0.98 |
array/reductions/reduce/Int64/dims=2 |
1175437.5 ns |
1185375 ns |
0.99 |
array/reductions/reduce/Int64/dims=1L |
2041917 ns |
2026167 ns |
1.01 |
array/reductions/reduce/Int64/dims=2L |
3755583 ns |
4244833 ns |
0.88 |
array/reductions/reduce/Float32/1d |
826645.5 ns |
1009292 ns |
0.82 |
array/reductions/reduce/Float32/dims=1 |
806562.5 ns |
851166 ns |
0.95 |
array/reductions/reduce/Float32/dims=2 |
837791 ns |
852062.5 ns |
0.98 |
array/reductions/reduce/Float32/dims=1L |
1359417 ns |
1328000 ns |
1.02 |
array/reductions/reduce/Float32/dims=2L |
1835917 ns |
1813625 ns |
1.01 |
array/reductions/mapreduce/Int64/1d |
1322875 ns |
1566000 ns |
0.84 |
array/reductions/mapreduce/Int64/dims=1 |
1127708 ns |
1108708 ns |
1.02 |
array/reductions/mapreduce/Int64/dims=2 |
1165292 ns |
1153187.5 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=1L |
1981417 ns |
2031583 ns |
0.98 |
array/reductions/mapreduce/Int64/dims=2L |
3656937.5 ns |
3617875 ns |
1.01 |
array/reductions/mapreduce/Float32/1d |
863396 ns |
1047375 ns |
0.82 |
array/reductions/mapreduce/Float32/dims=1 |
805500 ns |
834687.5 ns |
0.97 |
array/reductions/mapreduce/Float32/dims=2 |
830583 ns |
873208.5 ns |
0.95 |
array/reductions/mapreduce/Float32/dims=1L |
1362584 ns |
1318500 ns |
1.03 |
array/reductions/mapreduce/Float32/dims=2L |
1826271 ns |
1842479 ns |
0.99 |
array/private/copyto!/gpu_to_gpu |
530834 ns |
631625 ns |
0.84 |
array/private/copyto!/cpu_to_gpu |
735958.5 ns |
796333 ns |
0.92 |
array/private/copyto!/gpu_to_cpu |
689312.5 ns |
374541.5 ns |
1.84 |
array/private/iteration/findall/int |
1585854 ns |
1555542 ns |
1.02 |
array/private/iteration/findall/bool |
1466042 ns |
1407417 ns |
1.04 |
array/private/iteration/findfirst/int |
2112625 ns |
2072500 ns |
1.02 |
array/private/iteration/findfirst/bool |
2043750 ns |
2051292 ns |
1.00 |
array/private/iteration/scalar |
3481333.5 ns |
5517729 ns |
0.63 |
array/private/iteration/logical |
2666187.5 ns |
2653750 ns |
1.00 |
array/private/iteration/findmin/1d |
2583792 ns |
2503750 ns |
1.03 |
array/private/iteration/findmin/2d |
1853375 ns |
1792250 ns |
1.03 |
array/private/copy |
791042 ns |
569166 ns |
1.39 |
array/shared/copyto!/gpu_to_gpu |
86250 ns |
84166 ns |
1.02 |
array/shared/copyto!/cpu_to_gpu |
86959 ns |
81417 ns |
1.07 |
array/shared/copyto!/gpu_to_cpu |
85833 ns |
81875 ns |
1.05 |
array/shared/iteration/findall/int |
1580208 ns |
1574708.5 ns |
1.00 |
array/shared/iteration/findall/bool |
1483625 ns |
1417688 ns |
1.05 |
array/shared/iteration/findfirst/int |
1693958.5 ns |
1669333.5 ns |
1.01 |
array/shared/iteration/findfirst/bool |
1654395.5 ns |
1652125 ns |
1.00 |
array/shared/iteration/scalar |
204375 ns |
205583.5 ns |
0.99 |
array/shared/iteration/logical |
2358021 ns |
2455354.5 ns |
0.96 |
array/shared/iteration/findmin/1d |
2211229 ns |
2120042 ns |
1.04 |
array/shared/iteration/findmin/2d |
1870083 ns |
1807291 ns |
1.03 |
array/shared/copy |
216209 ns |
242458 ns |
0.89 |
array/permutedims/4d |
2506042 ns |
2399104 ns |
1.04 |
array/permutedims/2d |
1195125 ns |
1179666 ns |
1.01 |
array/permutedims/3d |
1798854.5 ns |
1685395.5 ns |
1.07 |
metal/synchronization/stream |
19542 ns |
19209 ns |
1.02 |
metal/synchronization/context |
20000 ns |
19917 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
- add unit tests for `_mtl_adapt_eltype` - integration tests for `mtl` with SVector - format changes with runic
|
Thanks. Inspecting a datatype is almost always the wrong solution though. I've made Adapt behave recursively for StaticArrays, JuliaGPU/Adapt.jl#104, and simplified the implementation here. |
Fixes inconsistent behavior (#650):
-
mtl(rand(SVector{2,Float64}, n))raised an error-
mtl(rand(Float64, n))silently converted to Float32Now no error raised, instead convert
Float64->Float32: