Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
history = history_file.read()

requirements = [

]

setup_requirements = [
Expand Down
86 changes: 42 additions & 44 deletions torchcomplex/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _fcaller(funtinal_handle, *args):
else:
b_r = None
b_i = None

# Perform complex valued convolution
if type(args[0]) is tuple: #only incase of bilinear
MrKr = funtinal_handle(inp1_r, inp2_r, w_r, b_r, *args[3:]) #Real Feature Maps *(conv) Real Kernels
Expand All @@ -76,9 +76,7 @@ def _fcaller(funtinal_handle, *args):
MiKr = funtinal_handle(inp_i, w_r, b_r, *args[3:]) #Imaginary Feature Maps * Real Kernels
real = MrKr - MiKi
imag = MrKi + MiKr
out = torch.view_as_complex(torch.stack((real,imag),dim=-1))

return out
return torch.view_as_complex(torch.stack((real,imag),dim=-1))

#Convolutions

Expand Down Expand Up @@ -224,21 +222,22 @@ def _whiten2x2(tensor, training=True, running_mean=None, running_cov=None,
p, q = (cov_vv + sqrdet) / denom, -cov_uv / denom
r, s = -cov_vu / denom, (cov_uu + sqrdet) / denom

# 4. apply Q to x (manually)
out = torch.stack([
tensor[0] * p.reshape(tail) + tensor[1] * r.reshape(tail),
tensor[0] * q.reshape(tail) + tensor[1] * s.reshape(tail),
], dim=0)
return out # , torch.cat([p, q, r, s], dim=0).reshape(2, 2, -1)
return torch.stack(
[
tensor[0] * p.reshape(tail) + tensor[1] * r.reshape(tail),
tensor[0] * q.reshape(tail) + tensor[1] * s.reshape(tail),
],
dim=0,
)

def batch_norm(input, running_mean, running_var, weight=None, bias=None,
training=False, momentum=0.1, eps=1e-5, naive=False):

"""
Source: Source: https://github.com/ivannz/cplxmodule/blob/master/cplxmodule/nn/modules/batchnorm.py
"""
complex_weight = not(type(weight) == torch.nn.ParameterList)
if naive:
complex_weight = type(weight) != torch.nn.ParameterList
real = F.batch_norm(input.real,
running_mean[0] if running_mean is not None else None,
running_var[0] if running_var is not None else None,
Expand Down Expand Up @@ -287,7 +286,7 @@ def zrelu(input: Tensor, inplace: bool = False) -> Tensor:
https://arxiv.org/pdf/1705.09792.pdf
'''
if input.is_complex():
return input * ((0 < input.angle()) * (input.angle() < math.pi/2)).float()
return input * ((input.angle() > 0) * (input.angle() < math.pi/2)).float()
else:
return F.relu(input, inplace=inplace)

Expand All @@ -297,22 +296,20 @@ def modrelu(input: Tensor, bias: int, inplace: bool = False) -> Tensor:
Notice that |z| (z.magnitude) is always positive, so if b > 0 then |z| + b > = 0 always.
In order to have any non-linearity effect, b must be smaller than 0 (b<0).
'''
if input.is_complex():
z_mag = torch.abs(input)
return input * ((z_mag + bias) >= 0).float() * (1 + bias / z_mag)
else:
if not input.is_complex():
return F.relu(input, inplace=inplace)
z_mag = torch.abs(input)
return input * ((z_mag + bias) >= 0).float() * (1 + bias / z_mag)

def cmodrelu(input: Tensor, threshold: int, inplace: bool = False):
r"""Compute the Complex modulus relu of the complex tensor in re-im pair.
As proposed in : https://arxiv.org/pdf/1802.08026.pdf
Source: https://github.com/ivannz/cplxmodule"""
if input.is_complex():
modulus = torch.clamp(torch.abs(input), min=1e-5)
_tmp_newshape = (1,len(threshold)) + (1,)*len(input.shape[2:])
return input * F.relu(1. - threshold.view(_tmp_newshape) / modulus)
else:
if not input.is_complex():
return F.relu(input, inplace=inplace)
modulus = torch.clamp(torch.abs(input), min=1e-5)
_tmp_newshape = (1,len(threshold)) + (1,)*len(input.shape[2:])
return input * F.relu(1. - threshold.view(_tmp_newshape) / modulus)

def softmax(input, dim=None, _stacklevel=3, dtype=None):
'''
Expand All @@ -326,35 +323,32 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)

def tanh(input: Tensor):
if input.is_complex():
a, b = input.real, input.imag
denominator = torch.cosh(2*a) + torch.cos(2*b)
real = torch.sinh(2 * a) / denominator
imag = torch.sin(2 * a) / denominator
return torch.view_as_complex(torch.stack((real, imag),dim=-1))
else:
if not input.is_complex():
return F.tanh(input)
a, b = input.real, input.imag
denominator = torch.cosh(2*a) + torch.cos(2*b)
real = torch.sinh(2 * a) / denominator
imag = torch.sin(2 * a) / denominator
return torch.view_as_complex(torch.stack((real, imag),dim=-1))

def sigmoid(input: Tensor):
if input.is_complex():
a, b = input.real, input.imag
denominator = 1 + 2 * torch.exp(-a) * torch.cos(b) + torch.exp(-2 * a)
real = 1 + torch.exp(-a) * torch.cos(b) / denominator
imag = torch.exp(-a) * torch.sin(b) / denominator
return torch.view_as_complex(torch.stack((real, imag),dim=-1))
else:
if not input.is_complex():
return F.sigmoid(input)
a, b = input.real, input.imag
denominator = 1 + 2 * torch.exp(-a) * torch.cos(b) + torch.exp(-2 * a)
real = 1 + torch.exp(-a) * torch.cos(b) / denominator
imag = torch.exp(-a) * torch.sin(b) / denominator
return torch.view_as_complex(torch.stack((real, imag),dim=-1))

def _sinc_interpolate(input, size):
axes = np.argwhere(np.equal(input.shape[2:], size) == False).squeeze(1) #2 dims for batch and channel
out_shape = [size[i] for i in axes]
return resample(input, out_shape, axis=axes+2) #2 dims for batch and channel

def interpolate(input, size=None, scale_factor=None, mode='sinc', align_corners=None, recompute_scale_factor=None):
if mode in ('nearest', 'area', 'sinc'):
if align_corners is not None:
raise ValueError("align_corners option can only be set with the "
"interpolating modes: linear | bilinear | bicubic | trilinear")
if mode in ('nearest', 'area', 'sinc') and align_corners is not None:
raise ValueError("align_corners option can only be set with the "
"interpolating modes: linear | bilinear | bicubic | trilinear")

dim = input.dim() - 2 # Number of spatial dimensions.

Expand All @@ -369,8 +363,10 @@ def interpolate(input, size=None, scale_factor=None, mode='sinc', align_corners=
scale_factors = None
if isinstance(size, (list, tuple)):
if len(size) != dim:
raise ValueError('size shape must match input shape. '
'Input is {}D, size is {}'.format(dim, len(size)))
raise ValueError(
f'size shape must match input shape. Input is {dim}D, size is {len(size)}'
)

output_size = size
else:
output_size = [size for _ in range(dim)]
Expand All @@ -379,8 +375,10 @@ def interpolate(input, size=None, scale_factor=None, mode='sinc', align_corners=
output_size = None
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
raise ValueError('scale_factor shape must match input shape. '
'Input is {}D, scale_factor is {}'.format(dim, len(scale_factor)))
raise ValueError(
f'scale_factor shape must match input shape. Input is {dim}D, scale_factor is {len(scale_factor)}'
)

scale_factors = scale_factor
else:
scale_factors = [scale_factor for _ in range(dim)]
Expand All @@ -404,7 +402,7 @@ def interpolate(input, size=None, scale_factor=None, mode='sinc', align_corners=

# "area" and "sinc" modes always require an explicit size rather than scale factor.
# Re-use the recompute_scale_factor code path.
if (mode == "area" or mode == "sinc") and output_size is None:
if mode in ["area", "sinc"] and output_size is None:
recompute_scale_factor = True

if recompute_scale_factor is not None and recompute_scale_factor:
Expand Down
15 changes: 5 additions & 10 deletions torchcomplex/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ def _postprocess(cls, tensor):
if cls.complex_weight:
return Parameter(tensor[0] + 1j*tensor[1])
else:
if type(tensor) is ParameterList:
return tensor
else:
return ParameterList(tensor)
return tensor if type(tensor) is ParameterList else ParameterList(tensor)

# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
Expand Down Expand Up @@ -272,9 +269,7 @@ def _calculate_fan_in_and_fan_out(tensor):

num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
receptive_field_size = tensor[0][0].numel() if tensor.dim() > 2 else 1
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size

Expand Down Expand Up @@ -336,7 +331,7 @@ def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")

fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
Expand Down Expand Up @@ -485,7 +480,7 @@ def trabelsi_standard_(tensor, kind="glorot"):
tensor = _tensorprocessor._preprocess(tensor)

fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor[0])
if kind == "glorot" or kind == "xavier":
if kind in ["glorot", "xavier"]:
scale = 1 / math.sqrt(fan_in + fan_out)
else:
scale = 1 / math.sqrt(fan_in)
Expand Down Expand Up @@ -524,7 +519,7 @@ def trabelsi_independent_(tensor, kind="glorot"):
M = np.dot(u[:, :k], vh[:, :k].conjugate().T)

fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor[0])
if kind == "glorot" or kind == "xavier":
if kind in ["glorot", "xavier"]:
scale = 1 / math.sqrt(fan_in + fan_out)
else:
scale = 1 / math.sqrt(fan_in)
Expand Down
17 changes: 6 additions & 11 deletions torchcomplex/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def forward(self, input: Tensor) -> Tensor:
return cF.crelu(input, inplace=self.inplace)

def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
return 'inplace=True' if self.inplace else ''

class zReLU(Module):
'''
Expand All @@ -61,8 +60,7 @@ def forward(self, input: Tensor) -> Tensor:
return cF.zrelu(input, inplace=self.inplace)

def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
return 'inplace=True' if self.inplace else ''

class modReLU(Module):
'''
Expand All @@ -85,8 +83,7 @@ def forward(self, input: Tensor) -> Tensor:
return cF.modrelu(input, bias=self.bias, inplace=self.inplace)

def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
return 'inplace=True' if self.inplace else ''

class CmodReLU(Module):
'''Compute the Complex modulus relu of the complex tensor in re-im pair.
Expand All @@ -108,8 +105,7 @@ def forward(self, input: Tensor) -> Tensor:
return cF.cmodrelu(input, threshold=self.threshold, inplace=self.inplace)

def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
return 'inplace=True' if self.inplace else ''

class AdaptiveCmodReLU(Module):
'''Compute the Complex modulus relu of the complex tensor in re-im pair.
Expand All @@ -123,15 +119,14 @@ class AdaptiveCmodReLU(Module):
def __init__(self, *dim, inplace: bool = False):
super(AdaptiveCmodReLU, self).__init__()
self.inplace = inplace
self.dim = dim if dim else (1,)
self.dim = dim or (1,)
self.threshold = Parameter(torch.randn(*self.dim) * 0.02)

def forward(self, input: Tensor) -> Tensor:
return cF.cmodrelu(input, threshold=self.threshold, inplace=self.inplace)

def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
return 'inplace=True' if self.inplace else ''

class Softmax(Module):
__constants__ = ['dim']
Expand Down
Loading