@@ -216,7 +216,8 @@ def _compute_phi_fast(
216216_numba_B_kz_func , _numba_phi_func = _get_numba_functions ()
217217
218218
219- @beartowertype
219+ # NOTE: beartype disabled for performance — this is called 100s-1000s of times per MU
220+ # Type validation happens at the simulate_fiber_v2 dispatcher level instead
220221def _simulate_fiber_v2_python (
221222 Fs : float ,
222223 v : float ,
@@ -240,6 +241,10 @@ def _simulate_fiber_v2_python(
240241 B_incomplete : np .ndarray | None = None ,
241242 use_gpu : bool = True ,
242243 fiber_length__mm : float | None = None ,
244+ theta_offset : float = 0.0 ,
245+ pos_z_precomputed : np .ndarray | None = None ,
246+ pos_theta_precomputed : np .ndarray | None = None ,
247+ rele_precomputed : float | None = None ,
243248):
244249 """
245250 Simulate a single fiber (Python implementation).
@@ -316,14 +321,16 @@ def _simulate_fiber_v2_python(
316321 # Get electrode configuration from the array
317322 channels = [electrode_array .num_rows , electrode_array .num_cols ]
318323
319- # Extract magnitudes from Quantity objects for numerical operations
324+ # Use pre-computed electrode positions if available (avoids per-fiber grid recomputation)
320325 import quantities as pq
321-
322- rele = float (electrode_array .electrode_radius__mm .rescale (pq .mm ).magnitude )
323- pos_z_mm = electrode_array .pos_z .rescale (pq .mm ).magnitude # Extract as plain array in mm
324- pos_theta_rad = electrode_array .pos_theta .rescale (
325- pq .rad
326- ).magnitude # Extract as plain array in rad
326+ if pos_z_precomputed is not None and pos_theta_precomputed is not None :
327+ rele = rele_precomputed if rele_precomputed is not None else float (electrode_array .electrode_radius__mm .rescale (pq .mm ).magnitude )
328+ pos_z_mm = pos_z_precomputed
329+ pos_theta_rad = pos_theta_precomputed + theta_offset # Apply angular offset as simple addition
330+ else :
331+ rele = float (electrode_array .electrode_radius__mm .rescale (pq .mm ).magnitude )
332+ pos_z_mm = electrode_array .pos_z .rescale (pq .mm ).magnitude
333+ pos_theta_rad = electrode_array .pos_theta .rescale (pq .rad ).magnitude + theta_offset
327334
328335 ###################################################################################################
329336 ## 1. Constants
@@ -813,44 +820,75 @@ def _simulate_fiber_v2_python(
813820
814821 # Use the electrode array's pre-computed positions
815822 H_glo = np .multiply (H_vc , H_ele )
816- # print(f"DEBUG: H_glo range = [{np.min(np.abs(H_glo)):.6e}, {np.max(np.abs(H_glo)):.6e}]")
817- B_kz = np .zeros ((channels [0 ], channels [1 ], len (k_z )))
818-
819- for channel_z in range (channels [0 ]):
820- for channel_theta in range (channels [1 ]):
821- arg = np .multiply (
822- H_glo ,
823- np .exp (1j * pos_theta_rad [channel_z , channel_theta ] * ktheta_mesh_kzktheta )
824- * (k_theta [1 ] - k_theta [0 ]),
825- )
826- B_kz [channel_z , channel_theta , :] = sum (np .transpose (arg )) / 2 / math .pi
827823
828- # print(f"DEBUG: B_kz range = [{np.min(np.abs(B_kz)):.6e}, {np.max(np.abs(B_kz)):.6e}]")
824+ k_theta_diff = k_theta [1 ] - k_theta [0 ]
825+ k_z_diff = k_z [1 ] - k_z [0 ]
826+
827+ if _numba_B_kz_func is not None :
828+ # Use Numba-optimized parallel computation
829+ H_glo_f64 = np .ascontiguousarray (H_glo .real , dtype = np .float64 )
830+ H_glo_i64 = np .ascontiguousarray (H_glo .imag , dtype = np .float64 )
831+ pt_f64 = np .ascontiguousarray (pos_theta_rad , dtype = np .float64 )
832+ km_f64 = np .ascontiguousarray (ktheta_mesh_kzktheta , dtype = np .float64 )
833+ B_kz = _numba_B_kz_func (
834+ H_glo_f64 , H_glo_i64 , pt_f64 ,
835+ float (k_theta_diff ), km_f64 ,
836+ channels [0 ], channels [1 ], len (k_z ), len (k_theta ),
837+ )
838+ else :
839+ B_kz = np .zeros ((channels [0 ], channels [1 ], len (k_z )))
840+ for channel_z in range (channels [0 ]):
841+ for channel_theta in range (channels [1 ]):
842+ arg = np .multiply (
843+ H_glo ,
844+ np .exp (1j * pos_theta_rad [channel_z , channel_theta ] * ktheta_mesh_kzktheta )
845+ * k_theta_diff ,
846+ )
847+ B_kz [channel_z , channel_theta , :] = sum (np .transpose (arg )) / 2 / math .pi
829848
830849 ###################################################################################################
831850 ## 6. phi(t) for each channel
832851
833- phi = np .zeros ((channels [0 ], channels [1 ], len (t )))
834- for channel_z in range (channels [0 ]):
835- for channel_theta in range (channels [1 ]):
836- auxiliar = np .dot (
837- np .ones ((len (I_kzkt [1 , :]), 1 )),
838- B_kz [channel_z , channel_theta , :].reshape (1 , - 1 ),
839- )
840- auxiliar = np .transpose (auxiliar )
841- arg = np .multiply (I_kzkt , auxiliar )
842- arg2 = np .multiply (
843- arg ,
844- np .exp (1j * pos_z_mm [channel_z , channel_theta ] * kz_mesh_kzkt ) * (k_z [1 ] - k_z [0 ]),
845- )
846- PHI = sum (arg2 )
847- phi [channel_z , channel_theta , :] = np .real (
848- (
852+ if _numba_phi_func is not None :
853+ # Use Numba-optimized parallel computation
854+ PHI_complex = _numba_phi_func (
855+ np .ascontiguousarray (I_kzkt .real , dtype = np .float64 ),
856+ np .ascontiguousarray (I_kzkt .imag , dtype = np .float64 ),
857+ np .ascontiguousarray (B_kz , dtype = np .float64 ),
858+ np .ascontiguousarray (pos_z_mm , dtype = np .float64 ),
859+ np .ascontiguousarray (kz_mesh_kzkt , dtype = np .float64 ),
860+ channels [0 ], channels [1 ], len (k_z ), len (k_t ),
861+ float (k_z_diff ),
862+ )
863+ # Apply IFFT to get time-domain signal
864+ phi = np .zeros ((channels [0 ], channels [1 ], len (t )))
865+ for channel_z in range (channels [0 ]):
866+ for channel_theta in range (channels [1 ]):
867+ phi [channel_z , channel_theta , :] = np .real (
868+ np .fft .ifft (
869+ np .fft .fftshift (PHI_complex [channel_z , channel_theta , :] * len (psi ))
870+ )
871+ )
872+ else :
873+ phi = np .zeros ((channels [0 ], channels [1 ], len (t )))
874+ for channel_z in range (channels [0 ]):
875+ for channel_theta in range (channels [1 ]):
876+ auxiliar = np .dot (
877+ np .ones ((len (I_kzkt [1 , :]), 1 )),
878+ B_kz [channel_z , channel_theta , :].reshape (1 , - 1 ),
879+ )
880+ auxiliar = np .transpose (auxiliar )
881+ arg = np .multiply (I_kzkt , auxiliar )
882+ arg2 = np .multiply (
883+ arg ,
884+ np .exp (1j * pos_z_mm [channel_z , channel_theta ] * kz_mesh_kzkt ) * k_z_diff ,
885+ )
886+ PHI = sum (arg2 )
887+ phi [channel_z , channel_theta , :] = np .real (
849888 np .fft .ifft (
850889 np .fft .fftshift (PHI / 2 / math .pi * len (psi ))
851- ) # Matches original line 239
890+ )
852891 )
853- )
854892
855893 # Center the MUAP signal in the time window by finding the peak and shifting
856894 # For each electrode channel, find the peak and center it
@@ -877,7 +915,7 @@ def _simulate_fiber_v2_python(
877915 return phi , A_matrix , B_incomplete
878916
879917
880- @ beartowertype
918+ # NOTE: beartype disabled — called 100s-1000s of times per MU in the inner fiber loop
881919def simulate_fiber_v2 (
882920 Fs : float ,
883921 v : float ,
@@ -902,6 +940,10 @@ def simulate_fiber_v2(
902940 use_cython : bool = True ,
903941 use_gpu : bool = True ,
904942 fiber_length__mm : float | None = None ,
943+ theta_offset : float = 0.0 ,
944+ pos_z_precomputed : np .ndarray | None = None ,
945+ pos_theta_precomputed : np .ndarray | None = None ,
946+ rele_precomputed : float | None = None ,
905947):
906948 """
907949 Simulate a single fiber (dispatcher to Cython or Python implementation).
@@ -1066,4 +1108,8 @@ def simulate_fiber_v2(
10661108 B_incomplete ,
10671109 use_gpu ,
10681110 fiber_length__mm ,
1111+ theta_offset ,
1112+ pos_z_precomputed ,
1113+ pos_theta_precomputed ,
1114+ rele_precomputed ,
10691115 )
0 commit comments