Skip to content

Commit ef51233

Browse files
committed
excuter(cpu/cuda):constant int8,cuda
1 parent 6664507 commit ef51233

13 files changed

Lines changed: 89 additions & 143 deletions

File tree

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
| Operation | Author | Func Def | Math Formula | IR Instruction |
66
|-----------|--------|------------|--------------|----------------|
7-
| concat | none | concat(listtensor<any> tensors, var<int32> axis)->(tensor<any> Tresult) | Tresult = concat([T1, T2...], axis=3) | concat(listtensor<any> tensors, var<int32> axis)->(tensor<any> Tresult) |
7+
| concat | none | concat()->() | Tresult = concat([T1, T2...], axis=3) | concat()->() |
88
| constant | miaobyte | constant(tensor<any> t, var<any> value)->() | print(T1) | constant(tensor<any> t, var<any> value)->() |
99
| print | miaobyte | print(tensor<any> )->() | print(T1) | print(tensor<any> )->() |
1010
| print | miaobyte | print(tensor<any> , var<string> )->() | print(T1) | print(tensor<any> , var<string> )->() |

excuter/cpp-common/src/deepx/mem/mem.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,31 @@ namespace deepx::mem
173173
result->data = ptr_tensor->data;
174174
break;
175175
}
176+
case Precision::Int64:
177+
{
178+
auto ptr_tensor = std::static_pointer_cast<Tensor<int64_t>>(ptr);
179+
result->data = ptr_tensor->data;
180+
break;
181+
}
176182
case Precision::Int32:
177183
{
178184
auto ptr_tensor = std::static_pointer_cast<Tensor<int32_t>>(ptr);
179185
result->data = ptr_tensor->data;
180186
break;
181187
}
188+
case Precision::Int16:
189+
{
190+
auto ptr_tensor = std::static_pointer_cast<Tensor<int16_t>>(ptr);
191+
result->data = ptr_tensor->data;
192+
break;
193+
}
194+
case Precision::Int8:
195+
{
196+
auto ptr_tensor = std::static_pointer_cast<Tensor<int8_t>>(ptr);
197+
result->data = ptr_tensor->data;
198+
break;
199+
}
200+
182201
default:
183202
throw std::runtime_error("Unsupported dtype: " + precision_str(ptr->shape.dtype));
184203
}

excuter/cpp-common/src/deepx/tf/tf.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,6 @@ namespace deepx::tf
446446
return true;
447447
}
448448

449-
void TF::funcdef(int polymorphism)
450-
{
451-
// 基类的默认实现为空
452-
// 派生类需要重写这个函数来定义具体的函数签名
453-
}
449+
454450

455451
}

excuter/cpp-common/src/deepx/tf/tf.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ namespace deepx::tf
5353
throw NotImplementError(name);
5454
}
5555
virtual string math_formula() const;
56-
virtual void funcdef(int polymorphism=0);
57-
56+
5857
void parse(const string &str);
5958
std::string to_string(bool show_extra=false, bool show_name=true) const;
6059
void init(const string &opname,

excuter/op-mem-cuda/src/client/tfs.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace deepx::tf
2020
{
2121
Param("name", DataCategory::Var, Precision::Any),
2222
})));
23-
tffactory.add_tf(std::make_shared<VecSet>(
23+
tffactory.add_tf(std::make_shared<VecSet>(
2424
vector<Param>(
2525
{
2626
Param("value", DataCategory::Vector, Precision::Any),
@@ -29,8 +29,22 @@ namespace deepx::tf
2929
{
3030
Param("name", DataCategory::Vector, Precision::Any),
3131
})));
32-
tffactory.add_tf(std::make_shared<NewTensor>(0));
33-
tffactory.add_tf(std::make_shared<NewTensor>(1));
32+
tffactory.add_tf(std::make_shared<NewTensor>(vector<Param>(
33+
{
34+
Param("shape", DataCategory::Vector, Precision::Int32),
35+
}),
36+
vector<Param>(
37+
{
38+
Param("tensor1", DataCategory::Tensor, Precision::Any),
39+
})));
40+
tffactory.add_tf(std::make_shared<NewTensor>(vector<Param>(
41+
{
42+
Param("shape", DataCategory::Var, Precision::String),
43+
}),
44+
vector<Param>(
45+
{
46+
Param("tensor1", DataCategory::Tensor, Precision::Any),
47+
})));
3448
// opfactory.add_op(DelTensor<float>());
3549
}
3650

@@ -41,11 +55,11 @@ namespace deepx::tf
4155
// opfactory.add_op(Uniform<double>());
4256

4357
tffactory.add_tf(std::make_shared<Constant<miaobyte>>(vector<Param>(
44-
{
45-
Param("t", DataCategory::Tensor, Precision::Any),
46-
Param("value", DataCategory::Var, Precision::Any),
47-
}),
48-
vector<Param>()));
58+
{
59+
Param("t", DataCategory::Tensor, Precision::Any),
60+
Param("value", DataCategory::Var, Precision::Any),
61+
}),
62+
vector<Param>()));
4963

