diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index bbca1a89..a26ed8f5 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -87,6 +87,56 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +// Support Level-1: SCAL primitive +extern "C" void onemklDscal(syclQueue_t device_queue, int64_t n, double alpha, + double *x, int64_t incx) { + auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha, + x, incx); + __FORCE_MKL_FLUSH__(status); + +} + +extern "C" void onemklSscal(syclQueue_t device_queue, int64_t n, float alpha, + float *x, int64_t incx) { + auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha, + x, incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCscal(syclQueue_t device_queue, int64_t n, + float _Complex alpha, float _Complex *x, + int64_t incx) { + auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, + static_cast >(alpha), + reinterpret_cast *>(x),incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCsscal(syclQueue_t device_queue, int64_t n, + float alpha, float _Complex *x, + int64_t incx) { + auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha, + reinterpret_cast *>(x),incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZscal(syclQueue_t device_queue, int64_t n, + double _Complex alpha, double _Complex *x, + int64_t incx) { + auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, + static_cast >(alpha), + reinterpret_cast *>(x),incx); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZdscal(syclQueue_t device_queue, int64_t n, + double alpha, double _Complex *x, + int64_t incx) { + auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha, + reinterpret_cast *>(x),incx); + __FORCE_MKL_FLUSH__(status); +} + extern "C" void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, double *result) { auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, x, incx, result); diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index f9ea315c..1fec16da 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -39,6 +39,19 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, const double _Complex *B, int64_t ldb, double _Complex beta, double _Complex *C, int64_t ldc); +// Level-1: scal oneMKL +void onemklDscal(syclQueue_t device_queue, int64_t n, double alpha, + double *x, int64_t incx); +void onemklSscal(syclQueue_t device_queue, int64_t n, float alpha, + float *x, int64_t incx); +void onemklCscal(syclQueue_t device_queue, int64_t n, float _Complex alpha, + float _Complex *x, int64_t incx); +void onemklCsscal(syclQueue_t device_queue, int64_t n, float alpha, + float _Complex *x, int64_t incx); +void onemklZscal(syclQueue_t device_queue, int64_t n, double _Complex alpha, + double _Complex *x, int64_t incx); +void onemklZdscal(syclQueue_t device_queue, int64_t n, double alpha, + double _Complex *x, int64_t incx); // Supported Level-1: Nrm2 void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, double *result); diff --git a/lib/mkl/libonemkl.jl b/lib/mkl/libonemkl.jl index f0bd8064..357f12db 100644 --- a/lib/mkl/libonemkl.jl +++ b/lib/mkl/libonemkl.jl @@ -42,6 +42,35 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld C::ZePtr{ComplexF64}, ldc::Int64)::Cint end +function onemklDscal(device_queue, n, alpha, x, incx) + @ccall liboneapi_support.onemklDscal(device_queue::syclQueue_t, n::Int64, + alpha::Cdouble, x::ZePtr{Cdouble}, incx::Int64)::Cvoid +end + +function onemklSscal(device_queue, n, alpha, x, incx) + @ccall liboneapi_support.onemklSscal(device_queue::syclQueue_t, n::Int64, + alpha::Cfloat, x::ZePtr{Cfloat}, incx::Int64)::Cvoid +end + +function onemklZscal(device_queue, n, alpha, x, incx) + @ccall liboneapi_support.onemklZscal(device_queue::syclQueue_t, n::Int64, + alpha::ComplexF64, x::ZePtr{ComplexF64}, incx::Int64)::Cvoid +end + +function onemklZdscal(device_queue, n, alpha, x, incx) + @ccall liboneapi_support.onemklZdscal(device_queue::syclQueue_t, n::Int64, + alpha::Cdouble, x::ZePtr{ComplexF64}, incx::Int64)::Cvoid +end + +function onemklCscal(device_queue, n, alpha, x, incx) + @ccall liboneapi_support.onemklCscal(device_queue::syclQueue_t, n::Int64, + alpha::ComplexF32, x::ZePtr{ComplexF32}, incx::Int64)::Cvoid +end + +function onemklCsscal(device_queue, n, alpha, x, incx) + @ccall liboneapi_support.onemklCsscal(device_queue::syclQueue_t, n::Int64, + alpha::Cfloat, x::ZePtr{ComplexF32}, incx::Int64)::Cvoid +end function onemklDnrm2(device_queue, n, x, incx, result) @ccall liboneapi_support.onemklDnrm2(device_queue::syclQueue_t, n::Int64, x::ZePtr{Cdouble}, incx::Int64, diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index 94905416..3cfad141 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -49,6 +49,12 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N end end +LinearAlgebra.rmul!(x::oneStridedVecOrMat{<:onemklFloat}, k::Number) = + oneMKL.scal!(length(x), convert(eltype(x),k), x) + +# Work around ambiguity with GPUArrays wrapper +LinearAlgebra.rmul!(x::oneStridedVecOrMat{<:onemklFloat}, k::Real) = + invoke(rmul!, Tuple{typeof(x), Number}, x, k) LinearAlgebra.norm(x::oneStridedVecOrMat{<:onemklFloat}) = oneMKL.nrm2(length(x), x) for NT in (Number, Real) diff --git a/lib/mkl/oneMKL.jl b/lib/mkl/oneMKL.jl index d83f2141..7b0b24b2 100644 --- a/lib/mkl/oneMKL.jl +++ b/lib/mkl/oneMKL.jl @@ -12,7 +12,8 @@ using GPUArrays include("libonemkl.jl") -const onemklFloat = Union{Float64,Float32,Float16,ComplexF64,ComplexF32} +# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16 +const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32} include("wrappers.jl") include("linalg.jl") diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index aca2e5d0..4f48cb1d 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -15,6 +15,23 @@ function Base.convert(::Type{onemklTranspose}, trans::Char) end # level 1 +## scal +for (fname, elty) in + ((:onemklDscal,:Float64), + (:onemklSscal,:Float32), + (:onemklZscal,:ComplexF64), + (:onemklCscal,:ComplexF32)) + @eval begin + function scal!(n::Integer, + alpha::$elty, + x::oneStridedArray{$elty}) + queue = global_queue(context(x), device(x)) + $fname(sycl_queue(queue), n, alpha, x, stride(x,1)) + x + end + end +end + ## nrm2 for (fname, elty, ret_type) in ((:onemklDnrm2, :Float64,:Float64), @@ -32,6 +49,17 @@ for (fname, elty, ret_type) in end end +for (fname, elty, celty) in ((:onemklCsscal, :Float32, :ComplexF32), + (:onemklZdscal, :Float64, :ComplexF64)) + @eval begin + function scal!(n::Integer, + alpha::$elty, + x::oneStridedArray{$celty}) + queue = global_queue(context(x), device(x)) + $fname(sycl_queue(queue), n, alpha, x, stride(x,1)) + end + end +end # # BLAS # diff --git a/test/onemkl.jl b/test/onemkl.jl index aaf6c2d2..8ab9e2a6 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -12,7 +12,28 @@ m = 20 A = oneArray(rand(T, m)) B = oneArray{T}(undef, m) oneMKL.copy!(m,A,B) - @test Array(A) == Array(B) + @test Array(A) == Array(B) + end + + @testset "scal" begin + # Test scal primitive [alpha/x: F32, F64, CF32, CF64] + alpha = rand(T,1) + @test testf(rmul!, rand(T,m), alpha[1]) + + # Test scal primitive [alpha - F32, F64, x - CF32, CF64] + A = rand(T,m) + gpuA = oneArray(A) + if T === ComplexF32 + alphaf32 = rand(Float32, 1) + oneMKL.scal!(m, alphaf32[1], gpuA) + @test Array(A .* alphaf32[1]) ≈ Array(gpuA) + end + + if T === ComplexF64 + alphaf64 = rand(Float64, 1) + oneMKL.scal!(m, alphaf64[1], gpuA) + @test Array(A .* alphaf64[1]) ≈ Array(gpuA) + end end @testset "nrm2" begin