1- #ifndef DEEPX_MEM_MEM_HPP
2- #define DEEPX_MEM_MEM_HPP
1+ #ifndef DEEPX_MEM_MEMBASE_HPP
2+ #define DEEPX_MEM_MEMBASE_HPP
33
44#include < any>
55#include < unordered_map>
66#include < vector>
7- #include < atomic>
87#include < memory>
9- #include " deepx/tensor.hpp "
8+ #include " iostream "
109
10+ #include " deepx/tensor.hpp"
1111namespace deepx ::mem
1212{
1313 using namespace std ;
14- class Mem
14+
15+ class MemBase
1516 {
16- private :
17+ protected :
1718 unordered_map<string, std::any> args;
18-
1919 std::unordered_map<std::string, std::shared_ptr<TensorBase>> mem;
2020 int tempidx = 0 ;
21+
2122 public:
22- Mem () = default ;
23- ~Mem () = default ;
24- Mem (const Mem &other)
25- {
26- args = other.args ;
27- mem = other.mem ;
28- }
29- Mem (Mem &&other) noexcept
30- {
31- args = std::move (other.args );
32- mem = std::move (other.mem );
33- }
34- Mem &operator =(const Mem &other)
35- {
36- args = other.args ;
37- mem = other.mem ;
38- return *this ;
39- }
40- Mem &operator =(Mem &&other) noexcept
23+ // 基本操作接口
24+ virtual void clear ()
4125 {
42- args = std::move (other.args );
43- mem = std::move (other.mem );
44- return *this ;
26+ args.clear ();
27+ mem.clear ();
4528 }
29+
30+ // 通用的arg操作
4631 template <typename T>
4732 void addarg (const string &name, const T value)
4833 {
@@ -136,74 +121,15 @@ namespace deepx::mem
136121 template <typename T>
137122 shared_ptr<Tensor<T>> gettensor (const string &name) const
138123 {
139- if (mem.find (name)== mem.end ())
124+ if (mem.find (name) == mem.end ())
140125 {
141126 throw std::runtime_error (" tensor not found: " + name);
142127 }
143128 auto ptr = mem.at (name);
144129 return std::static_pointer_cast<Tensor<T>>(ptr);
145130 }
146131
147- // TODO
148- shared_ptr<Tensor<void >> gettensor (const string &name) const
149- {
150- if (mem.find (name) == mem.end ())
151- {
152- throw std::runtime_error (" tensor not found: " + name);
153- }
154- auto ptr = mem.at (name);
155- auto result = make_shared<Tensor<void >>();
156- result->shape = ptr->shape ;
157- result->device = ptr->device ;
158- result->deleter = nullptr ;
159- result->copyer = nullptr ;
160- result->newer = nullptr ;
161-
162- switch (ptr->shape .dtype )
163- {
164- case Precision::Float64:
165- {
166- auto ptr_tensor = std::static_pointer_cast<Tensor<double >>(ptr);
167- result->data = ptr_tensor->data ;
168- break ;
169- }
170- case Precision::Float32:
171- {
172- auto ptr_tensor = std::static_pointer_cast<Tensor<float >>(ptr);
173- result->data = ptr_tensor->data ;
174- break ;
175- }
176- case Precision::Int64:
177- {
178- auto ptr_tensor = std::static_pointer_cast<Tensor<int64_t >>(ptr);
179- result->data = ptr_tensor->data ;
180- break ;
181- }
182- case Precision::Int32:
183- {
184- auto ptr_tensor = std::static_pointer_cast<Tensor<int32_t >>(ptr);
185- result->data = ptr_tensor->data ;
186- break ;
187- }
188- case Precision::Int16:
189- {
190- auto ptr_tensor = std::static_pointer_cast<Tensor<int16_t >>(ptr);
191- result->data = ptr_tensor->data ;
192- break ;
193- }
194- case Precision::Int8:
195- {
196- auto ptr_tensor = std::static_pointer_cast<Tensor<int8_t >>(ptr);
197- result->data = ptr_tensor->data ;
198- break ;
199- }
200-
201- default :
202- throw std::runtime_error (" Unsupported dtype: " + precision_str (ptr->shape .dtype ));
203- }
204-
205- return result;
206- }
132+ virtual shared_ptr<Tensor<void >> gettensor (const string &name) const = 0;
207133
208134 // 获取多个张量
209135 template <typename T>
@@ -225,7 +151,6 @@ namespace deepx::mem
225151 return tensors;
226152 }
227153
228-
229154 void delete_tensor (const string &name)
230155 {
231156 mem.erase (name);
@@ -235,11 +160,6 @@ namespace deepx::mem
235160 {
236161 args.erase (name);
237162 }
238- void clear ()
239- {
240- args.clear ();
241- mem.clear ();
242- };
243163 };
244164}
245- #endif // DEEPX_MEM_MEM_HPP
165+ #endif // DEEPX_MEM_MEMBASE_HPP
0 commit comments