Skip to content

Commit 72c0732

Browse files
committed
py:deepx.tensor
1 parent 4f003fa commit 72c0732

16 files changed

Lines changed: 225 additions & 153 deletions

File tree

front/go/deepx/attention.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ func (m *MultiHeadAttention) Forward(q, k, v *Tensor) *Tensor {
6565

6666
// 5. Scale
6767
d_k := float32(head_dim)
68-
scores = scores.Scale(1.0 / float32(math.Sqrt(float64(d_k))))
68+
scores = scores.MulScalar(1.0 / float32(math.Sqrt(float64(d_k))))
6969

7070
// 6. Softmax
71-
attn := scores.Softmax()
71+
attn := scores.Softmax(1)
7272

7373
// 7. 加权求和
7474
out := attn.Matmul(value)

front/go/deepx/graph.go

Lines changed: 0 additions & 84 deletions
This file was deleted.

front/go/deepx/norm.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ func NewLayerNorm(name string, normalized_shape int, dtype Dtype, g *Graph) *Lay
2222
}
2323

2424
func (m *LayerNorm) LayerNorm(x *Tensor) *Tensor {
25-
op := x.LayerNorm(m.weight, m.bias)
26-
return op
25+
26+
return nil
2727
}
2828

2929
// BatchNorm 批归一化
@@ -52,8 +52,7 @@ func NewBatchNorm(name string, num_features int, dtype Dtype, g *Graph) *BatchNo
5252
}
5353

5454
func (m *BatchNorm) BatchNorm(x *Tensor) *Tensor {
55-
op := x.BatchNorm(m.weight, m.bias, m.running_mean, m.running_var)
56-
return op
55+
return nil
5756
}
5857

5958
// InstanceNorm 实例归一化
@@ -78,8 +77,7 @@ func NewInstanceNorm(name string, num_features int, dtype Dtype, g *Graph) *Inst
7877
}
7978

8079
func (m *InstanceNorm) InstanceNorm(x *Tensor) *Tensor {
81-
op := x.InstanceNorm(m.weight, m.bias)
82-
return op
80+
return nil
8381
}
8482

8583
// GroupNorm 组归一化
@@ -110,8 +108,7 @@ func NewGroupNorm(name string, num_groups, num_channels int, dtype Dtype, g *Gra
110108
}
111109

112110
func (m *GroupNorm) GroupNorm(x *Tensor) *Tensor {
113-
op := x.GroupNorm(m.weight, m.bias, m.num_groups)
114-
return op
111+
return nil
115112
}
116113

117114
// RMSNorm Root Mean Square Layer Normalization
@@ -137,6 +134,5 @@ func NewRMSNorm(name string, normalized_shape int, dtype Dtype, g *Graph) *RMSNo
137134
}
138135

139136
func (m *RMSNorm) RMSNorm(x *Tensor) *Tensor {
140-
op := x.RMSNorm(m.weight, m.eps)
141-
return op
137+
return nil
142138
}

front/py/deepx/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from .tensor import Tensor
1+
from .tensor import Tensor,Shape,Device,DeviceType
2+
from .tensor import zeros, ones, arange, rand, randn, eye
23

3-
__all__ = ['Tensor']
4+
__all__ = [
5+
'Tensor',
6+
'Shape',
7+
'Device','DeviceType',
8+
'zeros', 'ones', 'arange', 'rand', 'randn', 'eye'
9+
]
10+
11+
# 为了支持 import deepx as dx 的用法
412
tensor = Tensor
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .graph import Graph
2+
3+
__all__ = ['Graph']

front/py/deepx/autograd/graph.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
1-
from ..tensor import Tensor
2-
from .tensornode import TensorNode
3-
from .opnode import OpNode
4-
from .constargnode import ConstArgNode
1+
from ..tensor.tensor import Tensor
2+
from ._tensornode import TensorNode
3+
from ._opnode import OpNode
4+
from ._constargnode import ConstArgNode
5+
56
class Graph:
7+
# 类属性存储默认实例
8+
_default_graph = None
9+
10+
@classmethod
11+
def get_default(cls):
12+
"""获取或创建默认计算图(线程不安全)"""
13+
if cls._default_graph is None:
14+
cls._default_graph = Graph()
15+
return cls._default_graph
16+
17+
@classmethod
18+
def set_default(cls, graph):
19+
"""设置新的默认计算图(用于上下文管理)"""
20+
if not isinstance(graph, Graph):
21+
raise TypeError("Must be a Graph instance")
22+
cls._default_graph = graph
23+
624
def __init__(self):
725
self.nodes = []
826
self.inputs = []
@@ -29,4 +47,7 @@ def add_constarg(self, value):
2947
name = f"constarg_{self.constarg_counter}"
3048
node=ConstArgNode(value)
3149
self.nodes.append(node)
32-
return node
50+
return node
51+
52+
# 初始化默认图
53+
Graph._default_graph = Graph()

front/py/deepx/autograd/graph_viz.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import graphviz
22
from .graph import Graph
3-
from .tensornode import TensorNode
4-
from .opnode import OpNode
5-
from .constargnode import ConstArgNode
3+
from ._tensornode import TensorNode
4+
from ._opnode import OpNode
5+
from ._constargnode import ConstArgNode
66

77
def graph_method(f):
88
"""装饰器:将函数注册为Graph类的方法"""

0 commit comments

Comments
 (0)