Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
664480c
cuBlas level 1 method - scal supported
kballeda Oct 31, 2022
5180b9b
scal test case updated in onemkl.jl
kballeda Oct 31, 2022
1307b8d
NITS - cleanup
kballeda Oct 31, 2022
86f9dc7
indentation NITS
kballeda Oct 31, 2022
57495e9
updated with scal - deps
kballeda Oct 31, 2022
28f245f
indentation fixes
kballeda Oct 31, 2022
86591f5
NITS
kballeda Oct 31, 2022
c49620f
NITS
kballeda Oct 31, 2022
990300a
cleanup onemkl.jl
kballeda Oct 31, 2022
1e164db
updated with rmul! and testf usage
kballeda Nov 1, 2022
139de1c
NITS
kballeda Nov 1, 2022
23943bb
testf used for cpu/gpu testing.
kballeda Nov 2, 2022
b51b870
NITS - clenaup & included int specific calls to rmul! diverted to
kballeda Nov 2, 2022
ab0abb2
wrapper alpha turns elttype and support all combinationswq
kballeda Nov 3, 2022
d5a58e0
NITS
kballeda Nov 3, 2022
c39e90c
support for Cs, Zd configs of scal function
kballeda Nov 3, 2022
0a17298
updated with staticcast complex alpha
kballeda Nov 3, 2022
e0aa24e
added onestridedarray
kballeda Nov 4, 2022
b9a7d29
enable tests of complex tye
kballeda Nov 4, 2022
645dad7
Merge branch 'master' into l1_scal
kballeda Nov 7, 2022
1c1b206
updated with Csscal and Zdscal test enabled
kballeda Nov 8, 2022
d423806
NITS
kballeda Nov 8, 2022
071f69e
NITS
kballeda Nov 8, 2022
8168b50
Merge branch 'master' into l1_scal
kballeda Nov 8, 2022
7be9df6
NITS
kballeda Nov 8, 2022
038252f
NITS
kballeda Nov 8, 2022
d07339e
Cleanup of tests
kballeda Nov 8, 2022
cbbd3a4
Instead of isapprox use compare op
kballeda Nov 9, 2022
af68c99
Bug fix: disable f16 check as it is not supported (CI crash)
kballeda Nov 10, 2022
7dd2fa3
Merge branch 'master' into l1_scal
kballeda Nov 10, 2022
41db9d5
Merge branch 'master' into l1_scal
kballeda Nov 22, 2022
578c300
use force flush instead of wait()
kballeda Nov 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,56 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

// Support Level-1: SCAL primitive
extern "C" void onemklDscal(syclQueue_t device_queue, int64_t n, double alpha,
double *x, int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha,
x, incx);
__FORCE_MKL_FLUSH__(status);

}

extern "C" void onemklSscal(syclQueue_t device_queue, int64_t n, float alpha,
float *x, int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha,
x, incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCscal(syclQueue_t device_queue, int64_t n,
float _Complex alpha, float _Complex *x,
int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n,
static_cast<std::complex<float> >(alpha),
reinterpret_cast<std::complex<float> *>(x),incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCsscal(syclQueue_t device_queue, int64_t n,
float alpha, float _Complex *x,
int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha,
reinterpret_cast<std::complex<float> *>(x),incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZscal(syclQueue_t device_queue, int64_t n,
double _Complex alpha, double _Complex *x,
int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n,
static_cast<std::complex<double> >(alpha),
reinterpret_cast<std::complex<double> *>(x),incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZdscal(syclQueue_t device_queue, int64_t n,
double alpha, double _Complex *x,
int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, alpha,
reinterpret_cast<std::complex<double> *>(x),incx);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *result) {
auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, x, incx, result);
Expand Down
13 changes: 13 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
const double _Complex *B, int64_t ldb, double _Complex beta,
double _Complex *C, int64_t ldc);

// Level-1: scal oneMKL
void onemklDscal(syclQueue_t device_queue, int64_t n, double alpha,
double *x, int64_t incx);
void onemklSscal(syclQueue_t device_queue, int64_t n, float alpha,
float *x, int64_t incx);
void onemklCscal(syclQueue_t device_queue, int64_t n, float _Complex alpha,
float _Complex *x, int64_t incx);
void onemklCsscal(syclQueue_t device_queue, int64_t n, float alpha,
float _Complex *x, int64_t incx);
void onemklZscal(syclQueue_t device_queue, int64_t n, double _Complex alpha,
double _Complex *x, int64_t incx);
void onemklZdscal(syclQueue_t device_queue, int64_t n, double alpha,
double _Complex *x, int64_t incx);
// Supported Level-1: Nrm2
void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *result);
Expand Down
29 changes: 29 additions & 0 deletions lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld
C::ZePtr{ComplexF64}, ldc::Int64)::Cint
end

function onemklDscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklDscal(device_queue::syclQueue_t, n::Int64,
alpha::Cdouble, x::ZePtr{Cdouble}, incx::Int64)::Cvoid
end

function onemklSscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklSscal(device_queue::syclQueue_t, n::Int64,
alpha::Cfloat, x::ZePtr{Cfloat}, incx::Int64)::Cvoid
end

function onemklZscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklZscal(device_queue::syclQueue_t, n::Int64,
alpha::ComplexF64, x::ZePtr{ComplexF64}, incx::Int64)::Cvoid
end

function onemklZdscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklZdscal(device_queue::syclQueue_t, n::Int64,
alpha::Cdouble, x::ZePtr{ComplexF64}, incx::Int64)::Cvoid
end

function onemklCscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklCscal(device_queue::syclQueue_t, n::Int64,
alpha::ComplexF32, x::ZePtr{ComplexF32}, incx::Int64)::Cvoid
end

function onemklCsscal(device_queue, n, alpha, x, incx)
@ccall liboneapi_support.onemklCsscal(device_queue::syclQueue_t, n::Int64,
alpha::Cfloat, x::ZePtr{ComplexF32}, incx::Int64)::Cvoid
end
function onemklDnrm2(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklDnrm2(device_queue::syclQueue_t,
n::Int64, x::ZePtr{Cdouble}, incx::Int64,
Expand Down
6 changes: 6 additions & 0 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N
end
end

LinearAlgebra.rmul!(x::oneStridedVecOrMat{<:onemklFloat}, k::Number) =
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if these two rules will properly dispatch the combination of alpha types of ComplexF32, ComplexF64, F32 and F64 to the ComplexF32 and Complex64 scal functions.

Please write specific tests for all the combinations to make sure they are dispatched properly.

If not, we will want to use AMDGPU.jl's type based dispatching rules at
https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/blas/wrappers.jl#L85, and

https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/blas/wrappers.jl#L106

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, Thanks

oneMKL.scal!(length(x), convert(eltype(x),k), x)

# Work around ambiguity with GPUArrays wrapper
LinearAlgebra.rmul!(x::oneStridedVecOrMat{<:onemklFloat}, k::Real) =
invoke(rmul!, Tuple{typeof(x), Number}, x, k)
LinearAlgebra.norm(x::oneStridedVecOrMat{<:onemklFloat}) = oneMKL.nrm2(length(x), x)

for NT in (Number, Real)
Expand Down
3 changes: 2 additions & 1 deletion lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ using GPUArrays

include("libonemkl.jl")

const onemklFloat = Union{Float64,Float32,Float16,ComplexF64,ComplexF32}
# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}

include("wrappers.jl")
include("linalg.jl")
Expand Down
28 changes: 28 additions & 0 deletions lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ function Base.convert(::Type{onemklTranspose}, trans::Char)
end

# level 1
## scal
for (fname, elty) in
((:onemklDscal,:Float64),
(:onemklSscal,:Float32),
(:onemklZscal,:ComplexF64),
(:onemklCscal,:ComplexF32))
@eval begin
function scal!(n::Integer,
alpha::$elty,
x::oneStridedArray{$elty})
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), n, alpha, x, stride(x,1))
x
end
end
end

## nrm2
for (fname, elty, ret_type) in
((:onemklDnrm2, :Float64,:Float64),
Expand All @@ -32,6 +49,17 @@ for (fname, elty, ret_type) in
end
end

for (fname, elty, celty) in ((:onemklCsscal, :Float32, :ComplexF32),
(:onemklZdscal, :Float64, :ComplexF64))
@eval begin
function scal!(n::Integer,
alpha::$elty,
x::oneStridedArray{$celty})
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), n, alpha, x, stride(x,1))
end
end
end
#
# BLAS
#
Expand Down
23 changes: 22 additions & 1 deletion test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,28 @@ m = 20
A = oneArray(rand(T, m))
B = oneArray{T}(undef, m)
oneMKL.copy!(m,A,B)
@test Array(A) == Array(B)
@test Array(A) == Array(B)
end

@testset "scal" begin
# Test scal primitive [alpha/x: F32, F64, CF32, CF64]
alpha = rand(T,1)
@test testf(rmul!, rand(T,m), alpha[1])

# Test scal primitive [alpha - F32, F64, x - CF32, CF64]
A = rand(T,m)
gpuA = oneArray(A)
if T === ComplexF32
alphaf32 = rand(Float32, 1)
oneMKL.scal!(m, alphaf32[1], gpuA)
@test Array(A .* alphaf32[1]) ≈ Array(gpuA)
end

if T === ComplexF64
alphaf64 = rand(Float64, 1)
oneMKL.scal!(m, alphaf64[1], gpuA)
@test Array(A .* alphaf64[1]) ≈ Array(gpuA)
end
end

@testset "nrm2" begin
Expand Down