Skip to content

Add nested conversion of Float64 to Float32#748

Merged
maleadt merged 3 commits into
JuliaGPU:mainfrom
papo1011:main
Mar 3, 2026
Merged

Add nested conversion of Float64 to Float32#748
maleadt merged 3 commits into
JuliaGPU:mainfrom
papo1011:main

Conversation

@papo1011

@papo1011 papo1011 commented Mar 1, 2026

Copy link
Copy Markdown
Contributor

Fixes inconsistent behavior (#650):
-mtl(rand(SVector{2,Float64}, n)) raised an error
-mtl(rand(Float64, n)) silently converted to Float32

Now no error raised, instead convert Float64 -> Float32:

mtl( rand( SVector{2, Float64}, 1000 ) )
# MtlVector{SVector{2, Float32}, Metal.PrivateStorage}

Fixes inconsistent behavior (#650):
- `mtl(rand(SVector{2,Float64}, n))` raised an error
-`mtl(rand(Float64, n))` silently converted to Float32
@github-actions

github-actions Bot commented Mar 1, 2026

Copy link
Copy Markdown
Contributor

Note: Please ignore this comment unless asked otherwise by a reviewer.

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/array.jl b/src/array.jl
index 0f8c9cc7..ea4ec7d7 100644
--- a/src/array.jl
+++ b/src/array.jl
@@ -485,26 +485,26 @@ Adapt.adapt_storage(::MtlArrayAdaptor, x::Float64) = Float32(x)
 Adapt.adapt_storage(::MtlArrayAdaptor, x::Complex{Float64}) = ComplexF32(x)
 
 # AbstractFloat → Float32
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,S} =
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: AbstractFloat, N, S} =
     isbits(xs) ? xs : MtlArray{Float32,N,S}(xs)
 
 # Float16 — preserve (more specific)
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float16,N,S} =
-    isbits(xs) ? xs : MtlArray{T,N,S}(xs)
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: Float16, N, S} =
+    isbits(xs) ? xs : MtlArray{T, N, S}(xs)
 
 # Complex{AbstractFloat} → ComplexF32
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,S} =
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: Complex{<:AbstractFloat}, N, S} =
     isbits(xs) ? xs : MtlArray{ComplexF32,N,S}(xs)
 
 # Complex{Float16} — preserve
-Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{Float16},N,S} =
-    isbits(xs) ? xs : MtlArray{T,N,S}(xs)
+Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T <: Complex{Float16}, N, S} =
+    isbits(xs) ? xs : MtlArray{T, N, S}(xs)
 
 # Generic — descend via adapt to handle composite types
-function Adapt.adapt_storage(to::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T,N,S}
+function Adapt.adapt_storage(to::MtlArrayAdaptor{S}, xs::AbstractArray{T, N}) where {T, N, S}
     isbits(xs) && return xs
     adapted = map(x -> adapt(to, x), xs)
