Skip to content

Commit 0d5a4ba

Browse files
committed
excuter(cpu/cuda):mem继承,支持bf16
1 parent bdd7377 commit 0d5a4ba

File tree

36 files changed

+1150
-544
lines changed

36 files changed

+1150
-544
lines changed

doc/excuter/op-mem-cuda/list.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
| Operation | Author | Func Def | Math Formula | IR Instruction |
66
|-----------|--------|------------|--------------|----------------|
7-
| constant | miaobyte | constant(tensor<any> t, var<any> value)->() | print(T1) | constant(tensor<any> t, var<any> value)->() |
7+
| uniform | miaobyte | uniform(tensor<any> t, var<any> low, var<any> high, var<int32> seed)->() | uniform(T1,low,high,seed) | uniform(tensor<any> t, var<any> low, var<any> high, var<int32> seed)->() |
8+
| arange | miaobyte | arange(tensor<any> t, var<any> start, var<any> step)->() | arange(T1,start,step) | arange(tensor<any> t, var<any> start, var<any> step)->() |
9+
| constant | miaobyte | constant(tensor<any> t, var<any> value)->() | constant(T1) | constant(tensor<any> t, var<any> value)->() |
810
| print | miaobyte | print(tensor<any> )->() | print(T1) | print(tensor<any> )->() |
911
| print | miaobyte | print(tensor<any> , var<string> )->() | print(T1) | print(tensor<any> , var<string> )->() |
1012
| newtensor | none | newtensor(vector<int32> shape)->(tensor<any> tensor1) | T1 = zeros(shape) | newtensor(vector<int32> shape)->(tensor<any> tensor1) |

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
| Operation | Author | Func Def | Math Formula | IR Instruction |
66
|-----------|--------|------------|--------------|----------------|
77
| concat | none | concat()->() | Tresult = concat([T1, T2...], axis=3) | concat()->() |
8+
| uniform | miaobyte | uniform(tensor<any> t, var<any> low, var<any> high, var<int32> seed)->() | uniform(T1,low,high,seed) | uniform(tensor<any> t, var<any> low, var<any> high, var<int32> seed)->() |
9+
| arange | miaobyte | arange(tensor<any> t, var<any> start, var<any> step)->() | arange(T1,start,step) | arange(tensor<any> t, var<any> start, var<any> step)->() |
810
| constant | miaobyte | constant(tensor<any> t, var<any> value)->() | print(T1) | constant(tensor<any> t, var<any> value)->() |
911
| print | miaobyte | print(tensor<any> )->() | print(T1) | print(tensor<any> )->() |
1012
| print | miaobyte | print(tensor<any> , var<string> )->() | print(T1) | print(tensor<any> , var<string> )->() |

excuter/cpp-common/src/client/udpserver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <sstream>
2-
2+
#include <queue>
3+
34
#include "client/udpserver.hpp"
45

56
namespace client

excuter/cpp-common/src/client/udpserver.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#include <sys/un.h>
88
#include <unistd.h>
99
#include <functional>
10-
#include "deepx/tf/tf.hpp"
1110
#include <queue>
1211

