Major refactoring for JAX-style classes.#29
Conversation
- Implemented `fo_integrators.py` for full orbit tracing with various methods and parameters. - Implemented `gc_integrators.py` for guiding center dynamics with adaptative and constant step sizes. - Enhanced `Tracing` class in `dynamics.py` to support multiple methods and step sizes.
…d adjust num_steps based on dt
… class for improved step size handling
…for adaptive step size
…ters for performance
…ance plots, and improve layout for better visualization
…n to scale the modes with different norms, optimization.py slightly changed to accomodate changes in surfaces. The example optimize_coils_and_surfaces.py was also changed to accomodate the changes
…d Coils and Curves into correct PyTrees
…arate gamma computation
|
Tests need fixing; |
There was a problem hiding this comment.
Pull request overview
This PR implements a major refactoring to make coils and surfaces proper JAX PyTrees, enabling automatic differentiation. It introduces a new loss wrapper system for gradient-based optimization and adds comprehensive analysis and validation code comparing ESSOS with SIMSOPT.
Key changes:
- Refactored
Coils,Curves,SurfaceRZFourier, andBiotSavartclasses as JAX PyTrees with proper tree flattening/unflattening - Added
essos/losses.pywithcustom_lossandcomposite_lossclasses for differentiable loss functions - Updated API:
Coils_from_json()→Coils.from_json(),tracing.energyproperty →tracing.energy()method - Added extensive analysis scripts for validation against SIMSOPT
Reviewed changes
Copilot reviewed 32 out of 32 changed files in this pull request and generated 39 comments.
Show a summary per file
| File | Description |
|---|---|
essos/losses.py |
New module implementing base_loss, custom_loss, and composite_loss classes for automatic differentiation |
essos/surfaces.py |
Refactored SurfaceRZFourier as PyTree with cached properties and improved initialization |
essos/coils.py |
Refactored Curves and Coils as PyTrees with cached properties, changed to classmethod constructors |
essos/fields.py |
Added MagneticField base class and registered BiotSavart as PyTree |
essos/dynamics.py |
Changed energy from cached property to method, added Particles.join() method |
essos/objective_functions.py |
Removed deprecated loss functions, added new coil separation and curvature losses |
essos/optimization.py |
Updated surface instantiation to include mpol/ntor parameters |
examples/optimize_coils_vmec_surface.py |
Major rewrite using new loss wrapper API instead of old optimization functions |
examples/trace_particles_coils_guidingcenter.py |
Updated imports and API calls (from_json, energy method) |
examples/trace_fieldlines_coils.py |
Updated to use Coils.from_json() |
examples/optimize_coils_particle_confinement_fullorbit.py |
Minor formatting and parameter updates |
examples/optimize_coils_and_surface.py |
Added mpol/ntor parameters, simplified loss calculations |
examples/input_files/*. |
Updated VMEC input file coefficients |
examples/comparisons_SIMSOPT/*.py |
Deleted old comparison scripts |
examples/compare_guidingcenter_fullorbit.py |
Updated particle initialization and energy calculation |
analysis/*.py |
New analysis scripts for validation and benchmarking |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): | ||
| # assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" |
There was a problem hiding this comment.
This comment appears to contain commented-out code.
| # if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): | |
| # assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" | |
| if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): | |
| assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" |
| # if nphi is not None: | ||
| # self.nphi = nphi | ||
| # else: | ||
| # nphi = self.nphi | ||
|
|
||
| # #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) | ||
| # #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) | ||
| # rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | ||
| # zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | ||
| # m_keep = min(mpol_old, mpol) | ||
| # n_keep = min(ntor_old, ntor) | ||
|
|
||
| # xm_old=self.xm | ||
| # xn_old=self.xn | ||
| # self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] | ||
| # self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] | ||
| # # Copy overlapping region | ||
| # for l in range(len(self.xm)): | ||
| # if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: | ||
| # index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp | ||
| # rc_new=rc_new.at[l].set(self.rc[index]) | ||
| # zs_new=zs_new.at[l].set(self.zs[index]) | ||
|
|
||
|
|
||
| # # Update attributes | ||
| # self.mpol, self.ntor = mpol, ntor | ||
| # self.rc, self.zs = rc_new, zs_new | ||
|
|
||
| # self.rmnc_interp = self.rc | ||
| # self.zmns_interp = self.zs | ||
|
|
There was a problem hiding this comment.
This comment appears to contain commented-out code.
| # if nphi is not None: | |
| # self.nphi = nphi | |
| # else: | |
| # nphi = self.nphi | |
| # #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) | |
| # #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) | |
| # rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | |
| # zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) | |
| # m_keep = min(mpol_old, mpol) | |
| # n_keep = min(ntor_old, ntor) | |
| # xm_old=self.xm | |
| # xn_old=self.xn | |
| # self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] | |
| # self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] | |
| # # Copy overlapping region | |
| # for l in range(len(self.xm)): | |
| # if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: | |
| # index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp | |
| # rc_new=rc_new.at[l].set(self.rc[index]) | |
| # zs_new=zs_new.at[l].set(self.zs[index]) | |
| # # Update attributes | |
| # self.mpol, self.ntor = mpol, ntor | |
| # self.rc, self.zs = rc_new, zs_new | |
| # self.rmnc_interp = self.rc | |
| # self.zmns_interp = self.zs |
| # self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) | ||
|
|
||
| # # Recompute angles and geometry | ||
| # if self.range_torus == 'full torus': div = 1 | ||
| # else: div = self.nfp | ||
| # if self.range_torus == 'half period': end_val = 0.5 | ||
| # else: end_val = 1.0 | ||
| # self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) | ||
| # self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) | ||
| # self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) | ||
|
|
||
| # self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) | ||
| # (self._gamma, self._gammadash_theta, self._gammadash_phi, | ||
| # self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) | ||
|
|
||
|
|
||
| # # Recompute AbsB if available | ||
| # if hasattr(self, 'bmnc'): | ||
| # self._AbsB = self._set_AbsB() | ||
|
|
||
| # return self | ||
|
|
There was a problem hiding this comment.
This comment appears to contain commented-out code.
| # self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) | |
| # # Recompute angles and geometry | |
| # if self.range_torus == 'full torus': div = 1 | |
| # else: div = self.nfp | |
| # if self.range_torus == 'half period': end_val = 0.5 | |
| # else: end_val = 1.0 | |
| # self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) | |
| # self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) | |
| # self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) | |
| # self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) | |
| # (self._gamma, self._gammadash_theta, self._gammadash_phi, | |
| # self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) | |
| # # Recompute AbsB if available | |
| # if hasattr(self, 'bmnc'): | |
| # self._AbsB = self._set_AbsB() | |
| # return self |
| @property | ||
| def dependencies_buffer(self): | ||
| if self._dependencies_buffer is None: | ||
| self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) |
There was a problem hiding this comment.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) | |
| self._dependencies_buffer = tree_util.tree_map(jnp.zeros_like, self.dependencies) |
| json_file_stel = curves_stel | ||
| field_simsopt = load(json_file_stel) | ||
| coils_simsopt = field_simsopt.coils | ||
| curves_simsopt = [coil.curve for coil in coils_simsopt] |
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', | ||
| stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) | ||
| block_until_ready(compile_tracing.trajectories) | ||
|
|
||
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | ||
| num_steps_essos = avg_steps_SIMSOPT_array[index] | ||
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | ||
| start_time = time() | ||
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', | ||
| stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) |
There was a problem hiding this comment.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
Keyword argument 'method' is not a supported parameter name of Tracing.init.
Keyword argument 'stepsize' is not a supported parameter name of Tracing.init.
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', | |
| stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) | |
| block_until_ready(compile_tracing.trajectories) | |
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | |
| num_steps_essos = avg_steps_SIMSOPT_array[index] | |
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | |
| start_time = time() | |
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', | |
| stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) | |
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles) | |
| block_until_ready(compile_tracing.trajectories) | |
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | |
| num_steps_essos = avg_steps_SIMSOPT_array[index] | |
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | |
| start_time = time() | |
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles) |
| tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo, | ||
| timesteps=timesteps_fo, tol_step_size=trace_tolerance) |
There was a problem hiding this comment.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
| tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, maxtime=tmax_gc, | ||
| timesteps=timesteps_gc, tol_step_size=trace_tolerance) |
There was a problem hiding this comment.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
| nfp=number_of_field_periods, stellsym=True) | ||
| coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) | ||
| field_essos = BiotSavart(coils_essos) | ||
| surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) |
There was a problem hiding this comment.
Call to SurfaceRZFourier.init with too few arguments; should be no fewer than 5.
| surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) | |
| surface_essos = SurfaceRZFourier_ESSOS(vmec, order_Fourier_series_coils, ntheta=ntheta, nphi=nphi, close=False) |
| EXPORT = False |
There was a problem hiding this comment.
This statement is unreachable.
| EXPORT = False | |
| EXPORT = True |
Refactor of coils & surfaces to be proper PyTrees;
Added a loss wrapper to differentiate with respect to the dogs (PyTree leaves);
Added analysis & validation of the code