From 222d314d04c80e931e3d6afcb19efeb702893bba Mon Sep 17 00:00:00 2001 From: Vera <46623149+vera30@users.noreply.github.com> Date: Thu, 26 Feb 2026 16:46:00 +1300 Subject: [PATCH] enabled a working pytorch backend for pfield, simus --- src/pymust/numericalEngine.py | 12 ++++++++++++ src/pymust/pfield.py | 6 ++++-- src/pymust/simus.py | 35 +++++++++++++++++++++++------------ 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/pymust/numericalEngine.py b/src/pymust/numericalEngine.py index ed48597..f61e0dd 100644 --- a/src/pymust/numericalEngine.py +++ b/src/pymust/numericalEngine.py @@ -1,3 +1,4 @@ +from hashlib import new import numpy, importlib @@ -31,6 +32,17 @@ def to_numpy(self, x): return x.detach().to('cpu').numpy() else: return x.detach().numpy() + + def to_backend(self, x): + if self.isNumpy: + return x + else: + return self.backend.asarray(x, device=self.device) + + def __deepcopy__(self, memo): + new = NumericalEngine(self.backend_name, self.device) + memo[id(self)] = new + return new NumpyEngine = NumericalEngine() \ No newline at end of file diff --git a/src/pymust/pfield.py b/src/pymust/pfield.py index 14614c2..74640f0 100644 --- a/src/pymust/pfield.py +++ b/src/pymust/pfield.py @@ -327,7 +327,7 @@ def pfield(x : numpy.ndarray,y : numpy.ndarray, z: numpy.ndarray, # DR: Possibly add explanation of casting RC to single precision if options.RC is not None and len(options.RC): - options.RC = options.RC.astype(np.float32) + options.RC = np.asarray(options.RC, dtype = np.float32) #%------------------------------------% #% END of Check the OPTIONS structure % @@ -524,7 +524,9 @@ def pfield(x : numpy.ndarray,y : numpy.ndarray, z: numpy.ndarray, RP = 0 # % RP = Radiation Pattern if isSIMUS: #%- For SIMUS only (we need the full spectrum of RX signals): - SPECT = np.zeros((nSampling, param.NumberOfElements), dtype = np.complex64) + # print(nSampling, param.NumberOfElements) + # breakpoint() + SPECT = np.zeros((nSampling, param.Nelements), dtype = np.complex64) else: #%- For MKMOVIE only (we need the full spectrum of the pressure field): #%- For using PFIELD alone we need the spectrum recieved on each point: diff --git a/src/pymust/simus.py b/src/pymust/simus.py index e592f0d..4b5c8c8 100644 --- a/src/pymust/simus.py +++ b/src/pymust/simus.py @@ -1,4 +1,4 @@ -from . import utils, pfield, getpulse +from . import utils, pfield, getpulse, numericalEngine import logging, copy, multiprocessing, functools import numpy as np @@ -364,10 +364,13 @@ def simus(*varargin): #%- run PFIELD in a parallel pool (NW workers) if options.get('ParPool', False): - if 'numericalEngine' in options and not options['numericalEngine'].isNumpy: - raise NotImplemented("Cannot use a numerical engine other than numpy for parallel computing") - with options.getParallelPool() as pool: - idx = options.getParallelSplitIndices(x.shape[1]) + engine = options.get('numericalEngine', numericalEngine.NumpyEngine) + if not engine.isNumpy: + raise NotImplementedError("Cannot use a numerical engine other than numpy for parallel computing") + with options.getParallelPool() as pool: # ORIGINAL (debug): + print('Using parallel pool with {} workers'.format(pool._processes)) + #%- split the scatterers into NW chunks + idx = options.getParallelSplitIndices(x.shape[1]) #1 RS = pool.starmap(functools.partial(pfieldParallel, delaysTX = delaysTX, param = param, options = options), [ ( x[:,i:j], @@ -382,13 +385,21 @@ def simus(*varargin): # end else: #%- no parallel pool - options.RC = RC - # - extra_args = {} - if 'engine' in options: - extra_args['numericalEngine'] = options['numericalEngine'] - _, RFsp,idx = pfield(x,y,z,delaysTX,param,options, **extra_args) - + engine = options.get('numericalEngine', numericalEngine.NumpyEngine) + x_in, y_in, z_in = x, y, z + options.RC = RC + if not engine.isNumpy: + x_in = engine.to_backend(x) + y_in = engine.to_backend(y) if not utils.isEmpty(y) else None + z_in = engine.to_backend(z) + options.RC = engine.to_backend(RC) + param.RXdelay = engine.to_backend(param.RXdelay) + + _, RFsp,idx = pfield(x_in, y_in, z_in, delaysTX, param, options, engine=engine) + + if not engine.isNumpy: + RFsp = engine.to_numpy(RFsp) + param.RXdelay = engine.to_numpy(param.RXdelay) RFspectrum[idx,:] = RFsp #%-- RF signals (in the time domain)