Skip to content

Commit 73addbb

Browse files
authored
attention:rotatehalf ok (#73)
* repeat:暂停开发repeat_interleave * llama_rope:分离共享的代码 * cuda:修复int64类型的调用错误 * py:rotate_half 验证完成 * attention:rotatehalf
1 parent 104c7b3 commit 73addbb

14 files changed

Lines changed: 102 additions & 67 deletions

File tree

excuter/cpp-common/src/deepx/tensorfunc/changeshape.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,19 @@ namespace deepx::tensorfunc
9292
template <typename Author, typename T>
9393
struct repeat_interleaveDispatcher
9494
{
95-
static void repeat_interleave(const Tensor<T> &A, const int repeats, Tensor<T> &B) = delete;
96-
static void repeat_interleave(const Tensor<T> &A, const Tensor<T> &repeats, Tensor<T> &B) = delete;
95+
static void repeat_interleave(const Tensor<T> &A, const int repeats,const int dim, Tensor<T> &B) = delete;
96+
// static void repeat_interleave(const Tensor<T> &A, const Tensor<T> &repeats, Tensor<T> &B) = delete;
9797
};
9898
template <typename Author, typename T>
99-
void repeat_interleave(const Tensor<T> &A, const int repeats, Tensor<T> &B)
99+
void repeat_interleave(const Tensor<T> &A, const int repeats,const int dim, Tensor<T> &B)
100100
{
101-
repeat_interleaveDispatcher<Author, T>::repeat_interleave(A, repeats, B);
102-
}
103-
template <typename Author, typename T>
104-
void repeat_interleave(const Tensor<T> &A, const Tensor<T> &repeats, Tensor<T> &B)
105-
{
106-
repeat_interleaveDispatcher<Author, T>::repeat_interleave(A, repeats, B);
101+
repeat_interleaveDispatcher<Author, T>::repeat_interleave(A, repeats,dim, B);
107102
}
103+
// template <typename Author, typename T>
104+
// void repeat_interleave(const Tensor<T> &A, const Tensor<T> &repeats, Tensor<T> &B)
105+
// {
106+
// repeat_interleaveDispatcher<Author, T>::repeat_interleave(A, repeats, B);
107+
// }
108108

109109

110110

excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,13 @@ namespace deepx::tensorfunc
8181
const int *repeats,
8282
T *output, const int *outputStrides, const int outputlen,
8383
const int dim);
84+
85+
// repeat_interleave
86+
template <int DIM, typename T>
87+
__global__ void repeat_interleave_kernel(
88+
const T *input, const int *inputStrides,
89+
const int *repeats,
90+
T *output, const int *outputStrides, const int outputlen,
91+
const int dim);
8492
};
8593
#endif // DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_CUH

excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,6 @@ namespace deepx::tensorfunc
152152
B.data, B.shape.strides.data(),B.shape.size, B.shape.dim());
153153
}
154154
};
155+
155156
}
156157
#endif // DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_HPP

excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ namespace deepx::tf
411411
tensorfunc::add<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), *mem->gettensor<nv_bfloat16>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
412412
break;
413413
case Precision::Int64:
414-
tensorfunc::add<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
414+
tensorfunc::add<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
415415
break;
416416
case Precision::Int32:
417417
tensorfunc::add<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -479,7 +479,7 @@ namespace deepx::tf
479479
tensorfunc::addscalar<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), this->getvar<nv_bfloat16>(1, mem), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
480480
break;
481481
case Precision::Int64:
482-
tensorfunc::addscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
482+
tensorfunc::addscalar<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), this->getvar<int64_t>(1, mem), *mem->gettensor<int64_t>(this->returns[0].textvalue));
483483
break;
484484
case Precision::Int32:
485485
tensorfunc::addscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -548,7 +548,7 @@ namespace deepx::tf
548548
tensorfunc::sub<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), *mem->gettensor<nv_bfloat16>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
549549
break;
550550
case Precision::Int64:
551-
tensorfunc::sub<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
551+
tensorfunc::sub<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
552552
break;
553553
case Precision::Int32:
554554
tensorfunc::sub<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -616,7 +616,7 @@ namespace deepx::tf
616616
tensorfunc::subscalar<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), this->getvar<nv_bfloat16>(1, mem), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
617617
break;
618618
case Precision::Int64:
619-
tensorfunc::subscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
619+
tensorfunc::subscalar<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), this->getvar<int64_t>(1, mem), *mem->gettensor<int64_t>(this->returns[0].textvalue));
620620
break;
621621
case Precision::Int32:
622622
tensorfunc::subscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -685,7 +685,7 @@ namespace deepx::tf
685685
tensorfunc::rsubscalar<Author, nv_bfloat16>(this->getvar<nv_bfloat16>(1, mem), *mem->gettensor<nv_bfloat16>(this->args[0].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
686686
break;
687687
case Precision::Int64:
688-
tensorfunc::rsubscalar<Author, int32_t>(this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
688+
tensorfunc::rsubscalar<Author, int64_t>(this->getvar<int64_t>(1, mem), *mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
689689
break;
690690
case Precision::Int32:
691691
tensorfunc::rsubscalar<Author, int32_t>(this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -754,7 +754,7 @@ namespace deepx::tf
754754
tensorfunc::mul<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), *mem->gettensor<nv_bfloat16>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
755755
break;
756756
case Precision::Int64:
757-
tensorfunc::mul<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
757+
tensorfunc::mul<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
758758
break;
759759
case Precision::Int32:
760760
tensorfunc::mul<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -822,7 +822,7 @@ namespace deepx::tf
822822
tensorfunc::mulscalar<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), this->getvar<nv_bfloat16>(1, mem), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
823823
break;
824824
case Precision::Int64:
825-
tensorfunc::mulscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
825+
tensorfunc::mulscalar<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), this->getvar<int64_t>(1, mem), *mem->gettensor<int64_t>(this->returns[0].textvalue));
826826
break;
827827
case Precision::Int32:
828828
tensorfunc::mulscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -891,7 +891,7 @@ namespace deepx::tf
891891
tensorfunc::div<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), *mem->gettensor<nv_bfloat16>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
892892
break;
893893
case Precision::Int64:
894-
tensorfunc::div<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
894+
tensorfunc::div<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
895895
break;
896896
case Precision::Int32:
897897
tensorfunc::div<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -959,7 +959,7 @@ namespace deepx::tf
959959
tensorfunc::divscalar<Author, nv_bfloat16>(*mem->gettensor<nv_bfloat16>(this->args[0].textvalue), this->getvar<nv_bfloat16>(1, mem), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
960960
break;
961961
case Precision::Int64:
962-
tensorfunc::divscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
962+
tensorfunc::divscalar<Author, int64_t>(*mem->gettensor<int64_t>(this->args[0].textvalue), this->getvar<int64_t>(1, mem), *mem->gettensor<int64_t>(this->returns[0].textvalue));
963963
break;
964964
case Precision::Int32:
965965
tensorfunc::divscalar<Author, int32_t>(*mem->gettensor<int32_t>(this->args[0].textvalue), this->getvar<int32_t>(1, mem), *mem->gettensor<int32_t>(this->returns[0].textvalue));
@@ -1027,7 +1027,7 @@ namespace deepx::tf
10271027
tensorfunc::rdivscalar<Author, nv_bfloat16>(this->getvar<nv_bfloat16>(0, mem), *mem->gettensor<nv_bfloat16>(this->args[1].textvalue), *mem->gettensor<nv_bfloat16>(this->returns[0].textvalue));
10281028
break;
10291029
case Precision::Int64:
1030-
tensorfunc::rdivscalar<Author, int32_t>(this->getvar<int32_t>(0, mem), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));
1030+
tensorfunc::rdivscalar<Author, int64_t>(this->getvar<int64_t>(0, mem), *mem->gettensor<int64_t>(this->args[1].textvalue), *mem->gettensor<int64_t>(this->returns[0].textvalue));
10311031
break;
10321032
case Precision::Int32:
10331033
tensorfunc::rdivscalar<Author, int32_t>(this->getvar<int32_t>(0, mem), *mem->gettensor<int32_t>(this->args[1].textvalue), *mem->gettensor<int32_t>(this->returns[0].textvalue));

front/py/deepx/tensor/changeshape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def broadcast_to(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor:
5555
return result
5656

5757
@tensor_method
58-
def indexselect(self,index:Tensor,axis:int=0,out:Union[Tensor,str]='')->Tensor:
58+
def indexselect(self,index:Tensor,gatheraxis:int=0,out:Union[Tensor,str]='')->Tensor:
5959
assert isinstance(index,Tensor)
60-
gatheraxis=axis%self.ndim
60+
gatheraxis=gatheraxis%self.ndim
6161
from deepx.nn.functional import indexselect as indexselect_func
6262
result=indexselect_func(self,index,gatheraxis,out)
6363
return result

front/py/deepx/tensor/tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def __mul__(self, other:Union[Number,'Tensor']):
124124
return self.mul(other)
125125
def __rmul__(self, other:Union[Number,'Tensor']):
126126
return self.mul(other)
127+
def __neg__(self):
128+
return self.mul(-1.0)
127129
def __truediv__(self, other:Union[Number,'Tensor']):
128130
return self.div(other)
129131
def __rtruediv__(self, other:Union[Number,'Tensor']):
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from .embedding import *
2+
from .attention import *
3+
24
__all__ = [
3-
"LlamaRotaryEmbedding"
5+
"LlamaRotaryEmbedding",
6+
"rotate_half"
47
]

front/py/deepx/transformer/models/llama/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
def rotate_half(x:Tensor):
88
index_front=arange(0,x.shape[-1]//2,dtype="int32")
99
index_back=arange(x.shape[-1]//2,x.shape[-1],dtype="int32")
10-
x1 = x.index_select(dim=-1,index=index_front)
11-
x2 = x.index_select(dim=-1,index=index_back)
12-
return concat((-x2, x1), dim=-1)
10+
x1 = x.indexselect(gatheraxis=-1,index=index_front)
11+
x2 = x.indexselect(gatheraxis=-1,index=index_back)
12+
return concat((-x2, x1,), dim=-1)
1313

1414
def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1):
1515
cos = cos.unsqueeze(unsqueeze_dim)

front/py/examples/4_transformer/llama/llama_

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from token_text import dir
2+
3+
############-------DEEPX-------################
4+
from deepx import load
5+
from deepx.transformer.models.llama import rotate_half
6+
7+
input=load(dir+'input')
8+
input.print()
9+
r=rotate_half(input)
10+
r.print()
11+

0 commit comments

Comments
 (0)