77
88#include " deepx/shape.hpp"
99#include " deepx/dtype.hpp"
10+ #include " deepx/tensorbase.hpp"
11+
1012namespace deepx
1113{
12-
13-
14- enum DeviceType
15- {
16- CPU = 0 ,
17- CUDA = 1 ,
18- };
19-
14+ using namespace std ;
15+
2016 template <typename T>
21- struct Tensor
17+ struct Tensor : public TensorBase
2218 {
23- Shape shape;
2419 T *data;
25- DeviceType device;
26-
27-
28- using NewFn = T* (*)(int );
20+
21+ using NewFn = T *(*)(int );
2922 NewFn newer; // 申请内存
30-
23+
3124 using DeleteFn = void (*)(T *);
3225 DeleteFn deleter; // 释放内存
3326
34- using CopyFn = void (*)(T *,T *,int );
27+ using CopyFn = void (*)(T *, T *, int );
3528 CopyFn copyer; // 拷贝内存
3629
37- Tensor ( )=default ;
38- Tensor (const Shape &s):shape(s){
39- shape.dtype =dtype<T>::name ();
30+ Tensor () = default ;
31+ Tensor (const vector<int > &s)
32+ {
33+ shape = Shape (s);
34+ shape.dtype = dtype<T>::name ();
35+ }
36+ Tensor (const Shape &s)
37+ {
38+ shape = s;
39+ shape.dtype = dtype<T>::name ();
4040 }
4141
4242 ~Tensor ()
@@ -52,95 +52,99 @@ namespace deepx
5252 * 该构造函数用于创建一个新的Tensor对象,并将现有Tensor对象的内容复制到新对象中。
5353 * 它会分配新的内存并使用copyer函数将数据从源Tensor复制到新Tensor。
5454 */
55- Tensor (const Tensor<T> &tensor){
56- shape=tensor.shape ;
57- shape.dtype =dtype<T>::name ();
58- device=tensor.device ;
59- newer=tensor.newer ;
60- deleter=tensor.deleter ;
61- copyer=tensor.copyer ;
62-
63- data=newer (shape.size );
64- copyer (tensor.data ,data,tensor.shape .size );
55+ Tensor (const Tensor<T> &tensor)
56+ {
57+ shape = tensor.shape ;
58+ shape.dtype = dtype<T>::name ();
59+ device = tensor.device ;
60+ newer = tensor.newer ;
61+ deleter = tensor.deleter ;
62+ copyer = tensor.copyer ;
63+
64+ data = newer (shape.size );
65+ copyer (tensor.data , data, tensor.shape .size );
6566 }
66-
67+
6768 /* *
6869 * 移动构造
6970 * 该构造函数用于通过转移资源来创建一个新的Tensor对象。
7071 * 它会将源Tensor的资源(如数据指针)转移到新对象中,并将源Tensor的数据指针置为nullptr。
7172 * 这样可以避免不必要的内存分配,提高性能。
7273 */
73-
74-
75- Tensor (Tensor<T> &&other) noexcept {
76- shape=std::move (other.shape );
77- device=other.device ;
7874
79- deleter=other.deleter ;
80- copyer=other.copyer ;
81- newer=other.newer ;
75+ Tensor (Tensor<T> &&other) noexcept
76+ {
77+ shape = std::move (other.shape );
78+ device = other.device ;
79+
80+ deleter = other.deleter ;
81+ copyer = other.copyer ;
82+ newer = other.newer ;
8283
83- data= other.data ;
84+ data = other.data ;
8485
85- other.data = nullptr ;
86+ other.data = nullptr ;
8687
87- other.deleter = nullptr ;
88- other.copyer = nullptr ;
89- other.newer = nullptr ;
88+ other.deleter = nullptr ;
89+ other.copyer = nullptr ;
90+ other.newer = nullptr ;
9091 }
91-
92+
9293 /* *
9394 * 拷贝赋值运算符
9495 * 该运算符用于将一个Tensor对象的内容赋值给另一个Tensor对象。
9596 * 它会先检查自赋值的情况,然后使用copyer函数将数据从源Tensor复制到目标Tensor。
9697 * 需要注意的是,目标Tensor的原有数据会被释放。
9798 */
9899
99- Tensor<T>& operator =(const Tensor<T> &tensor) {
100- if (this ==&tensor)
101- return *this ;
102-
103- shape=tensor.shape ;
104- shape.dtype =dtype<T>::name ();
105- device=tensor.device ;
106- deleter=tensor.deleter ;
107- copyer=tensor.copyer ;
108- newer=tensor.newer ;
109-
110- data=newer (shape.size );
111- if (data!=nullptr ){
100+ Tensor<T> &operator =(const Tensor<T> &tensor)
101+ {
102+ if (this == &tensor)
103+ return *this ;
104+
105+ shape = tensor.shape ;
106+ shape.dtype = dtype<T>::name ();
107+ device = tensor.device ;
108+ deleter = tensor.deleter ;
109+ copyer = tensor.copyer ;
110+ newer = tensor.newer ;
111+
112+ data = newer (shape.size );
113+ if (data != nullptr )
114+ {
112115 deleter (data);
113116 }
114- copyer (tensor.data ,data,tensor.shape .size );
117+ copyer (tensor.data , data, tensor.shape .size );
115118 return *this ;
116119 }
117-
120+
118121 /* *
119122 * 移动赋值运算符
120123 * 该运算符用于将一个Tensor对象的资源转移到另一个Tensor对象。
121124 * 它会先检查自赋值的情况,然后将源Tensor的资源转移到目标Tensor,并将源Tensor的数据指针置为nullptr。
122125 * 这样可以避免不必要的内存分配,提高性能。
123126 */
124- Tensor<T>& operator =( Tensor<T> &&tensor) noexcept {
125- if (this ==&tensor) return *this ;
126- shape=tensor.shape ;
127- shape.dtype =dtype<T>::name ();
128- device=tensor.device ;
129- newer=tensor.newer ;
130- deleter=tensor.deleter ;
131- copyer=tensor.copyer ;
132- if (data!=nullptr ){
127+ Tensor<T> &operator =(Tensor<T> &&tensor) noexcept
128+ {
129+ if (this == &tensor)
130+ return *this ;
131+ shape = tensor.shape ;
132+ shape.dtype = dtype<T>::name ();
133+ device = tensor.device ;
134+ newer = tensor.newer ;
135+ deleter = tensor.deleter ;
136+ copyer = tensor.copyer ;
137+ if (data != nullptr )
138+ {
133139 deleter (data);
134140 }
135- data= tensor.data ;
136- tensor.data = nullptr ;
137- tensor.deleter = nullptr ;
138- tensor.copyer = nullptr ;
139- tensor.newer = nullptr ;
141+ data = tensor.data ;
142+ tensor.data = nullptr ;
143+ tensor.deleter = nullptr ;
144+ tensor.copyer = nullptr ;
145+ tensor.newer = nullptr ;
140146 return *this ;
141147 }
142-
143-
144148 };
145149
146150 // template <typename T>
0 commit comments