-    MtlArray{eltype(adapted),N,S}(adapted)
+    return MtlArray{eltype(adapted), N, S}(adapted)
 end
 
 """

@codecov

codecov Bot commented Mar 1, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 90.00000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 82.26%. Comparing base (1d2f000) to head (f83fed5).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/array.jl 90.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #748      +/-   ##
==========================================
+ Coverage   82.01%   82.26%   +0.25%     
==========================================
  Files          62       62              
  Lines        2874     2881       +7     
==========================================
+ Hits         2357     2370      +13     
+ Misses        517      511       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Details
Benchmark suite Current: f83fed5 Previous: 1848fba Ratio
latency/precompile 25782881959 ns 25661029083 ns 1.00
latency/ttfp 2389108146 ns 2347953229.5 ns 1.02
latency/import 1452582834 ns 1431131708 ns 1.01
integration/metaldevrt 854062.5 ns 877062.5 ns 0.97
integration/byval/slices=1 1598791 ns 1562250 ns 1.02
integration/byval/slices=3 20475166 ns 9561167 ns 2.14
integration/byval/reference 1599416 ns 1555625 ns 1.03
integration/byval/slices=2 2734125 ns 2601916.5 ns 1.05
kernel/indexing 508208 ns 625375 ns 0.81
kernel/indexing_checked 511750 ns 659417 ns 0.78
kernel/launch 12500 ns 11541 ns 1.08
kernel/rand 516042 ns 567291 ns 0.91
array/construct 6334 ns 6333 ns 1.00
array/broadcast 510625 ns 608958 ns 0.84
array/random/randn/Float32 1040084 ns 963500 ns 1.08
array/random/randn!/Float32 729625 ns 752729 ns 0.97
array/random/rand!/Int64 538375 ns 558083 ns 0.96
array/random/rand!/Float32 540875 ns 586958 ns 0.92
array/random/rand/Int64 960708 ns 752416 ns 1.28
array/random/rand/Float32 828666 ns 662625 ns 1.25
array/accumulate/Int64/1d 1287958 ns 1244875 ns 1.03
array/accumulate/Int64/dims=1 1892459 ns 1839791 ns 1.03
array/accumulate/Int64/dims=2 2288000 ns 2174979 ns 1.05
array/accumulate/Int64/dims=1L 12240666.5 ns 11510459 ns 1.06
array/accumulate/Int64/dims=2L 9678750 ns 9796125 ns 0.99
array/accumulate/Float32/1d 1085375 ns 1111917 ns 0.98
array/accumulate/Float32/dims=1 1626000 ns 1567354.5 ns 1.04
array/accumulate/Float32/dims=2 2038750 ns 1882084 ns 1.08
array/accumulate/Float32/dims=1L 10420479.5 ns 9815541 ns 1.06
array/accumulate/Float32/dims=2L 7111666.5 ns 7235812.5 ns 0.98
array/reductions/reduce/Int64/1d 1326271 ns 1622437.5 ns 0.82
array/reductions/reduce/Int64/dims=1 1127000 ns 1148333 ns 0.98
array/reductions/reduce/Int64/dims=2 1175437.5 ns 1185375 ns 0.99
array/reductions/reduce/Int64/dims=1L 2041917 ns 2026167 ns 1.01
array/reductions/reduce/Int64/dims=2L 3755583 ns 4244833 ns 0.88
array/reductions/reduce/Float32/1d 826645.5 ns 1009292 ns 0.82
array/reductions/reduce/Float32/dims=1 806562.5 ns 851166 ns 0.95
array/reductions/reduce/Float32/dims=2 837791 ns 852062.5 ns 0.98
array/reductions/reduce/Float32/dims=1L 1359417 ns 1328000 ns 1.02
array/reductions/reduce/Float32/dims=2L 1835917 ns 1813625 ns 1.01
array/reductions/mapreduce/Int64/1d 1322875 ns 1566000 ns 0.84
array/reductions/mapreduce/Int64/dims=1 1127708 ns 1108708 ns 1.02
array/reductions/mapreduce/Int64/dims=2 1165292 ns 1153187.5 ns 1.01
array/reductions/mapreduce/Int64/dims=1L 1981417 ns 2031583 ns 0.98
array/reductions/mapreduce/Int64/dims=2L 3656937.5 ns 3617875 ns 1.01
array/reductions/mapreduce/Float32/1d 863396 ns 1047375 ns 0.82
array/reductions/mapreduce/Float32/dims=1 805500 ns 834687.5 ns 0.97
array/reductions/mapreduce/Float32/dims=2 830583 ns 873208.5 ns 0.95
array/reductions/mapreduce/Float32/dims=1L 1362584 ns 1318500 ns 1.03
array/reductions/mapreduce/Float32/dims=2L 1826271 ns 1842479 ns 0.99
array/private/copyto!/gpu_to_gpu 530834 ns 631625 ns 0.84
array/private/copyto!/cpu_to_gpu 735958.5 ns 796333 ns 0.92
array/private/copyto!/gpu_to_cpu 689312.5 ns 374541.5 ns 1.84
array/private/iteration/findall/int 1585854 ns 1555542 ns 1.02
array/private/iteration/findall/bool 1466042 ns 1407417 ns 1.04
array/private/iteration/findfirst/int 2112625 ns 2072500 ns 1.02
array/private/iteration/findfirst/bool 2043750 ns 2051292 ns 1.00
array/private/iteration/scalar 3481333.5 ns 5517729 ns 0.63
array/private/iteration/logical 2666187.5 ns 2653750 ns 1.00
array/private/iteration/findmin/1d 2583792 ns 2503750 ns 1.03
array/private/iteration/findmin/2d 1853375 ns 1792250 ns 1.03
array/private/copy 791042 ns 569166 ns 1.39
array/shared/copyto!/gpu_to_gpu 86250 ns 84166 ns 1.02
array/shared/copyto!/cpu_to_gpu 86959 ns 81417 ns 1.07
array/shared/copyto!/gpu_to_cpu 85833 ns 81875 ns 1.05
array/shared/iteration/findall/int 1580208 ns 1574708.5 ns 1.00
array/shared/iteration/findall/bool 1483625 ns 1417688 ns 1.05
array/shared/iteration/findfirst/int 1693958.5 ns 1669333.5 ns 1.01
array/shared/iteration/findfirst/bool 1654395.5 ns 1652125 ns 1.00
array/shared/iteration/scalar 204375 ns 205583.5 ns 0.99
array/shared/iteration/logical 2358021 ns 2455354.5 ns 0.96
array/shared/iteration/findmin/1d 2211229 ns 2120042 ns 1.04
array/shared/iteration/findmin/2d 1870083 ns 1807291 ns 1.03
array/shared/copy 216209 ns 242458 ns 0.89
array/permutedims/4d 2506042 ns 2399104 ns 1.04
array/permutedims/2d 1195125 ns 1179666 ns 1.01
array/permutedims/3d 1798854.5 ns 1685395.5 ns 1.07
metal/synchronization/stream 19542 ns 19209 ns 1.02
metal/synchronization/context 20000 ns 19917 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@papo1011 papo1011 changed the title Add nested conversion of Float64 to Float32 [Do not merge] Add nested conversion of Float64 to Float32 Mar 1, 2026
- add unit tests for `_mtl_adapt_eltype`
- integration tests for `mtl` with SVector
- format changes with runic
@papo1011 papo1011 changed the title [Do not merge] Add nested conversion of Float64 to Float32 Add nested conversion of Float64 to Float32 Mar 1, 2026
@maleadt maleadt linked an issue Mar 3, 2026 that may be closed by this pull request
@maleadt

maleadt commented Mar 3, 2026

Copy link
Copy Markdown
Member

Thanks. Inspecting a datatype is almost always the wrong solution though. I've made Adapt behave recursively for StaticArrays, JuliaGPU/Adapt.jl#104, and simplified the implementation here.

@maleadt maleadt merged commit 14d326a into JuliaGPU:main Mar 3, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nested conversion of Float64 to Float32

2 participants