diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index a273b5c03a..ef19560ff9 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -700,6 +700,25 @@ def _e_sup_theta_times_sqrt_g(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="e^theta*rho", + label="\\mathbf{e}^{\\theta} \\rho", + units="m^{-1}", + units_long="inverse meters", + description="Contravariant poloidal basis vector weighted by radial coordinate", + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e^theta"], + parameterization=["desc.equilibrium.equilibrium.Equilibrium"], +) +def _e_sup_theta_times_rho(params, transforms, profiles, data, **kwargs): + data["e^theta*rho"] = data["e^theta"] * transforms["grid"].nodes[:, 0][:, None] + return data + + @register_compute_fun( name="e^theta_r", label="\\partial_{\\rho} \\mathbf{e}^{\\theta}", diff --git a/desc/particles.py b/desc/particles.py index 7701064bc4..652b7de1e7 100644 --- a/desc/particles.py +++ b/desc/particles.py @@ -13,6 +13,7 @@ Tsit5, diffeqsolve, ) +from interpax import Interpolator3D from scipy.constants import Boltzmann, elementary_charge, proton_mass from desc.backend import jax, jit, jnp, tree_map @@ -174,12 +175,17 @@ def vf(self, t, x, args): m, q, mu = model_args if self.frame == "flux": assert isinstance( - eq_or_field, Equilibrium - ), "Integration in flux coordinates requires an Equilibrium." - - return self._compute_flux_coordinates( - x, eq_or_field, params, m, q, mu, **kwargs + eq_or_field, (Equilibrium, FourierChebyshevField, SplineFieldFlux) + ), ( + "Integration in flux coordinates requires an Equilibrium or " + "FourierChebyshevField or SplineFieldFlux." ) + if isinstance(eq_or_field, (FourierChebyshevField, SplineFieldFlux)): + return self._compute_flux_coordinates_with_fit(x, eq_or_field, m, q, mu) + else: + return self._compute_flux_coordinates( + x, eq_or_field, params, m, q, mu, **kwargs + ) elif self.frame == "lab": assert isinstance(eq_or_field, _MagneticField), ( "Integration in lab coordinates requires a MagneticField. If using an " @@ -192,6 +198,7 @@ def vf(self, t, x, args): x, eq_or_field, params, m, q, mu, **kwargs ) + @jit def _compute_flux_coordinates(self, x, eq, params, m, q, mu, **kwargs): """ODE equation for vacuum guiding center in flux coordinates. @@ -257,6 +264,42 @@ def _compute_flux_coordinates(self, x, eq, params, m, q, mu, **kwargs): dxdt = jnp.array([xpdot, ypdot, zetadot, vpardot]).reshape(x.shape) return dxdt.squeeze() + @jit + def _compute_flux_coordinates_with_fit(self, x, field, m, q, mu): + """ODE equation for vacuum guiding center in flux coordinates. + + A Fourier-Chebyshev fit for each component of the 3D magnetic field B, gradient + of the magnetic field strength and the basis vectors must be given as real and + imaginary parts. Basis vector e^theta is not well around the axis and blows up, + instead we use e^theta*rho which results in better fit. + """ + xp, yp, zeta, vpar = x + rho = jnp.sqrt(xp**2 + yp**2) + theta = jnp.arctan2(yp, xp) + # compute functions are not correct for very small rho + rho = jnp.where(rho < 1e-6, 1e-6, rho) + + data = field.evaluate(rho, theta, zeta) + + Rdot = vpar * data["b"] + ( + (m / q / data["|B|"] ** 2) + * ((mu * data["|B|"] / m) + vpar**2) + * cross(data["b"], data["grad(|B|)"]) + ) + # take dot product for rho, theta and zeta coordinates + rhodot = dot(Rdot, data["e^rho"]) + thetadot_x_rho = dot(Rdot, data["e^theta*rho"]) + zetadot = dot(Rdot, data["e^zeta"]) + + # get the derivative for cartesian-like coordinates + xpdot = rhodot * jnp.cos(theta) - thetadot_x_rho * jnp.sin(theta) + ypdot = rhodot * jnp.sin(theta) + thetadot_x_rho * jnp.cos(theta) + # derivative the parallel velocity + vpardot = -mu / m * dot(data["b"], data["grad(|B|)"]) + dxdt = jnp.array([xpdot, ypdot, zetadot, vpardot]).reshape(x.shape) + return dxdt.squeeze() + + @jit def _compute_lab_coordinates(self, x, field, params, m, q, mu, **kwargs): """Compute the RHS of the ODE using MagneticField. @@ -1065,6 +1108,11 @@ def trace_particles( will depend on ``model.vcoords``. """ + assert isinstance(field, (Equilibrium, _MagneticField)), ( + f"field must be either Equilibrium or MagneticField object but {type(field)} " + "given. If field type is FourierChebyshevField or SplineFieldFlux, please use " + "_trace_particles function." + ) if not params: params = field.params_dict if not options: @@ -1160,7 +1208,7 @@ def _trace_particles( Custom event function to stop integration. """ # convert cartesian-like for integration in flux coordinates - if isinstance(field, Equilibrium): + if isinstance(field, (Equilibrium, FourierChebyshevField, SplineFieldFlux)): xp = y0[:, 0] * jnp.cos(y0[:, 1]) yp = y0[:, 0] * jnp.sin(y0[:, 1]) y0 = y0.at[:, 0].set(xp) @@ -1198,7 +1246,7 @@ def _trace_particles( v = yt[:, :, 3:] # convert back to flux coordinates - if isinstance(field, Equilibrium): + if isinstance(field, (Equilibrium, FourierChebyshevField, SplineFieldFlux)): rho = jnp.sqrt(x[:, :, 0] ** 2 + x[:, :, 1] ** 2) theta = jnp.arctan2(x[:, :, 1], x[:, :, 0]) theta = jnp.where(theta < 0, theta + 2 * jnp.pi, theta) @@ -1248,3 +1296,295 @@ def _intfun_wrapper( event=event, throw=throw, ) + + +class FourierChebyshevField(IOAble): + """Convenience class for fitting and evaluating equilibrium fields. + + This class is intended to be used during particle tracing to reduce overhead + of creating transforms. It fits a Fourier-Fourier-Chebyshev series to the + quantities required for guiding center equations, and evaluates them + at requested points during tracing. + + Parameters + ---------- + L : int + Maximum order of the Chebyshev polynomial to be used in the radial direction. + M : int + Maximum order of the Fourier series to be used in the poloidal direction. + N : int + Maximum order of the Fourier series to be used in the toroidal direction. + """ + + _static_attrs = ["L", "M", "N", "M_fft", "N_fft", "data_keys"] + + def __init__(self, L, M, N): + self.L = L + self.M = M + self.N = N + + def build(self, eq): + """Build the constants for fit. + + During optimization, equilibrium field changes, however, the same transforms + can be used to get the fit faster. This method creates the grid and transforms + to be used during the fitting procedure. + + Parameters + ---------- + eq : Equilibrium + Equilibrium to be used to get transforms. + + """ + self.data_keys = ["B", "grad(|B|)", "e^rho", "e^theta*rho", "e^zeta"] + self.l = jnp.arange(self.L) + self.M_fft = 2 * self.M + 1 + self.N_fft = 2 * self.N + 1 + self.m = jnp.fft.fftfreq(self.M_fft) * self.M_fft + self.n = jnp.fft.fftfreq(self.N_fft) * self.N_fft + x = jnp.cos(jnp.pi * (2 * self.l + 1) / (2 * self.L)) + rho = (x + 1) / 2 + self.grid = LinearGrid(rho=rho, M=self.M, N=self.N, sym=False, NFP=eq.NFP) + self.transforms = get_transforms(self.data_keys, eq, self.grid) + + def fit(self, params, profiles): + """Fit a Fourier-Chebyshev series to an equilibrium field. + + First computes the magnetic field, its gradient and basis vectors at + the grid created in build. Then, finds the spectral coefficients to + each component of the computed vectors. Since e^theta doesn't behave + well around axis, the fit is computed for e^theta*rho (which is what actually + required by the guiding center equations). + + Parameters + ---------- + params : dict + Equilibriums `params_dict` which contains the parameters that define + the equiliubrium. + profiles : dict of Profiles + Profiles necessary to compute magnetic field. Either iota or current + profile must be given. + + """ + data_raw = compute_fun( + "desc.equilibrium.equilibrium.Equilibrium", + self.data_keys, + params, + self.transforms, + profiles, + ) + L, M, N = self.L, self.M_fft, self.N_fft + # e^zeta only has one component to fit, deal with it separately + keys3d = [key for key in self.data_keys if key != "e^zeta"] + keys = [key + i for key in keys3d for i in ["_r", "_p", "_z"]] + keys += ["e^zeta_p"] + # stack data to perform 12+1 transforms in batch + stacks = [ + data_raw[key][:, i].reshape(N, L, M) for key in keys3d for i in [0, 1, 2] + ] + stacks.append(data_raw["e^zeta"][:, 1].reshape(N, L, M)) + stacked_data = jnp.stack(stacks) # shape (13, N, L, M) + coefs = jax.scipy.fft.dct(stacked_data, axis=2, norm=None) + # handle the 0-th Chebyshev coefficient and normalization + coefs = coefs.at[:, :, 0, :].divide(2) + coefs /= self.L + + coefs = jnp.fft.fft(coefs, axis=3, norm=None) + coefs = jnp.fft.fft(coefs, axis=1, norm=None) + + data = {} + # stacking/unstacking is unnecessary + data["coefs_real"] = coefs.real + data["coefs_imag"] = coefs.imag + data["l"] = self.l + data["m"] = self.m + data["n"] = self.n + data["M"] = self.M_fft + data["N"] = self.N_fft + + self.params_dict = data + + def evaluate(self, rho, theta, zeta, params=None): + """Evaluate the Fourier-Chebyshev series at a point. + + Parameters + ---------- + rho, theta, zeta : float + Radial, poloidal and toroidal coordinates to evaluate. + params : dict + The spectral coefficients obtained from `FourierChebyshevField.fit()` + which is stored as `self.params_dict`. + """ + if params is None: + params = self.params_dict + + # the cosine transforms reverses the order + r0p = 1 - 2 * rho + Tl = jnp.cos(params["l"] * jnp.arccos(r0p)) + m_theta = params["m"] * theta + expm_real = jnp.cos(m_theta) / params["M"] + expm_imag = jnp.sin(m_theta) / params["M"] + # we computed and fitted the field for zeta in [0, 2pi/NFP] + # so we need to map zeta back to that range + zeta = (zeta * self.grid.NFP) % (2 * jnp.pi) + n_zeta = params["n"] * zeta + expn_real = jnp.cos(n_zeta) / params["N"] + expn_imag = jnp.sin(n_zeta) / params["N"] + + # "knlm,l->knm" contracts the 'l' dimension for all 'k' batches at once + f_l_real = jnp.einsum("knlm,l->knm", params["coefs_real"], Tl) + f_l_imag = jnp.einsum("knlm,l->knm", params["coefs_imag"], Tl) + + # "knm,m->kn" contracts the 'm' dimension for all 'k' batches + f_lm_real = jnp.einsum("knm,m->kn", f_l_real, expm_real) - jnp.einsum( + "knm,m->kn", f_l_imag, expm_imag + ) + f_lm_imag = jnp.einsum("knm,m->kn", f_l_real, expm_imag) + jnp.einsum( + "knm,m->kn", f_l_imag, expm_real + ) + + # "kn,n->k" contracts the 'n' dimension, leaving just the batch dimension + results = jnp.einsum("kn,n->k", f_lm_real, expn_real) - jnp.einsum( + "kn,n->k", f_lm_imag, expn_imag + ) + + out = {} + # Magnetic Field B + B = results[0:3] + out["|B|"] = jnp.linalg.norm(B) + out["b"] = B / out["|B|"] + # grad(|B|) + out["grad(|B|)"] = results[3:6] + # e^rho + out["e^rho"] = results[6:9] + # e^theta*rho + out["e^theta*rho"] = results[9:12] + # e^zeta + out["e^zeta"] = jnp.array([0, results[12], 0]) + + return out + + +class SplineFieldFlux(IOAble): + """Convenience class for splining and evaluating equilibrium fields. + + This class is intended to be used during particle tracing to reduce overhead + of creating transforms. It fits a 3D spline to the quantities required for + guiding center equations, and evaluates them at requested points during tracing. + + Parameters + ---------- + L : int + Radial resolution of the linear grid for spline nodes. + M : int + Poloidal resolution of the linear grid for spline nodes. + N : int + Toroidal resolution of the linear grid for spline nodes. + method : str + Method to use for 3D spline interpolation. See `interpax.interp3d` + for options. Defaults to 'cubic'. + """ + + _static_attrs = ["L", "M", "N", "data_keys", "method"] + + def __init__(self, L, M, N, method="cubic"): + self.L = L + self.M = M + self.N = N + self.method = method + + def build(self, eq): + """Build the constants for fit. + + During optimization, equilibrium field changes, however, the same transforms + can be used to get the fit faster. This method creates the grid and transforms + to be used during the fitting procedure. + + Parameters + ---------- + eq : Equilibrium + Equilibrium to be used to get transforms. + + """ + self.data_keys = ["B", "grad(|B|)", "e^rho", "e^theta*rho", "e^zeta"] + rho = jnp.linspace(1e-6, 1.0, self.L + 1) + self.grid = LinearGrid(rho=rho, M=self.M, N=self.N, sym=False, NFP=eq.NFP) + self.transforms = get_transforms(self.data_keys, eq, self.grid) + self.rhos = self.grid.nodes[self.grid.unique_rho_idx, 0] + self.thetas = self.grid.nodes[self.grid.unique_theta_idx, 1] + self.zetas = self.grid.nodes[self.grid.unique_zeta_idx, 2] + + def fit(self, params, profiles): + """Compute spline nodes for an equilibrium field. + + First computes the magnetic field, its gradient and basis vectors at + the grid created in build. Since e^theta doesn't behave + well around axis, the fit is computed for e^theta*rho (which is what actually + required by the guiding center equations). + + Parameters + ---------- + params : dict + Equilibriums `params_dict` which contains the parameters that define + the equiliubrium. + profiles : dict of Profiles + Profiles necessary to compute magnetic field. Either iota or current + profile must be given. + + """ + data = compute_fun( + "desc.equilibrium.equilibrium.Equilibrium", + self.data_keys, + params, + self.transforms, + profiles, + ) + L, M, N = self.grid.num_rho, self.grid.num_theta, self.grid.num_zeta + # e^zeta only has one component to fit, deal with it separately + keys3d = [key for key in self.data_keys if key != "e^zeta"] + keys = [key + i for key in keys3d for i in ["_r", "_p", "_z"]] + keys += ["e^zeta_p"] + # stack data to query 12+1 quantities in interp3d + stacks = [ + jnp.moveaxis(data[key][:, i].reshape(N, L, M), 0, -1) + for key in keys3d + for i in [0, 1, 2] + ] + stacks.append(jnp.moveaxis(data["e^zeta"][:, 1].reshape(N, L, M), 0, -1)) + self.interpolator = Interpolator3D( + self.rhos, + self.thetas, + self.zetas, + jnp.stack(stacks, axis=3), # shape (L, M, N, 13) + self.method, + extrap=False, + period=(None, 2 * jnp.pi, 2 * jnp.pi / self.grid.NFP), + ) + + def evaluate(self, rho, theta, zeta, params=None): + """Evaluate the 3D spline at a point. + + Parameters + ---------- + rho, theta, zeta : float + Radial, poloidal and toroidal coordinates to evaluate. + params : dict + Computed data at spline nodes stored as ``SplineFieldFlux.params_dict``. + """ + results = self.interpolator(rho, theta, zeta) + + out = {} + # Magnetic Field B + B = results[0:3] + out["|B|"] = jnp.linalg.norm(B) + out["b"] = B / out["|B|"] + # grad(|B|) + out["grad(|B|)"] = results[3:6] + # e^rho + out["e^rho"] = results[6:9] + # e^theta*rho + out["e^theta*rho"] = results[9:12] + # e^zeta + out["e^zeta"] = jnp.array([0, results[12], 0]) + + return out diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 1335503f31..976dee8e8b 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -117,6 +117,7 @@ "iota_den_rrr", "gds2", "(B*grad) grad(rho)", + "e^theta*rho", } diff --git a/tests/test_particles.py b/tests/test_particles.py index f13104e4ee..b068269dfd 100644 --- a/tests/test_particles.py +++ b/tests/test_particles.py @@ -5,6 +5,7 @@ from desc.backend import jit, jnp from desc.equilibrium import Equilibrium +from desc.examples import get from desc.geometry import FourierRZCurve, FourierRZToroidalSurface from desc.grid import Grid, LinearGrid from desc.magnetic_fields import ( @@ -14,6 +15,7 @@ ) from desc.particles import ( CurveParticleInitializer, + FourierChebyshevField, ManualParticleInitializerFlux, ManualParticleInitializerLab, SurfaceParticleInitializer, @@ -467,3 +469,80 @@ def test_init_curve_particles(): # smaller curve is out of larger equilibrium, so it should fail with pytest.raises(match="Mapping from lab to flux coordinates failed"): _, _ = particles.init_particles(model, eq_large) + + +@pytest.mark.unit +def test_FourierChebyshevField_interpolation(): + """Test the interpolation is accurate.""" + eq = get("precise_QA") + with pytest.warns(UserWarning): + eq.iota = eq.get_profile("iota") + interpolator = FourierChebyshevField(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid) + interpolator.build(eq=eq) + interpolator.fit( + params=eq.params_dict, profiles={"current": eq.current, "iota": eq.iota} + ) + + assert interpolator.params_dict["coefs_real"].shape == ( + 13, + interpolator.N_fft, + interpolator.L, + interpolator.M_fft, + ) + + rhos = np.linspace(0.05, 1.0, 5) + grid = LinearGrid(rho=rhos, M=3, N=3, NFP=eq.NFP, sym=eq.sym) + keys = ["|B|", "b", "grad(|B|)", "e^rho", "e^theta*rho", "e^zeta"] + data = eq.compute(keys, grid=grid) + + for i, coord in enumerate(grid.nodes): + rho, theta, zeta = coord + data_interp = interpolator.evaluate(rho, theta, zeta) + for key in keys: + msg = f"{key} mismatch at ρ={rho}, θ={theta}, ζ={zeta}" + np.testing.assert_allclose( + data[key][i], data_interp[key], rtol=5e-4, atol=5e-4, err_msg=msg + ) + + +@pytest.mark.unit +def test_FourierChebyshevField_model_vf(): + """Test the vector field evaluation using interpolated field.""" + eq = get("precise_QA") + iota = eq.get_profile("iota") + interpolator = FourierChebyshevField(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid) + interpolator.build(eq=eq) + interpolator.fit( + params=eq.params_dict, profiles={"current": eq.current, "iota": eq.iota} + ) + + model = VacuumGuidingCenterTrajectory(frame="flux") + rhos = np.linspace(0.05, 1.0, 3) + grid = LinearGrid(rho=rhos, M=2, N=2, NFP=eq.NFP, sym=eq.sym) + particles = ManualParticleInitializerFlux( + rho0=grid.nodes[:, 0], + theta0=grid.nodes[:, 1], + zeta0=grid.nodes[:, 2], + xi0=2 * np.random.rand(grid.num_nodes) - 1, + E=3.5e6, + ) + x0, args = particles.init_particles(model=model, field=eq) + + for xi, argsi in zip(x0, args): + rho, theta, zeta, vpar = xi + xp = rho * np.cos(theta) + yp = rho * np.sin(theta) + x = jnp.array([xp, yp, zeta, vpar]) + + params = eq.params_dict + params["i_l"] = iota.params + exact = model.vf(0, x=x, args=[argsi, eq, params, {"iota": iota}]) + interpolated = model.vf(0, x=x, args=[argsi, interpolator, None, {}]) + + comps = ["xp_dot", "yp_dot", "zeta_dot", "vpar_dot"] + + for i, comp in enumerate(comps): + msg = f"{comp} mismatch at ρ={rho}, θ={theta}, ζ={zeta}" + np.testing.assert_allclose( + exact[i], interpolated[i], rtol=2e-2, atol=1e-3, err_msg=msg + )