3535 "moving_avg" ,
3636 "permutation_signal" ,
3737 "permute_channels" ,
38+ "phase_swap" ,
3839 "random_FT_phase" ,
3940 "random_slope_scale" ,
4041 "scaling" ,
@@ -114,7 +115,7 @@ def shift_horizontal(
114115 batch_equal : bool = True ,
115116) -> ArrayLike :
116117 """
117- shifts temporally the elements of the ArrayLike object.
118+ Shifts temporally the elements of the ArrayLike object.
118119
119120 Shift is applied along the last dimension.
120121 The empty elements at beginning or the ending part
@@ -438,6 +439,90 @@ def shift_frequency(
438439 return _shift_frequency (x , shift_freq , Fs , forward , random_shift , batch_equal , t , h )
439440
440441
442+ def phase_swap (x : ArrayLike ) -> ArrayLike :
443+ """
444+ Apply the phase swap data augmentation to the ArrayLike object.
445+
446+ The phase swap data augmentation consists in merging the amplitude
447+ and phase components of biosignals from different sources to help
448+ the model learn their coupling.
449+ Specifically, the amplitude and phase of two randomly selected EEG samples
450+ are extracted using the Fourier transform.
451+ New samples are then generated by applying the inverse Fourier transform,
452+ combining the amplitude from one sample with the phase from the other.
453+ See the following paper for more information [phaseswap]_.
454+
455+ Parameters
456+ ----------
457+ x : ArrayLike
458+ A 3-dimensional torch tensor or numpy array.
459+ The last two dimensions must refer to the EEG (Channels x Samples).
460+
461+ Returns
462+ -------
463+ x: ArrayLike
464+ The augmented version of the input Tensor or Array.
465+
466+ Note
467+ ----
468+ `Phase swap` ignores the class of each sample.
469+
470+
471+ References
472+ ----------
473+ .. [phaseswap] Lemkhenter, Abdelhak, and Favaro, Paolo.
474+ "Boosting Generalization in Bio-signal Classification by
475+ Learning the Phase-Amplitude Coupling". DAGM GCPR (2020).
476+
477+ """
478+
479+ Ndim = len (x .shape )
480+ if Ndim != 3 :
481+ raise ValueError ("x must be a 3-dimensional array or tensor" )
482+
483+ N = x .shape [0 ]
484+
485+ if isinstance (x , torch .Tensor ):
486+ # Compute fft, module and phase
487+ xfft = torch .fft .fft (x )
488+ amplitude = xfft .abs ()
489+ phase = xfft .angle ()
490+ x_aug = torch .clone (xfft )
491+
492+ # Random shuffle indeces
493+ idx_shuffle = torch .randperm (N ).to (device = x .device )
494+ idx_shuffle_1 = idx_shuffle [: (N // 2 )]
495+ idx_shuffle_2 = idx_shuffle [(N // 2 ) : (N // 2 ) * 2 ]
496+
497+ # Apply phase swap
498+ x_aug [idx_shuffle_1 ] = amplitude [idx_shuffle_1 ] * torch .exp (1j * phase [idx_shuffle_2 ])
499+ x_aug [idx_shuffle_2 ] = amplitude [idx_shuffle_2 ] * torch .exp (1j * phase [idx_shuffle_1 ])
500+
501+ # Reconstruct the signal
502+ x_aug = (torch .fft .ifft (x_aug )).real .to (device = x .device )
503+
504+ else :
505+
506+ xfft = np .fft .fft (x )
507+ amplitude = np .abs (xfft )
508+ phase = np .angle (xfft )
509+ x_aug = np .copy (xfft )
510+
511+ # Random shuffle indeces
512+ idx_shuffle = np .random .permutation (N )
513+ idx_shuffle_1 = idx_shuffle [: (N // 2 )]
514+ idx_shuffle_2 = idx_shuffle [(N // 2 ) : (N // 2 ) * 2 ]
515+
516+ # Apply phase swap
517+ x_aug [idx_shuffle_1 ] = amplitude [idx_shuffle_1 ] * np .exp (1j * phase [idx_shuffle_2 ])
518+ x_aug [idx_shuffle_2 ] = amplitude [idx_shuffle_2 ] * np .exp (1j * phase [idx_shuffle_1 ])
519+
520+ # Reconstruct the signal
521+ x_aug = (np .fft .ifft (x_aug )).real
522+
523+ return x_aug
524+
525+
441526def flip_vertical (x : ArrayLike ) -> ArrayLike :
442527 """
443528 changes the sign of all the elements of the input.
@@ -456,9 +541,8 @@ def flip_vertical(x: ArrayLike) -> ArrayLike:
456541 -------
457542 >>> import torch
458543 >>> import selfeeg.augmentation as aug
459- >>> x = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
460- >>> xaug = aug.flip_vertical(x)
461- >>> print(torch.equal(xaug, x*(-1))) # should return True
544+ >>> x = torch.randn(64,32,512)
545+ >>> xaug = aug.phase_swap(x)
462546
463547 """
464548 x_flip = x * (- 1 )
0 commit comments