5064
// opfactory.add_op(Arange<float>());
5165
// opfactory.add_op(Arange<double>());

excuter/op-mem-cuda/src/deepx/tf/arg.hpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ namespace deepx::tf
2727
throw std::runtime_error("Invalid name: " + this->name);
2828
}
2929
}
30-
void funcdef(int polymorphism = 0) override
31-
{
32-
this->args.push_back(Param("value", DataCategory::Var, Precision::Any));
33-
this->returns.push_back(Param("name", DataCategory::Var, Precision::Any));
34-
}
30+
3531
string math_formula() const override
3632
{
3733
return "var argname = argvalue";
@@ -96,11 +92,7 @@ namespace deepx::tf
9692
throw std::runtime_error("Invalid name: " + this->name);
9793
}
9894
}
99-
void funcdef(int polymorphism = 0) override
100-
{
101-
this->args.push_back(Param("shape", DataCategory::Vector, Precision::Any));
102-
this->returns.push_back(Param("name", DataCategory::Vector, Precision::Any));
103-
}
95+
10496
string math_formula() const override
10597
{
10698
return "shape = [3 4 5]";

excuter/op-mem-cuda/src/deepx/tf/new.hpp

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,14 @@ namespace deepx::tf
1414
class NewTensor : public TF
1515
{
1616
public:
17-
NewTensor()
17+
NewTensor(vector<Param> args, vector<Param> returns)
1818
{
1919
this->name = "newtensor";
20-
this->funcdef();
20+
this->args = args;
21+
this->returns = returns;
2122
}
22-
// support polymorphism=0 or 1
23-
NewTensor(int polymorphism = 0)
24-
{
25-
this->name = "newtensor";
26-
this->funcdef(polymorphism);
27-
}
28-
NewTensor(string text, bool call = false)
23+
24+
NewTensor(string text)
2925
{
3026
this->parse(text);
3127
if (this->name != "newtensor")
@@ -146,22 +142,6 @@ namespace deepx::tf
146142
return 0;
147143
};
148144

149-
void funcdef(int polymorphism = 0) override
150-
{
151-
switch (polymorphism)
152-
{
153-
154-
case 1:
155-
this->args.push_back(Param("shape", DataCategory::Var, Precision::String));
156-
break;
157-
case 0:
158-
default:
159-
this->args.push_back(Param("shape", DataCategory::Vector, Precision::Int32));
160-
break;
161-
}
162-
this->returns.push_back(Param("tensor1", DataCategory::Tensor, Precision::Any));
163-
};
164-
165145
string math_formula() const override
166146
{
167147
return "T1 = zeros(shape)";
@@ -174,7 +154,7 @@ namespace deepx::tf
174154
CopyTensor()
175155
{
176156
this->name = "copytensor";
177-
this->funcdef();
157+
178158
}
179159
CopyTensor(string text)
180160
{
@@ -192,11 +172,7 @@ namespace deepx::tf
192172
// tensorfunc::copytensor(*src,*dst);
193173
return 0;
194174
}
195-
void funcdef(int polymorphism = 0) override
196-
{
197-
this->args.push_back(Param("src", DataCategory::Tensor, Precision::Any));
198-
this->args.push_back(Param("dst", DataCategory::Tensor, Precision::Any));
199-
}
175+
200176
string math_formula() const override
201177
{
202178
return "T2.data = T1.data";
@@ -209,7 +185,7 @@ namespace deepx::tf
209185
CloneTensor()
210186
{
211187
this->name = "clonetensor";
212-
this->funcdef();
188+
213189
}
214190
int run(mem::Mem &mem, string &error) override
215191
{
@@ -220,11 +196,7 @@ namespace deepx::tf
220196
return 0;
221197
}
222198

223-
void funcdef(int polymorphism = 0) override
224-
{
225-
this->args.push_back(Param("src", DataCategory::Tensor, Precision::Any));
226-
this->args.push_back(Param("dst", DataCategory::Var, Precision::String));
227-
}
199+
228200
string math_formula() const override
229201
{
230202
return "T2 = T1.clone()";
@@ -237,7 +209,7 @@ namespace deepx::tf
237209
DelTensor()
238210
{
239211
this->name = "deltensor";
240-
this->funcdef();
212+
241213
}
242214
DelTensor(string text)
243215
{
@@ -253,10 +225,7 @@ namespace deepx::tf
253225
mem.delete_tensor(name);
254226
return 0;
255227
}
256-
void funcdef(int polymorphism=0) override
257-
{
258-
this->args.push_back(Param("tensor1", DataCategory::Tensor, Precision::Any));
259-
}
228+
260229
string math_formula() const override
261230
{
262231
return "del T1";

excuter/op-mem-cuda/src/deepx/tf/print.hpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,7 @@ namespace deepx::tf
5151
}
5252
return 0;
5353
}
54-
void funcdef(int polymorphism = 0) override
55-
{
56-
this->args.push_back(Param("tensor1", DataCategory::Tensor, Precision::Any));
57-
if (polymorphism == 0)
58-
{
59-
this->args.push_back(Param("format", DataCategory::Var, Precision::String));
60-
}
61-
}
54+
6255
string math_formula() const override
6356
{
6457
return "print(T1)";

excuter/op-mem-ompsimd/src/client/tfs.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,22 @@ namespace deepx::tf
3333
{
3434
Param("name", DataCategory::Vector, Precision::Any),
3535
})));
36-
tffactory.add_tf(std::make_shared<NewTensor>(0));
37-
tffactory.add_tf(std::make_shared<NewTensor>(1));
36+
tffactory.add_tf(std::make_shared<NewTensor>(vector<Param>(
37+
{
38+
Param("shape", DataCategory::Vector, Precision::Int32),
39+
}),
40+
vector<Param>(
41+
{
42+
Param("tensor1", DataCategory::Tensor, Precision::Any),
43+
})));
44+
tffactory.add_tf(std::make_shared<NewTensor>(vector<Param>(
45+
{
46+
Param("shape", DataCategory::Var, Precision::String),
47+
}),
48+
vector<Param>(
49+
{
50+
Param("tensor1", DataCategory::Tensor, Precision::Any),
51+
})));
3852
// opfactory.add_op(DelTensor<float>());
3953
}
4054

excuter/op-mem-ompsimd/src/deepx/tf/arg.hpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ namespace deepx::tf
2828
throw std::runtime_error("Invalid name: " + this->name);
2929
}
3030
}
31-
void funcdef(int polymorphism = 0) override
32-
{
33-
this->args.push_back(Param("value", DataCategory::Var, Precision::Any));
34-
this->returns.push_back(Param("name", DataCategory::Var, Precision::Any));
35-
}
31+
3632
string math_formula() const override
3733
{
3834
return "var argname = argvalue";
@@ -97,11 +93,7 @@ namespace deepx::tf
9793
throw std::runtime_error("Invalid name: " + this->name);
9894
}
9995
}
100-
void funcdef(int polymorphism = 0) override
101-
{
102-
this->args.push_back(Param("shape", DataCategory::Vector, Precision::Any));
103-
this->returns.push_back(Param("name", DataCategory::Vector, Precision::Any));
104-
}
96+
10597
string math_formula() const override
10698
{
10799
return "shape = [3 4 5]";

0 commit comments

Comments
 (0)