Skip to content

Commit 1ae63fa

Browse files
authored
Merge pull request #17 from fedepup/main
new features for version 0.2.1
2 parents 0042816 + c7561f6 commit 1ae63fa

11 files changed

Lines changed: 1356 additions & 200 deletions

File tree

RELEASE.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
# Version X.X.X (only via git install)
22

3+
# Version 0.2.1 (latest)
4+
35
**Functionality**
46

7+
- **augmentation module**:
8+
- add Circular augmenter in compose module.
9+
- add phase swap augmentation in functional module.
10+
- **models module**:
11+
- models can be initialized with a custom seed.
12+
- add EEGConformer.
13+
- add xEEGNet.
514
- **dataloading module**:
615
- EEGDataset now supports EEG with multiple labels (1 per window partition).
716
- **ssl module**:
@@ -14,7 +23,7 @@
1423
* reduced unittest overall time
1524

1625

17-
# Version 0.2.0 (latest)
26+
# Version 0.2.0
1827

1928
**Functionality**
2029

docs/selfeeg.augmentation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Functions
5858
moving_avg
5959
permutation_signal
6060
permute_channels
61+
phase_swap
6162
random_FT_phase
6263
random_slope_scale
6364
scaling

docs/selfeeg.models.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Classes
4040
:template: classtemplate.rst
4141

4242
DeepConvNetEncoder
43+
EEGConformerEncoder
4344
EEGInceptionEncoder
4445
EEGNetEncoder
4546
EEGSymEncoder
@@ -49,6 +50,7 @@ Classes
4950
StagerNetEncoder
5051
STNetEncoder
5152
TinySleepNetEncoder
53+
xEEGNetEncoder
5254

5355

5456
models.zoo module
@@ -65,6 +67,7 @@ Classes
6567

6668
ATCNet
6769
DeepConvNet
70+
EEGConformer
6871
EEGInception
6972
EEGNet
7073
EEGSym
@@ -74,3 +77,4 @@ Classes
7477
StagerNet
7578
STNet
7679
TinySleepNet
80+
xEEGNet

selfeeg/augmentation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
moving_avg,
2626
permutation_signal,
2727
permute_channels,
28+
phase_swap,
2829
random_FT_phase,
2930
random_slope_scale,
3031
scaling,

selfeeg/augmentation/functional.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"moving_avg",
3636
"permutation_signal",
3737
"permute_channels",
38+
"phase_swap",
3839
"random_FT_phase",
3940
"random_slope_scale",
4041
"scaling",
@@ -114,7 +115,7 @@ def shift_horizontal(
114115
batch_equal: bool = True,
115116
) -> ArrayLike:
116117
"""
117-
shifts temporally the elements of the ArrayLike object.
118+
Shifts temporally the elements of the ArrayLike object.
118119
119120
Shift is applied along the last dimension.
120121
The empty elements at beginning or the ending part
@@ -438,6 +439,90 @@ def shift_frequency(
438439
return _shift_frequency(x, shift_freq, Fs, forward, random_shift, batch_equal, t, h)
439440

440441

442+
def phase_swap(x: ArrayLike) -> ArrayLike:
443+
"""
444+
Apply the phase swap data augmentation to the ArrayLike object.
445+
446+
The phase swap data augmentation consists in merging the amplitude
447+
and phase components of biosignals from different sources to help
448+
the model learn their coupling.
449+
Specifically, the amplitude and phase of two randomly selected EEG samples
450+
are extracted using the Fourier transform.
451+
New samples are then generated by applying the inverse Fourier transform,
452+
combining the amplitude from one sample with the phase from the other.
453+
See the following paper for more information [phaseswap]_.
454+
455+
Parameters
456+
----------
457+
x : ArrayLike
458+
A 3-dimensional torch tensor or numpy array.
459+
The last two dimensions must refer to the EEG (Channels x Samples).
460+
461+
Returns
462+
-------
463+
x: ArrayLike
464+
The augmented version of the input Tensor or Array.
465+
466+
Note
467+
----
468+
`Phase swap` ignores the class of each sample.
469+
470+
471+
References
472+
----------
473+
.. [phaseswap] Lemkhenter, Abdelhak, and Favaro, Paolo.
474+
"Boosting Generalization in Bio-signal Classification by
475+
Learning the Phase-Amplitude Coupling". DAGM GCPR (2020).
476+
477+
"""
478+
479+
Ndim = len(x.shape)
480+
if Ndim != 3:
481+
raise ValueError("x must be a 3-dimensional array or tensor")
482+
483+
N = x.shape[0]
484+
485+
if isinstance(x, torch.Tensor):
486+
# Compute fft, module and phase
487+
xfft = torch.fft.fft(x)
488+
amplitude = xfft.abs()
489+
phase = xfft.angle()
490+
x_aug = torch.clone(xfft)
491+
492+
# Random shuffle indeces
493+
idx_shuffle = torch.randperm(N).to(device=x.device)
494+
idx_shuffle_1 = idx_shuffle[: (N // 2)]
495+
idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2]
496+
497+
# Apply phase swap
498+
x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * torch.exp(1j * phase[idx_shuffle_2])
499+
x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * torch.exp(1j * phase[idx_shuffle_1])
500+
501+
# Reconstruct the signal
502+
x_aug = (torch.fft.ifft(x_aug)).real.to(device=x.device)
503+
504+
else:
505+
506+
xfft = np.fft.fft(x)
507+
amplitude = np.abs(xfft)
508+
phase = np.angle(xfft)
509+
x_aug = np.copy(xfft)
510+
511+
# Random shuffle indeces
512+
idx_shuffle = np.random.permutation(N)
513+
idx_shuffle_1 = idx_shuffle[: (N // 2)]
514+
idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2]
515+
516+
# Apply phase swap
517+
x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * np.exp(1j * phase[idx_shuffle_2])
518+
x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * np.exp(1j * phase[idx_shuffle_1])
519+
520+
# Reconstruct the signal
521+
x_aug = (np.fft.ifft(x_aug)).real
522+
523+
return x_aug
524+
525+
441526
def flip_vertical(x: ArrayLike) -> ArrayLike:
442527
"""
443528
changes the sign of all the elements of the input.
@@ -456,9 +541,8 @@ def flip_vertical(x: ArrayLike) -> ArrayLike:
456541
-------
457542
>>> import torch
458543
>>> import selfeeg.augmentation as aug
459-
>>> x = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
460-
>>> xaug = aug.flip_vertical(x)
461-
>>> print(torch.equal(xaug, x*(-1))) # should return True
544+
>>> x = torch.randn(64,32,512)
545+
>>> xaug = aug.phase_swap(x)
462546
463547
"""
464548
x_flip = x * (-1)

selfeeg/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .encoders import (
1010
BasicBlock1,
1111
DeepConvNetEncoder,
12+
EEGConformerEncoder,
1213
EEGInceptionEncoder,
1314
EEGNetEncoder,
1415
EEGSymEncoder,
@@ -17,11 +18,13 @@
1718
StagerNetEncoder,
1819
STNetEncoder,
1920
TinySleepNetEncoder,
21+
xEEGNetEncoder,
2022
)
2123

2224
from .zoo import (
2325
ATCNet,
2426
DeepConvNet,
27+
EEGConformer,
2528
EEGInception,
2629
EEGNet,
2730
EEGSym,
@@ -31,4 +34,5 @@
3134
StagerNet,
3235
STNet,
3336
TinySleepNet,
37+
xEEGNet,
3438
)

0 commit comments

Comments
 (0)