Skip to content

Commit 79d06bd

Browse files
Merge branch 'dev/decay_turb-sugon' of https://github.com/StaticObserver/entity into dev/decay_turb-sugon
2 parents e3d29cc + 377f981 commit 79d06bd

2 files changed

Lines changed: 212 additions & 0 deletions

File tree

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import numpy as np
2+
import xarray as xr
3+
import nt2.read as nt2r
4+
import matplotlib.pyplot as plt
5+
from tqdm import tqdm
6+
7+
def parallel(func, steps, dataset, num_cpus=None):
8+
import multiprocessing as mp
9+
import numpy as np # 添加numpy导入
10+
11+
if num_cpus is None:
12+
num_cpus = mp.cpu_count()
13+
14+
global calculate
15+
def calculate(t):
16+
try:
17+
value = func(t, dataset)
18+
except Exception as e:
19+
print(f"Error in processing {t}: {e}")
20+
return t, None
21+
return t, value
22+
23+
# 初始化多进程池
24+
pool = mp.Pool(num_cpus)
25+
try:
26+
# 添加进度条
27+
results = [pool.apply_async(calculate, args=(t,)) for t in tqdm(steps)]
28+
pool.close() # 关闭输入通道
29+
pool.join() # 等待所有进程完成
30+
except Exception as e:
31+
pool.terminate() # 遇到异常时终止所有进程
32+
print(f"Error during multiprocessing: {e}")
33+
raise
34+
35+
# 获取结果并排序
36+
sorted_results = sorted([r.get() for r in results], key=lambda x: x[0])
37+
38+
# 提取结果值为numpy数组
39+
return np.array([value for t, value in sorted_results])
40+
41+
42+
def compute_spectrum(field, dx, k_min=None, k_max=None, num_bins=200):
43+
Ny, Nx = field.shape
44+
power_spectrum = np.abs(np.fft.fftshift(np.fft.fft2(field)))**2
45+
dkx = 2 * np.pi / (Nx * dx)
46+
dky = 2 * np.pi / (Ny * dx)
47+
kx, ky = np.meshgrid(
48+
np.linspace(-Nx/2 * dkx, Nx/2 * dkx - dkx, Nx),
49+
np.linspace(-Ny/2 * dky, Ny/2 * dky - dky, Ny)
50+
)
51+
k_mag = np.sqrt(kx**2 + ky**2).flatten()
52+
power_spectrum_flatten = power_spectrum.flatten()
53+
if k_max is None:
54+
k_max = np.max(k_mag)
55+
if k_min is None:
56+
k_min = np.min(k_mag[k_mag > 0])
57+
k_bins = np.linspace(k_min, k_max, num=num_bins)
58+
k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])
59+
dk = k_bins[1] - k_bins[0]
60+
k_indices = np.digitize(k_mag, k_bins, right=False)
61+
power_spectrum_binned = np.array([
62+
np.sum(power_spectrum_flatten[k_indices == i]) if np.any(k_indices == i) else 0
63+
for i in range(1, len(k_bins))
64+
])
65+
power_spectrum_binned /= (Nx * Ny )
66+
print(power_spectrum_binned)
67+
return k_bin_centers, power_spectrum_binned
68+
69+
def compute_L(field, dx, k_min=None, k_max=None, num_bins=200):
70+
k_bin_centers, power_spectrum_binned = compute_spectrum(field, dx, k_min, k_max, num_bins)
71+
return np.dot(1.0 / k_bin_centers, power_spectrum_binned) / np.sum(power_spectrum_binned)
72+
73+
class Visualizer(nt2r.Data):
74+
def __init__(self, filename, d0, rho0, num_cpus):
75+
super().__init__(filename)
76+
self.sigma0 = (d0 / rho0)**2
77+
self.num_cpus = num_cpus
78+
self.times = self.coords['t'].values
79+
self.dx = self.coords['x'].values[1] - self.coords['x'].values[0]
80+
self.field_map = {
81+
'B2': lambda data: (data.Bx**2 + data.By**2 + data.Bz**2) ,
82+
'E2': lambda data: (data.Ex**2 + data.Ey**2 + data.Ez**2) ,
83+
'EM_Energy': lambda data: 0.5 * (data.Ex**2 + data.Ey**2 + data.Ez**2 +
84+
data.Bx**2 + data.By**2 + data.Bz**2) * self.sigma0,
85+
'Prtl_Energy': lambda data: data.T00,
86+
'Total_Energy': lambda data: 0.5 * (data.Ex**2 + data.Ey**2 + data.Ez**2 +
87+
data.Bx**2 + data.By**2 + data.Bz**2) * self.sigma0+ data.T00,
88+
'N' : lambda data: data.N_1 + data.N_2,
89+
'Bxy_Energy' : lambda data: 0.5 * (data.Bx**2 + data.By**2) * self.sigma0,
90+
}
91+
92+
def set_cpu(self, num_cpus):
93+
self.num_cpus = num_cpus
94+
95+
def set_para(self, d0, rho0):
96+
self.sigma0 = (d0 / rho0)**2
97+
98+
def get_field(self, name):
99+
if name in self.field_map:
100+
return self.field_map[name](self)
101+
else:
102+
raise ValueError(f"{name} not found.")
103+
104+
def get_means(self, name, times=None):
105+
if times is None:
106+
times = self.times
107+
if name in self.field_map:
108+
field = self.field_map[name](self)
109+
elif hasattr(self, name):
110+
field = getattr(self, name)
111+
else:
112+
raise ValueError("Invalid type.")
113+
return parallel(lambda t, data: data.sel({'t':t}, method='nearest').mean(('x', 'y')).compute().item(),
114+
times,
115+
field,
116+
self.num_cpus)
117+
118+
def get_spectrum(self, name, t, k_min=None, k_max=None, num_bins=200):
119+
if name in self.field_map:
120+
field = self.field_map[name](self)
121+
elif hasattr(self, name):
122+
field = getattr(self, name)
123+
else:
124+
raise ValueError("Invalid type.")
125+
field = field.sel({'t': t}, method='nearest').values
126+
print(field)
127+
k_bins, powers = compute_spectrum(field, self.dx, k_min, k_max, num_bins)
128+
return k_bins, powers
129+
130+
def get_L(self, name, times=None, k_min=None, k_max=None, num_bins=200):
131+
if times is None:
132+
times = self.times
133+
if name in self.field_map:
134+
field = self.field_map[name](self)
135+
elif hasattr(self, name):
136+
field = getattr(self, name)
137+
else:
138+
raise ValueError("Invalid type.")
139+
return parallel(lambda t, data: compute_L(data.sel({'t': t}, method='nearest'), self.dx, k_min, k_max, num_bins),
140+
times,
141+
field,
142+
self.num_cpus)
143+
144+
def decay_rate(self, name, times=None, **kwargs):
145+
if times is None:
146+
times = self.times
147+
ts = times[1:]
148+
Qs = self.get_means(name, ts)
149+
rate = np.array([np.log(Qs[i-1] / Qs[i+1]) / np.log(ts[i+1] / ts[i-1]) for i in tqdm(range(1, len(ts) - 1))])
150+
plt.plot(ts[1:-1], rate, **kwargs)
151+
plt.savefig("decay_{}.png".format(name), dpi=300, bbox_inches="tight")
152+
np.savetxt("decay_{}.dat".format(name), np.column_stack((ts[1:-1], rate)))
153+
154+
def increase_rate_L(self, name, times=None, k_min=None, k_max=None, num_bins=200, **kwargs):
155+
if times is None:
156+
times = self.times
157+
ts = times[1:]
158+
Qs = self.get_L(name, ts, k_min, k_max, num_bins)
159+
rate = np.array([np.log(Qs[i+1] / Qs[i-1]) / np.log(ts[i+1] / ts[i-1]) for i in range(1, len(ts) - 1)])
160+
plt.plot(ts[1:-1], rate, **kwargs)
161+
plt.savefig("L_{}.png".format(name), dpi=300, bbox_inches="tight")
162+
np.savetxt("L_{}.dat".format(name), np.column_stack((ts[1:-1], rate)))
163+
164+
def plot_mean(self, name, times=None, xscale=None, yscale=None, **kwargs):
165+
if times is None:
166+
times = self.times
167+
means = self.get_means(name, times)
168+
plt.plot(times, means, **kwargs)
169+
if xscale is not None:
170+
plt.xscale(xscale)
171+
if yscale is not None:
172+
plt.yscale(yscale)
173+
plt.savefig("mean_{}.png".format(name), dpi=300, bbox_inches="tight")
174+
plt.close()
175+
np.savetxt("mean_{}.dat".format(name), np.column_stack((times, means)))
176+
177+
def plot_spectrum(self, name, t, num_bins=200, y_min=None, y_max=None, **kwargs):
178+
k_bins, powers = self.get_spectrum(name, t, num_bins=num_bins)
179+
print(powers)
180+
plt.plot(k_bins, powers, **kwargs)
181+
if y_max is None:
182+
y_max = np.max(powers)
183+
if y_min is None:
184+
y_min = y_max / 1e6
185+
plt.ylim([y_min, y_max])
186+
plt.xscale('log')
187+
plt.yscale('log')
188+
189+
def field_line(self, name, t, density=2):
190+
x = np.array(self.coords['x'].values)
191+
y = np.array(self.coords['y'].values)
192+
X, Y = np.meshgrid(x,y)
193+
if name=='magnetic':
194+
vx = self.Bx.sel({'t':t}, method='nearest')
195+
vy = self.By.sel({'t':t}, method='nearest')
196+
else:
197+
raise ValueError("Invalid type.")
198+
plt.figure(figsize=(8, 6))
199+
plt.streamplot(X, Y, vx, vy, density=density)
200+
plt.xlabel("x")
201+
plt.ylabel("y")
202+
plt.axis('equal')
203+
204+
205+
206+
207+
208+
209+
210+
211+

src/engines/engine.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ namespace ntt {
6262
adios2::ADIOS m_adios;
6363
#endif
6464
#endif
65+
#endif // OUTPUT_ENABLED
6566

6667
SimulationParams m_params;
6768
Metadomain<S, M> m_metadomain;

0 commit comments

Comments
 (0)