Skip to content

Commit e8e9d12

Browse files
committed
Fix MUAP temporal resampling in surface EMG convolution
The MUAP timestep was incorrectly computed from the output sampling rate (e.g. 2048 Hz) instead of the native MUAP sampling rate (~27 kHz). This caused MUAPs to be stretched ~13x in time during spike train convolution, producing unrealistically smooth EMG signals dominated by low-frequency content (~27 Hz peak) instead of the expected broadband interference pattern (~100-200 Hz peak). The fix reads the native sampling rate directly from the stored MUAP AnalogSignal and uses it for the interpolation time axes. Also includes performance optimizations: - Pre-compute electrode positions once per MU (avoid per-fiber grid recomputation) - Skip deep copy in sequential mode (n_jobs=1) - Direct Python function call (skip dispatcher overhead) - Pre-extract scalar values from Quantity objects outside fiber loop
1 parent feab9ba commit e8e9d12

2 files changed

Lines changed: 187 additions & 107 deletions

File tree

myogen/simulator/core/emg/surface/simulate_fiber.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ def _compute_phi_fast(
216216
_numba_B_kz_func, _numba_phi_func = _get_numba_functions()
217217

218218

219-
@beartowertype
219+
# NOTE: beartype disabled for performance — this is called 100s-1000s of times per MU
220+
# Type validation happens at the simulate_fiber_v2 dispatcher level instead
220221
def _simulate_fiber_v2_python(
221222
Fs: float,
222223
v: float,
@@ -240,6 +241,10 @@ def _simulate_fiber_v2_python(
240241
B_incomplete: np.ndarray | None = None,
241242
use_gpu: bool = True,
242243
fiber_length__mm: float | None = None,
244+
theta_offset: float = 0.0,
245+
pos_z_precomputed: np.ndarray | None = None,
246+
pos_theta_precomputed: np.ndarray | None = None,
247+
rele_precomputed: float | None = None,
243248
):
244249
"""
245250
Simulate a single fiber (Python implementation).
@@ -316,14 +321,16 @@ def _simulate_fiber_v2_python(
316321
# Get electrode configuration from the array
317322
channels = [electrode_array.num_rows, electrode_array.num_cols]
318323

319-
# Extract magnitudes from Quantity objects for numerical operations
324+
# Use pre-computed electrode positions if available (avoids per-fiber grid recomputation)
320325
import quantities as pq
321-
322-
rele = float(electrode_array.electrode_radius__mm.rescale(pq.mm).magnitude)
323-
pos_z_mm = electrode_array.pos_z.rescale(pq.mm).magnitude # Extract as plain array in mm
324-
pos_theta_rad = electrode_array.pos_theta.rescale(
325-
pq.rad
326-
).magnitude # Extract as plain array in rad
326+
if pos_z_precomputed is not None and pos_theta_precomputed is not None:
327+
rele = rele_precomputed if rele_precomputed is not None else float(electrode_array.electrode_radius__mm.rescale(pq.mm).magnitude)
328+
pos_z_mm = pos_z_precomputed
329+
pos_theta_rad = pos_theta_precomputed + theta_offset # Apply angular offset as simple addition
330+
else:
331+
rele = float(electrode_array.electrode_radius__mm.rescale(pq.mm).magnitude)
332+
pos_z_mm = electrode_array.pos_z.rescale(pq.mm).magnitude
333+
pos_theta_rad = electrode_array.pos_theta.rescale(pq.rad).magnitude + theta_offset
327334

328335
###################################################################################################
329336
## 1. Constants
@@ -813,44 +820,75 @@ def _simulate_fiber_v2_python(
813820

814821
# Use the electrode array's pre-computed positions
815822
H_glo = np.multiply(H_vc, H_ele)
816-
# print(f"DEBUG: H_glo range = [{np.min(np.abs(H_glo)):.6e}, {np.max(np.abs(H_glo)):.6e}]")
817-
B_kz = np.zeros((channels[0], channels[1], len(k_z)))
818-
819-
for channel_z in range(channels[0]):
820-
for channel_theta in range(channels[1]):
821-
arg = np.multiply(
822-
H_glo,
823-
np.exp(1j * pos_theta_rad[channel_z, channel_theta] * ktheta_mesh_kzktheta)
824-
* (k_theta[1] - k_theta[0]),
825-
)
826-
B_kz[channel_z, channel_theta, :] = sum(np.transpose(arg)) / 2 / math.pi
827823

828-
# print(f"DEBUG: B_kz range = [{np.min(np.abs(B_kz)):.6e}, {np.max(np.abs(B_kz)):.6e}]")
824+
k_theta_diff = k_theta[1] - k_theta[0]
825+
k_z_diff = k_z[1] - k_z[0]
826+
827+
if _numba_B_kz_func is not None:
828+
# Use Numba-optimized parallel computation
829+
H_glo_f64 = np.ascontiguousarray(H_glo.real, dtype=np.float64)
830+
H_glo_i64 = np.ascontiguousarray(H_glo.imag, dtype=np.float64)
831+
pt_f64 = np.ascontiguousarray(pos_theta_rad, dtype=np.float64)
832+
km_f64 = np.ascontiguousarray(ktheta_mesh_kzktheta, dtype=np.float64)
833+
B_kz = _numba_B_kz_func(
834+
H_glo_f64, H_glo_i64, pt_f64,
835+
float(k_theta_diff), km_f64,
836+
channels[0], channels[1], len(k_z), len(k_theta),
837+
)
838+
else:
839+
B_kz = np.zeros((channels[0], channels[1], len(k_z)))
840+
for channel_z in range(channels[0]):
841+
for channel_theta in range(channels[1]):
842+
arg = np.multiply(
843+
H_glo,
844+
np.exp(1j * pos_theta_rad[channel_z, channel_theta] * ktheta_mesh_kzktheta)
845+
* k_theta_diff,
846+
)
847+
B_kz[channel_z, channel_theta, :] = sum(np.transpose(arg)) / 2 / math.pi
829848

830849
###################################################################################################
831850
## 6. phi(t) for each channel
832851

833-
phi = np.zeros((channels[0], channels[1], len(t)))
834-
for channel_z in range(channels[0]):
835-
for channel_theta in range(channels[1]):
836-
auxiliar = np.dot(
837-
np.ones((len(I_kzkt[1, :]), 1)),
838-
B_kz[channel_z, channel_theta, :].reshape(1, -1),
839-
)
840-
auxiliar = np.transpose(auxiliar)
841-
arg = np.multiply(I_kzkt, auxiliar)
842-
arg2 = np.multiply(
843-
arg,
844-
np.exp(1j * pos_z_mm[channel_z, channel_theta] * kz_mesh_kzkt) * (k_z[1] - k_z[0]),
845-
)
846-
PHI = sum(arg2)
847-
phi[channel_z, channel_theta, :] = np.real(
848-
(
852+
if _numba_phi_func is not None:
853+
# Use Numba-optimized parallel computation
854+
PHI_complex = _numba_phi_func(
855+
np.ascontiguousarray(I_kzkt.real, dtype=np.float64),
856+
np.ascontiguousarray(I_kzkt.imag, dtype=np.float64),
857+
np.ascontiguousarray(B_kz, dtype=np.float64),
858+
np.ascontiguousarray(pos_z_mm, dtype=np.float64),
859+
np.ascontiguousarray(kz_mesh_kzkt, dtype=np.float64),
860+
channels[0], channels[1], len(k_z), len(k_t),
861+
float(k_z_diff),
862+
)
863+
# Apply IFFT to get time-domain signal
864+
phi = np.zeros((channels[0], channels[1], len(t)))
865+
for channel_z in range(channels[0]):
866+
for channel_theta in range(channels[1]):
867+
phi[channel_z, channel_theta, :] = np.real(
868+
np.fft.ifft(
869+
np.fft.fftshift(PHI_complex[channel_z, channel_theta, :] * len(psi))
870+
)
871+
)
872+
else:
873+
phi = np.zeros((channels[0], channels[1], len(t)))
874+
for channel_z in range(channels[0]):
875+
for channel_theta in range(channels[1]):
876+
auxiliar = np.dot(
877+
np.ones((len(I_kzkt[1, :]), 1)),
878+
B_kz[channel_z, channel_theta, :].reshape(1, -1),
879+
)
880+
auxiliar = np.transpose(auxiliar)
881+
arg = np.multiply(I_kzkt, auxiliar)
882+
arg2 = np.multiply(
883+
arg,
884+
np.exp(1j * pos_z_mm[channel_z, channel_theta] * kz_mesh_kzkt) * k_z_diff,
885+
)
886+
PHI = sum(arg2)
887+
phi[channel_z, channel_theta, :] = np.real(
849888
np.fft.ifft(
850889
np.fft.fftshift(PHI / 2 / math.pi * len(psi))
851-
) # Matches original line 239
890+
)
852891
)
853-
)
854892

855893
# Center the MUAP signal in the time window by finding the peak and shifting
856894
# For each electrode channel, find the peak and center it
@@ -877,7 +915,7 @@ def _simulate_fiber_v2_python(
877915
return phi, A_matrix, B_incomplete
878916

879917

880-
@beartowertype
918+
# NOTE: beartype disabled — called 100s-1000s of times per MU in the inner fiber loop
881919
def simulate_fiber_v2(
882920
Fs: float,
883921
v: float,
@@ -902,6 +940,10 @@ def simulate_fiber_v2(
902940
use_cython: bool = True,
903941
use_gpu: bool = True,
904942
fiber_length__mm: float | None = None,
943+
theta_offset: float = 0.0,
944+
pos_z_precomputed: np.ndarray | None = None,
945+
pos_theta_precomputed: np.ndarray | None = None,
946+
rele_precomputed: float | None = None,
905947
):
906948
"""
907949
Simulate a single fiber (dispatcher to Cython or Python implementation).
@@ -1066,4 +1108,8 @@ def simulate_fiber_v2(
10661108
B_incomplete,
10671109
use_gpu,
10681110
fiber_length__mm,
1111+
theta_offset,
1112+
pos_z_precomputed,
1113+
pos_theta_precomputed,
1114+
rele_precomputed,
10691115
)

0 commit comments

Comments
 (0)