12+
#include "deepx/tf/tf.hpp"
1313
namespace client{
1414
using namespace std;
1515
class udpserver

excuter/cpp-common/src/deepx/dtype.hpp

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef DEEPX_DTYPE_HPP
22
#define DEEPX_DTYPE_HPP
33

4-
#include <typeinfo>
54
#include <string>
65
namespace deepx
76
{
@@ -229,32 +228,7 @@ namespace deepx
229228
return TypeDef(category, precision);
230229
}
231230

232-
// 获取类型对应的Precision
233-
template <typename T>
234-
constexpr Precision precision()
235-
{
236-
if constexpr (std::is_same_v<T, double>)
237-
return Precision::Float64;
238-
else if constexpr (std::is_same_v<T, float>)
239-
return Precision::Float32;
240-
// else if constexpr (std::is_same_v<T, half>) return Precision::Float16;
241-
// else if constexpr (std::is_same_v<T, nv_bfloat16>) return Precision::BFloat16;
242-
else if constexpr (std::is_same_v<T, int64_t>)
243-
return Precision::Int64;
244-
else if constexpr (std::is_same_v<T, int32_t>)
245-
return Precision::Int32;
246-
else if constexpr (std::is_same_v<T, int16_t>)
247-
return Precision::Int16;
248-
else if constexpr (std::is_same_v<T, int8_t>)
249-
return Precision::Int8;
250-
// else if constexpr (std::is_same_v<T, int4_t>) return Precision::Int4;
251-
else if constexpr (std::is_same_v<T, bool>)
252-
return Precision::Bool;
253-
else if constexpr (std::is_same_v<T, std::string>)
254-
return Precision::String;
255-
else
256-
return Precision::Any;
257-
}
231+
258232

259233
// 修改precision_str函数以使用标准命名格式
260234
inline std::string precision_str(Precision p)
Lines changed: 17 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,33 @@
1-
#ifndef DEEPX_MEM_MEM_HPP
2-
#define DEEPX_MEM_MEM_HPP
1+
#ifndef DEEPX_MEM_MEMBASE_HPP
2+
#define DEEPX_MEM_MEMBASE_HPP
33

44
#include <any>
55
#include <unordered_map>
66
#include <vector>
7-
#include <atomic>
87
#include <memory>
9-
#include "deepx/tensor.hpp"
8+
#include "iostream"
109

10+
#include "deepx/tensor.hpp"
1111
namespace deepx::mem
1212
{
1313
using namespace std;
14-
class Mem
14+
15+
class MemBase
1516
{
16-
private:
17+
protected:
1718
unordered_map<string, std::any> args;
18-
1919
std::unordered_map<std::string, std::shared_ptr<TensorBase>> mem;
2020
int tempidx = 0;
21+
2122
public:
22-
Mem() = default;
23-
~Mem() = default;
24-
Mem(const Mem &other)
25-
{
26-
args = other.args;
27-
mem = other.mem;
28-
}
29-
Mem(Mem &&other) noexcept
30-
{
31-
args = std::move(other.args);
32-
mem = std::move(other.mem);
33-
}
34-
Mem &operator=(const Mem &other)
35-
{
36-
args = other.args;
37-
mem = other.mem;
38-
return *this;
39-
}
40-
Mem &operator=(Mem &&other) noexcept
23+
// 基本操作接口
24+
virtual void clear()
4125
{
42-
args = std::move(other.args);
43-
mem = std::move(other.mem);
44-
return *this;
26+
args.clear();
27+
mem.clear();
4528
}
29+
30+
// 通用的arg操作
4631
template <typename T>
4732
void addarg(const string &name, const T value)
4833
{
@@ -136,74 +121,15 @@ namespace deepx::mem
136121
template <typename T>
137122
shared_ptr<Tensor<T>> gettensor(const string &name) const
138123
{
139-
if (mem.find(name)== mem.end())
124+
if (mem.find(name) == mem.end())
140125
{
141126
throw std::runtime_error("tensor not found: " + name);
142127
}
143128
auto ptr = mem.at(name);
144129
return std::static_pointer_cast<Tensor<T>>(ptr);
145130
}
146131

147-
//TODO
148-
shared_ptr<Tensor<void>> gettensor(const string &name) const
149-
{
150-
if (mem.find(name) == mem.end())
151-
{
152-
throw std::runtime_error("tensor not found: " + name);
153-
}
154-
auto ptr = mem.at(name);
155-
auto result = make_shared<Tensor<void>>();
156-
result->shape = ptr->shape;
157-
result->device = ptr->device;
158-
result->deleter = nullptr;
159-
result->copyer = nullptr;
160-
result->newer = nullptr;
161-
162-
switch (ptr->shape.dtype)
163-
{
164-
case Precision::Float64:
165-
{
166-
auto ptr_tensor = std::static_pointer_cast<Tensor<double>>(ptr);
167-
result->data = ptr_tensor->data;
168-
break;
169-
}
170-
case Precision::Float32:
171-
{
172-
auto ptr_tensor = std::static_pointer_cast<Tensor<float>>(ptr);
173-
result->data = ptr_tensor->data;
174-
break;
175-
}
176-
case Precision::Int64:
177-
{
178-
auto ptr_tensor = std::static_pointer_cast<Tensor<int64_t>>(ptr);
179-
result->data = ptr_tensor->data;
180-
break;
181-
}
182-
case Precision::Int32:
183-
{
184-
auto ptr_tensor = std::static_pointer_cast<Tensor<int32_t>>(ptr);
185-
result->data = ptr_tensor->data;
186-
break;
187-
}
188-
case Precision::Int16:
189-
{
190-
auto ptr_tensor = std::static_pointer_cast<Tensor<int16_t>>(ptr);
191-
result->data = ptr_tensor->data;
192-
break;
193-
}
194-
case Precision::Int8:
195-
{
196-
auto ptr_tensor = std::static_pointer_cast<Tensor<int8_t>>(ptr);
197-
result->data = ptr_tensor->data;
198-
break;
199-
}
200-
201-
default:
202-
throw std::runtime_error("Unsupported dtype: " + precision_str(ptr->shape.dtype));
203-
}
204-
205-
return result;
206-
}
132+
virtual shared_ptr<Tensor<void>> gettensor(const string &name) const = 0;
207133

208134
// 获取多个张量
209135
template <typename T>
@@ -225,7 +151,6 @@ namespace deepx::mem
225151
return tensors;
226152
}
227153

228-
229154
void delete_tensor(const string &name)
230155
{
231156
mem.erase(name);
@@ -235,11 +160,6 @@ namespace deepx::mem
235160
{
236161
args.erase(name);
237162
}
238-
void clear()
239-
{
240-
args.clear();
241-
mem.clear();
242-
};
243163
};
244164
}
245-
#endif // DEEPX_MEM_MEM_HPP
165+
#endif // DEEPX_MEM_MEMBASE_HPP

excuter/cpp-common/src/deepx/tensor.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ namespace deepx
3131
Tensor(const vector<int> &s)
3232
{
3333
shape = Shape(s);
34-
shape.dtype = precision<T>();
3534
}
3635
Tensor(const Shape &s)
3736
{
3837
shape = s;
39-
shape.dtype = precision<T>();
4038
}
4139

4240
~Tensor()
@@ -55,7 +53,6 @@ namespace deepx
5553
Tensor(const Tensor<T> &tensor)
5654
{
5755
shape = tensor.shape;
58-
shape.dtype = precision<T>();
5956
device = tensor.device;
6057
newer = tensor.newer;
6158
deleter = tensor.deleter;
@@ -103,7 +100,6 @@ namespace deepx
103100
return *this;
104101

105102
shape = tensor.shape;
106-
shape.dtype = precision<T>();
107103
device = tensor.device;
108104
deleter = tensor.deleter;
109105
copyer = tensor.copyer;
@@ -129,7 +125,6 @@ namespace deepx
129125
if (this == &tensor)
130126
return *this;
131127
shape = tensor.shape;
132-
shape.dtype = precision_str(precision<T>());
133128
device = tensor.device;
134129
newer = tensor.newer;
135130
deleter = tensor.deleter;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class ArgSet : public TF
2+
{
3+
public:
4+
// ... 其他现有代码 ...
5+
6+
shared_ptr<TF> clone() const override {
7+
return make_shared<ArgSet>(*this);
8+
}
9+
};
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class NewTensor : public TF
2+
{
3+
public:
4+
// ... 其他现有代码 ...
5+
6+
shared_ptr<TF> clone() const override {
7+
return make_shared<NewTensor>(*this);
8+
}
9+
};

excuter/cpp-common/src/deepx/tf/tf.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,25 @@ namespace deepx::tf
1919
size_t arrow_pos = body.find("->");
2020
if (arrow_pos == string::npos)
2121
{
22-
throw runtime_error("Invalid IR format: missing arrow");
22+
throw runtime_error("Invalid IR format: missing arrow");
2323
}
2424

2525
// 获取输入和输出部分的原始字符串
2626
string input_part = body.substr(0, arrow_pos);
2727
string output_part = body.substr(arrow_pos + 2);
2828

29-
// 提取函数名
29+
// 提取函数名 - 修改这部分逻辑
3030
size_t space_pos = input_part.find(' ');
3131
size_t paren_pos = input_part.find('(');
32-
size_t name_end = std::min(
33-
space_pos != string::npos ? space_pos : input_part.length(),
34-
paren_pos != string::npos ? paren_pos : input_part.length());
32+
size_t name_end;
33+
34+
if (paren_pos != string::npos && (space_pos == string::npos || paren_pos < space_pos)) {
35+
// 如果有括号且括号在空格之前,使用括号位置
36+
name_end = paren_pos;
37+
} else {
38+
// 否则使用空格位置或字符串末尾
39+
name_end = space_pos != string::npos ? space_pos : input_part.length();
40+
}
3541
string func_name = input_part.substr(0, name_end);
3642

3743
// 处理输入部分,去掉函数名

0 commit comments

Comments
 (0)