Skip to content

Commit 510a09f

Browse files
front&excuter:联合调试matmul.cblas,cublas (#9)
* excuter(cpu/cuda):subscalar * front:newtensor,print 联合调试 * front:newtensor,print 联合调试 * Fix build error in gcc compiler. (#5) In gcc/++13 compiler, it shows error: ``` dtype.hpp:8:29: error: found ‘:’ in nested-name-specifier, expected ‘::’ 8 | enum class DataCategory : uint8_t ``` * front&excuter:联合调试,修复init、elementwise的IR * front&excuter:联合调试matmul.cblas,cublas --------- Co-authored-by: harryharrygo <harryharrygogogo@gmail.com>
1 parent 26c8fc3 commit 510a09f

12 files changed

Lines changed: 358 additions & 26 deletions

File tree

doc/excuter/op-mem-cuda/list.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
| newtensor | none | newtensor(vector<int32> shape)->(tensor<any> tensor1) | T1 = zeros(shape) | newtensor(vector<int32> shape)->(tensor<any> tensor1) |
1717
| newtensor | none | newtensor(var<string> shape)->(tensor<any> tensor1) | T1 = zeros(shape) | newtensor(var<string> shape)->(tensor<any> tensor1) |
1818
| vecset | none | vecset(vector<any> value)->(vector<any> name) | shape = [3 4 5] | vecset(vector<any> value)->(vector<any> name) |
19+
| matmul | cublas | matmul(tensor<any> A, tensor<any> B)->(tensor<any> C) | T3=T1 @ T2 | matmul(tensor<any> A, tensor<any> B)->(tensor<any> C) |
1920
| sub | miaobyte | sub(tensor<any> A, tensor<any> B)->(tensor<any> C) | T3=T1-T2 | sub(tensor<any> A, tensor<any> B)->(tensor<any> C) |
2021
| argset | none | argset(var<any> value)->(var<any> name) | var argname = argvalue | argset(var<any> value)->(var<any> name) |

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@
1717
| newtensor | none | newtensor(vector<int32> shape)->(tensor<any> tensor1) | T1 =Tensor(shape=[...]) | newtensor(vector<int32> shape)->(tensor<any> tensor1) |
1818
| newtensor | none | newtensor(var<string> shape)->(tensor<any> tensor1) | T1 =Tensor(shape=[...]) | newtensor(var<string> shape)->(tensor<any> tensor1) |
1919
| vecset | none | vecset(vector<any> value)->(vector<any> name) | shape = [3 4 5] | vecset(vector<any> value)->(vector<any> name) |
20+
| matmul | cblas | matmul(tensor<float64|float32> A, tensor<float64|float32> B)->(tensor<float64|float32> C) | T3=T1 @ T2 | matmul(tensor<float64|float32> A, tensor<float64|float32> B)->(tensor<float64|float32> C) |
21+
| matmul | miaobyte | matmul(tensor<any> A, tensor<any> B)->(tensor<any> C) | T3=T1 @ T2 | matmul(tensor<any> A, tensor<any> B)->(tensor<any> C) |
2022
| sub | miaobyte | sub(tensor<any> a, tensor<any> b)->(tensor<any> c) | T3=T1-T2 | sub(tensor<any> a, tensor<any> b)->(tensor<any> c) |
2123
| argset | none | argset(var<any> value)->(var<any> name) | var argname = argvalue | argset(var<any> value)->(var<any> name) |

excuter/cpp-common/src/deepx/tensorfunc/matmul.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include "deepx/tensor.hpp"
55
#include "deepx/tensorfunc/authors.hpp"
6-
6+
#include "stdutil/error.hpp"
77
namespace deepx::tensorfunc
88
{
99
bool check_matmul_shape(const Shape &a, const Shape &b)
@@ -29,7 +29,10 @@ namespace deepx::tensorfunc
2929
template <typename Author, typename T>
3030
struct matmulDispatcher
3131
{
32-
static void matmul(const Tensor<T> &A, const Tensor<T> &B, Tensor<T> &C) = delete;
32+
static void matmul(const Tensor<T> &A, const Tensor<T> &B, Tensor<T> &C)
33+
{
34+
throw NotImplementError("matmul");
35+
}
3336
};
3437

3538
template <typename Author, typename T>

excuter/op-mem-cuda/src/client/tfs.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "deepx/tf/print.hpp"
55
#include "deepx/tf/init.hpp"
66
#include "deepx/tf/elementwise_basic.hpp"
7+
#include "deepx/tf/matmul.hpp"
78
#include "deepx/dtype.hpp"
89
#include "deepx/tf/tffactory.hpp"
910
#include "deepx/tensorfunc/authors.hpp"
@@ -173,12 +174,19 @@ namespace deepx::tf
173174
// opfactory.add_op(Powscalar_miaobyte<float>());
174175
// opfactory.add_op(Powscalar_miaobyte<double>());
175176
}
176-
// // matmul
177-
// void register_matmul(OpFactory &opfactory)
178-
// {
179-
// opfactory.add_op(MatMul<float>());
180-
// opfactory.add_op(MatMul<double>());
181-
// }
177+
// matmul
178+
void register_matmul(TfFactory &tffactory)
179+
{
180+
tffactory.add_tf(std::make_shared<MatMul<cublas>>(vector<Param>(
181+
{
182+
Param("A", DataCategory::Tensor, Precision::Any),
183+
Param("B", DataCategory::Tensor, Precision::Any),
184+
}),
185+
vector<Param>(
186+
{
187+
Param("C", DataCategory::Tensor, Precision::Any),
188+
})));
189+
}
182190
// // changeshape
183191
void register_changeshape(TfFactory &tffactory)
184192
{
@@ -207,7 +215,7 @@ namespace deepx::tf
207215
register_init(tffactory);
208216
register_util(tffactory);
209217
register_elementwise(tffactory);
210-
// register_matmul(opfactory);
218+
register_matmul(tffactory);
211219
register_changeshape(tffactory);
212220
// register_reduce(opfactory);
213221
return 0;
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#ifndef DEEPX_TF_MATMUL_HPP
2+
#define DEEPX_TF_MATMUL_HPP
3+
4+
#include <cuda_fp16.h>
5+
#include <cuda_bf16.h>
6+
7+
#include "deepx/tf/tf.hpp"
8+
#include "deepx/dtype.hpp"
9+
#include "deepx/dtype_cuda.hpp"
10+
#include "deepx/tensorfunc/matmul_cublas.hpp"
11+
12+
namespace deepx::tf
13+
{
14+
template <typename Author>
15+
class MatMul : public TF
16+
{
17+
public:
18+
MatMul(const vector<Param> &args, const vector<Param> &returns)
19+
{
20+
this->name = "matmul";
21+
this->author = Author::name();
22+
this->args = args;
23+
this->returns = returns;
24+
}
25+
26+
MatMul(string text)
27+
{
28+
this->parse(text);
29+
this->author = Author::name();
30+
if (this->name != "matmul")
31+
{
32+
throw std::runtime_error("Invalid name: " + this->name);
33+
}
34+
}
35+
string math_formula() const override
36+
{
37+
return "T3=T1 @ T2";
38+
}
39+
shared_ptr<TF> clone() const override
40+
{
41+
return make_shared<MatMul<Author>>(*this);
42+
}
43+
int run(shared_ptr<MemBase> mem, string &error) override
44+
{
45+
Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype;
46+
Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype;
47+
Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype;
48+
if (a_type != b_type || a_type != c_type)
49+
{
50+
error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " != " + precision_str(c_type);
51+
return 1;
52+
}
53+
switch (a_type)
54+
{
55+
case Precision::Float64:
56+
tensorfunc::matmul<Author, double>(*mem->gettensor<double>(this->args[0].textvalue), *mem->gettensor<double>(this->args[1].textvalue), *mem->gettensor<double>(this->returns[0].textvalue));
57+
break;
58+
case Precision::Float32:
59+
tensorfunc::matmul<Author, float>(*mem->gettensor<float>(this->args[0].textvalue), *mem->gettensor<float>(this->args[1].textvalue), *mem->gettensor<float>(this->returns[0].textvalue));
60+
break;
61+
case Precision::Float16:
62+
tensorfunc::matmul<Author, half>(*mem->gettensor<half>(this->args[0].textvalue), *mem->gettensor<half>(this->args[1].textvalue), *mem->gettensor<half>(this->returns[0].textvalue));
63+
break;
64+
case Precision::BFloat16:
65+
tensorfunc::matmul<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), *mem->gettensor<nv_bfloat16>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
66+
break;
67+
case Precision::Int64:
68+
tensorfunc::matmul<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
69+
break;
70+
case Precision::Int32:
71+
tensorfunc::matmul<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
72+
break;
73+
case Precision::Int16:
74+
tensorfunc::matmul<Author, int16_t>(*mem->gettensor<int16_t>(this->args[0].textvalue), *mem->gettensor<int16_t>(this->args[1].textvalue), *mem->gettensor<int16_t>(this->returns[0].textvalue));
75+
break;
76+
case Precision::Int8:
77+
tensorfunc::matmul<Author, int8_t>(*mem->gettensor<int8_t>(this->args[0].textvalue), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int8_t>(this->returns[0].textvalue));
78+
break;
79+
default:
80+
error = "Unsupported dtype: " + precision_str(a_type);
81+
return 1;
82+
}
83+
return 0;
84+
}
85+
};
86+
}
87+
88+
#endif

excuter/op-mem-ompsimd/src/client/tfs.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "deepx/tf/changeshape.hpp"
99
#include "deepx/tf/elementwise.hpp"
1010
#include "deepx/tf/tffactory.hpp"
11-
11+
#include "deepx/tf/matmul.hpp"
1212
#include "deepx/tensorfunc/authors.hpp"
1313
namespace deepx::tf
1414
{
@@ -186,12 +186,28 @@ namespace deepx::tf
186186
// opfactory.add_op(Powscalar_miaobyte<float>());
187187
// opfactory.add_op(Powscalar_miaobyte<double>());
188188
}
189-
// // matmul
190-
// void register_matmul(OpFactory &opfactory)
191-
// {
192-
// opfactory.add_op(MatMul<float>());
193-
// opfactory.add_op(MatMul<double>());
194-
// }
189+
// matmul
190+
void register_matmul(TfFactory &tffactory)
191+
{
192+
tffactory.add_tf(std::make_shared<MatMul<miaobyte>>(vector<Param>(
193+
{
194+
Param("A", DataCategory::Tensor, Precision::Any),
195+
Param("B", DataCategory::Tensor, Precision::Any),
196+
}),
197+
vector<Param>(
198+
{
199+
Param("C", DataCategory::Tensor, Precision::Any),
200+
})));
201+
tffactory.add_tf(std::make_shared<MatMul<cblas>>(vector<Param>(
202+
{
203+
Param("A", DataCategory::Tensor, Precision::Float64|Precision::Float32),
204+
Param("B", DataCategory::Tensor, Precision::Float64|Precision::Float32),
205+
}),
206+
vector<Param>(
207+
{
208+
Param("C", DataCategory::Tensor, Precision::Float64|Precision::Float32),
209+
})));
210+
}
195211
// // changeshape
196212
void register_changeshape(TfFactory &tffactory)
197213
{
@@ -220,7 +236,7 @@ namespace deepx::tf
220236
register_init(tffactory);
221237
register_util(tffactory);
222238
register_elementwise(tffactory);
223-
// register_matmul(opfactory);
239+
register_matmul(tffactory);
224240
register_changeshape(tffactory);
225241
// register_reduce(opfactory);
226242
return 0;

excuter/op-mem-ompsimd/src/deepx/tensorfunc/matmul_cblas.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef DEEPX_TENSORFUNC_MATMUL_HPP
2-
#define DEEPX_TENSORFUNC_MATMUL_HPP
1+
#ifndef DEEPX_TENSORFUNC_MATMUL_CBLAS_HPP
2+
#define DEEPX_TENSORFUNC_MATMUL_CBLAS_HPP
33

44
#include <cblas.h> // 如果使用 OpenBLAS
55
#include "deepx/tensor.hpp"
@@ -64,7 +64,7 @@ namespace deepx::tensorfunc
6464
{
6565
static void matmul(const Tensor<double> &a, const Tensor<double> &b, Tensor<double> &c)
6666
{
67-
if (!check_shape(a.shape, b.shape))
67+
if (!check_matmul_shape(a.shape, b.shape))
6868
{
6969
throw std::invalid_argument("a.shape could matmul with b.shape");
7070
}
@@ -150,7 +150,7 @@ namespace deepx::tensorfunc
150150
{
151151
static void matmuladd(const Tensor<float> &a, const Tensor<float> &b, const float &alpha, const float &beta, Tensor<float> &c)
152152
{
153-
if (!check_shape(a.shape, b.shape))
153+
if (!check_matmul_shape(a.shape, b.shape))
154154
{
155155
throw std::invalid_argument("a.shape could matmul with b.shape");
156156
}
@@ -208,7 +208,7 @@ namespace deepx::tensorfunc
208208
{
209209
static void matmuladd(const Tensor<double> &a, const Tensor<double> &b, const double &alpha, const double &beta, Tensor<double> &c)
210210
{
211-
if (!check_shape(a.shape, b.shape))
211+
if (!check_matmul_shape(a.shape, b.shape))
212212
{
213213
throw std::invalid_argument("a.shape could matmul with b.shape");
214214
}
@@ -261,4 +261,4 @@ namespace deepx::tensorfunc
261261
}
262262
};
263263
}
264-
#endif // DEEPX_TENSORFUNC_MATMUL_HPP
264+
#endif // DEEPX_TENSORFUNC_MATMUL_CBLAS_HPP
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#ifndef DEEPX_TF_MATMUL_HPP
2+
#define DEEPX_TF_MATMUL_HPP
3+
4+
#include "deepx/tf/tf.hpp"
5+
#include "deepx/dtype.hpp"
6+
#include "deepx/dtype_ompsimd.hpp"
7+
#include "deepx/tensorfunc/matmul.hpp"
8+
#include "deepx/tensorfunc/matmul_cblas.hpp"
9+
#include "deepx/tensorfunc/matmul_miaobyte.hpp"
10+
namespace deepx::tf
11+
{
12+
template <typename Author>
13+
class MatMul : public TF
14+
{
15+
public:
16+
MatMul(const vector<Param> &args, const vector<Param> &returns)
17+
{
18+
this->name = "matmul";
19+
this->author = Author::name();
20+
this->args = args;
21+
this->returns = returns;
22+
}
23+
24+
MatMul(string text)
25+
{
26+
this->parse(text);
27+
this->author = Author::name();
28+
if (this->name != "matmul")
29+
{
30+
throw std::runtime_error("Invalid name: " + this->name);
31+
}
32+
}
33+
string math_formula() const override
34+
{
35+
return "T3=T1 @ T2";
36+
}
37+
shared_ptr<TF> clone() const override
38+
{
39+
return make_shared<MatMul<Author>>(*this);
40+
}
41+
int run(shared_ptr<MemBase> mem, string &error) override
42+
{
43+
Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype;
44+
Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype;
45+
Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype;
46+
if (a_type != b_type || a_type != c_type)
47+
{
48+
error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " != " + precision_str(c_type);
49+
return 1;
50+
}
51+
switch (a_type)
52+
{
53+
case Precision::Float64:
54+
tensorfunc::matmul<Author, double>(*mem->gettensor<double>(this->args[0].textvalue), *mem->gettensor<double>(this->args[1].textvalue), *mem->gettensor<double>(this->returns[0].textvalue));
55+
break;
56+
case Precision::Float32:
57+
tensorfunc::matmul<Author, float>(*mem->gettensor<float>(this->args[0].textvalue), *mem->gettensor<float>(this->args[1].textvalue), *mem->gettensor<float>(this->returns[0].textvalue));
58+
break;
59+
case Precision::Int64:
60+
tensorfunc::matmul<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
61+
break;
62+
case Precision::Int32:
63+
tensorfunc::matmul<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
64+
break;
65+
case Precision::Int16:
66+
tensorfunc::matmul<Author, int16_t>(*mem->gettensor<int16_t>(this->args[0].textvalue), *mem->gettensor<int16_t>(this->args[1].textvalue), *mem->gettensor<int16_t>(this->returns[0].textvalue));
67+
break;
68+
case Precision::Int8:
69+
tensorfunc::matmul<Author, int8_t>(*mem->gettensor<int8_t>(this->args[0].textvalue), *mem->gettensor<int8_t>(this->args[1].textvalue), *mem->gettensor<int8_t>(this->returns[0].textvalue));
70+
break;
71+
default:
72+
error = "Unsupported dtype: " + precision_str(a_type);
73+
return 1;
74+
}
75+
return 0;
76+
}
77+
};
78+
}
79+
80+
#endif

front/py/deepx/nn/functional/matmul.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
def matmul(
1212
a:Tensor,
1313
b: Tensor,
14-
out:Union[Tensor,str]='')->Tensor:
14+
out:Union[Tensor,str]='',
15+
author:str='cublas'):
1516
opnode = a.graph.add_op("matmul")
1617
opnode.add_input(a.node)
1718
opnode.add_input(b.node)
@@ -25,6 +26,6 @@ def matmul(
2526
outtensor=out
2627
outtensor.node.add_input(opnode)
2728
if a.graph.eager:
28-
ir=DeepxIR("matmul", a.dtype, [a.node.name,b.node.name], [outtensor.node.name])
29+
ir=DeepxIR("matmul", [a.node.name,b.node.name], [outtensor.node.name], author=author)
2930
send(ir)
3031
return outtensor

front/py/deepx/scheduler/client/udpconn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import select
44

55
class UDPConn:
6-
def __init__(self, endpoint: str = "localhost:8080"):
6+
def __init__(self, endpoint: str = "localhost:9090"):
77
# 解析endpoint
88
self._host, port_str = endpoint.split(':')
99
self._port = int(port_str)

0 commit comments

Comments
 (0)