|
| 1 | +import numpy as np |
| 2 | +import xarray as xr |
| 3 | +import nt2.read as nt2r |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +from tqdm import tqdm |
| 6 | + |
| 7 | +def parallel(func, steps, dataset, num_cpus=None): |
| 8 | + import multiprocessing as mp |
| 9 | + import numpy as np # 添加numpy导入 |
| 10 | + |
| 11 | + if num_cpus is None: |
| 12 | + num_cpus = mp.cpu_count() |
| 13 | + |
| 14 | + global calculate |
| 15 | + def calculate(t): |
| 16 | + try: |
| 17 | + value = func(t, dataset) |
| 18 | + except Exception as e: |
| 19 | + print(f"Error in processing {t}: {e}") |
| 20 | + return t, None |
| 21 | + return t, value |
| 22 | + |
| 23 | + # 初始化多进程池 |
| 24 | + pool = mp.Pool(num_cpus) |
| 25 | + try: |
| 26 | + # 添加进度条 |
| 27 | + results = [pool.apply_async(calculate, args=(t,)) for t in tqdm(steps)] |
| 28 | + pool.close() # 关闭输入通道 |
| 29 | + pool.join() # 等待所有进程完成 |
| 30 | + except Exception as e: |
| 31 | + pool.terminate() # 遇到异常时终止所有进程 |
| 32 | + print(f"Error during multiprocessing: {e}") |
| 33 | + raise |
| 34 | + |
| 35 | + # 获取结果并排序 |
| 36 | + sorted_results = sorted([r.get() for r in results], key=lambda x: x[0]) |
| 37 | + |
| 38 | + # 提取结果值为numpy数组 |
| 39 | + return np.array([value for t, value in sorted_results]) |
| 40 | + |
| 41 | + |
| 42 | +def compute_spectrum(field, dx, k_min=None, k_max=None, num_bins=200): |
| 43 | + Ny, Nx = field.shape |
| 44 | + power_spectrum = np.abs(np.fft.fftshift(np.fft.fft2(field)))**2 |
| 45 | + dkx = 2 * np.pi / (Nx * dx) |
| 46 | + dky = 2 * np.pi / (Ny * dx) |
| 47 | + kx, ky = np.meshgrid( |
| 48 | + np.linspace(-Nx/2 * dkx, Nx/2 * dkx - dkx, Nx), |
| 49 | + np.linspace(-Ny/2 * dky, Ny/2 * dky - dky, Ny) |
| 50 | + ) |
| 51 | + k_mag = np.sqrt(kx**2 + ky**2).flatten() |
| 52 | + power_spectrum_flatten = power_spectrum.flatten() |
| 53 | + if k_max is None: |
| 54 | + k_max = np.max(k_mag) |
| 55 | + if k_min is None: |
| 56 | + k_min = np.min(k_mag[k_mag > 0]) |
| 57 | + k_bins = np.linspace(k_min, k_max, num=num_bins) |
| 58 | + k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:]) |
| 59 | + dk = k_bins[1] - k_bins[0] |
| 60 | + k_indices = np.digitize(k_mag, k_bins, right=False) |
| 61 | + power_spectrum_binned = np.array([ |
| 62 | + np.sum(power_spectrum_flatten[k_indices == i]) if np.any(k_indices == i) else 0 |
| 63 | + for i in range(1, len(k_bins)) |
| 64 | + ]) |
| 65 | + power_spectrum_binned /= (Nx * Ny ) |
| 66 | + print(power_spectrum_binned) |
| 67 | + return k_bin_centers, power_spectrum_binned |
| 68 | + |
| 69 | +def compute_L(field, dx, k_min=None, k_max=None, num_bins=200): |
| 70 | + k_bin_centers, power_spectrum_binned = compute_spectrum(field, dx, k_min, k_max, num_bins) |
| 71 | + return np.dot(1.0 / k_bin_centers, power_spectrum_binned) / np.sum(power_spectrum_binned) |
| 72 | + |
| 73 | +class Visualizer(nt2r.Data): |
| 74 | + def __init__(self, filename, d0, rho0, num_cpus): |
| 75 | + super().__init__(filename) |
| 76 | + self.sigma0 = (d0 / rho0)**2 |
| 77 | + self.num_cpus = num_cpus |
| 78 | + self.times = self.coords['t'].values |
| 79 | + self.dx = self.coords['x'].values[1] - self.coords['x'].values[0] |
| 80 | + self.field_map = { |
| 81 | + 'B2': lambda data: (data.Bx**2 + data.By**2 + data.Bz**2) , |
| 82 | + 'E2': lambda data: (data.Ex**2 + data.Ey**2 + data.Ez**2) , |
| 83 | + 'EM_Energy': lambda data: 0.5 * (data.Ex**2 + data.Ey**2 + data.Ez**2 + |
| 84 | + data.Bx**2 + data.By**2 + data.Bz**2) * self.sigma0, |
| 85 | + 'Prtl_Energy': lambda data: data.T00, |
| 86 | + 'Total_Energy': lambda data: 0.5 * (data.Ex**2 + data.Ey**2 + data.Ez**2 + |
| 87 | + data.Bx**2 + data.By**2 + data.Bz**2) * self.sigma0+ data.T00, |
| 88 | + 'N' : lambda data: data.N_1 + data.N_2, |
| 89 | + 'Bxy_Energy' : lambda data: 0.5 * (data.Bx**2 + data.By**2) * self.sigma0, |
| 90 | + } |
| 91 | + |
| 92 | + def set_cpu(self, num_cpus): |
| 93 | + self.num_cpus = num_cpus |
| 94 | + |
| 95 | + def set_para(self, d0, rho0): |
| 96 | + self.sigma0 = (d0 / rho0)**2 |
| 97 | + |
| 98 | + def get_field(self, name): |
| 99 | + if name in self.field_map: |
| 100 | + return self.field_map[name](self) |
| 101 | + else: |
| 102 | + raise ValueError(f"{name} not found.") |
| 103 | + |
| 104 | + def get_means(self, name, times=None): |
| 105 | + if times is None: |
| 106 | + times = self.times |
| 107 | + if name in self.field_map: |
| 108 | + field = self.field_map[name](self) |
| 109 | + elif hasattr(self, name): |
| 110 | + field = getattr(self, name) |
| 111 | + else: |
| 112 | + raise ValueError("Invalid type.") |
| 113 | + return parallel(lambda t, data: data.sel({'t':t}, method='nearest').mean(('x', 'y')).compute().item(), |
| 114 | + times, |
| 115 | + field, |
| 116 | + self.num_cpus) |
| 117 | + |
| 118 | + def get_spectrum(self, name, t, k_min=None, k_max=None, num_bins=200): |
| 119 | + if name in self.field_map: |
| 120 | + field = self.field_map[name](self) |
| 121 | + elif hasattr(self, name): |
| 122 | + field = getattr(self, name) |
| 123 | + else: |
| 124 | + raise ValueError("Invalid type.") |
| 125 | + field = field.sel({'t': t}, method='nearest').values |
| 126 | + print(field) |
| 127 | + k_bins, powers = compute_spectrum(field, self.dx, k_min, k_max, num_bins) |
| 128 | + return k_bins, powers |
| 129 | + |
| 130 | + def get_L(self, name, times=None, k_min=None, k_max=None, num_bins=200): |
| 131 | + if times is None: |
| 132 | + times = self.times |
| 133 | + if name in self.field_map: |
| 134 | + field = self.field_map[name](self) |
| 135 | + elif hasattr(self, name): |
| 136 | + field = getattr(self, name) |
| 137 | + else: |
| 138 | + raise ValueError("Invalid type.") |
| 139 | + return parallel(lambda t, data: compute_L(data.sel({'t': t}, method='nearest'), self.dx, k_min, k_max, num_bins), |
| 140 | + times, |
| 141 | + field, |
| 142 | + self.num_cpus) |
| 143 | + |
| 144 | + def decay_rate(self, name, times=None, **kwargs): |
| 145 | + if times is None: |
| 146 | + times = self.times |
| 147 | + ts = times[1:] |
| 148 | + Qs = self.get_means(name, ts) |
| 149 | + rate = np.array([np.log(Qs[i-1] / Qs[i+1]) / np.log(ts[i+1] / ts[i-1]) for i in tqdm(range(1, len(ts) - 1))]) |
| 150 | + plt.plot(ts[1:-1], rate, **kwargs) |
| 151 | + plt.savefig("decay_{}.png".format(name), dpi=300, bbox_inches="tight") |
| 152 | + np.savetxt("decay_{}.dat".format(name), np.column_stack((ts[1:-1], rate))) |
| 153 | + |
| 154 | + def increase_rate_L(self, name, times=None, k_min=None, k_max=None, num_bins=200, **kwargs): |
| 155 | + if times is None: |
| 156 | + times = self.times |
| 157 | + ts = times[1:] |
| 158 | + Qs = self.get_L(name, ts, k_min, k_max, num_bins) |
| 159 | + rate = np.array([np.log(Qs[i+1] / Qs[i-1]) / np.log(ts[i+1] / ts[i-1]) for i in range(1, len(ts) - 1)]) |
| 160 | + plt.plot(ts[1:-1], rate, **kwargs) |
| 161 | + plt.savefig("L_{}.png".format(name), dpi=300, bbox_inches="tight") |
| 162 | + np.savetxt("L_{}.dat".format(name), np.column_stack((ts[1:-1], rate))) |
| 163 | + |
| 164 | + def plot_mean(self, name, times=None, xscale=None, yscale=None, **kwargs): |
| 165 | + if times is None: |
| 166 | + times = self.times |
| 167 | + means = self.get_means(name, times) |
| 168 | + plt.plot(times, means, **kwargs) |
| 169 | + if xscale is not None: |
| 170 | + plt.xscale(xscale) |
| 171 | + if yscale is not None: |
| 172 | + plt.yscale(yscale) |
| 173 | + plt.savefig("mean_{}.png".format(name), dpi=300, bbox_inches="tight") |
| 174 | + plt.close() |
| 175 | + np.savetxt("mean_{}.dat".format(name), np.column_stack((times, means))) |
| 176 | + |
| 177 | + def plot_spectrum(self, name, t, num_bins=200, y_min=None, y_max=None, **kwargs): |
| 178 | + k_bins, powers = self.get_spectrum(name, t, num_bins=num_bins) |
| 179 | + print(powers) |
| 180 | + plt.plot(k_bins, powers, **kwargs) |
| 181 | + if y_max is None: |
| 182 | + y_max = np.max(powers) |
| 183 | + if y_min is None: |
| 184 | + y_min = y_max / 1e6 |
| 185 | + plt.ylim([y_min, y_max]) |
| 186 | + plt.xscale('log') |
| 187 | + plt.yscale('log') |
| 188 | + |
| 189 | + def field_line(self, name, t, density=2): |
| 190 | + x = np.array(self.coords['x'].values) |
| 191 | + y = np.array(self.coords['y'].values) |
| 192 | + X, Y = np.meshgrid(x,y) |
| 193 | + if name=='magnetic': |
| 194 | + vx = self.Bx.sel({'t':t}, method='nearest') |
| 195 | + vy = self.By.sel({'t':t}, method='nearest') |
| 196 | + else: |
| 197 | + raise ValueError("Invalid type.") |
| 198 | + plt.figure(figsize=(8, 6)) |
| 199 | + plt.streamplot(X, Y, vx, vy, density=density) |
| 200 | + plt.xlabel("x") |
| 201 | + plt.ylabel("y") |
| 202 | + plt.axis('equal') |
| 203 | + |
| 204 | + |
| 205 | + |
| 206 | + |
| 207 | + |
| 208 | + |
| 209 | + |
| 210 | + |
| 211 | + |
0 commit comments