diff --git a/cosformer.py b/cosformer.py index d9653ce..9a8fdca 100644 --- a/cosformer.py +++ b/cosformer.py @@ -51,7 +51,7 @@ def get_act_fun(self, act_fun): if act_fun == "relu": return F.relu elif act_fun == "elu": - return 1 + F.elu + return lambda x: 1 + F.elu(x) def forward( self,