Skip to content

Commit 1038e37

Browse files
updating plot.py
1 parent 64abe83 commit 1038e37

1 file changed

Lines changed: 171 additions & 3 deletions

File tree

  • setups/srpic/decay_turbulence

setups/srpic/decay_turbulence/plot.py

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,174 @@
22
import matplotlib.pyplot as plt
33
import matplotlib as mpl
44
import numpy as np
5+
from tqdm import tqdm
56

6-
data = nt2.Data(path="turbulence")
7+
sigma0 = 16
78

8-
#data.fields.inspect.plot(name="inspect", only_fields=["N", "Jz"])
9+
field_map = {
10+
'B2': lambda data: (data.Bx**2 + data.By**2 + data.Bz**2) ,
11+
'E2': lambda data: (data.Ex**2 + data.Ey**2 + data.Ez**2) ,
12+
'EM_Energy': lambda data: 0.5 * (data.Ex**2 + data.Ey**2 + data.Ez**2 +
13+
data.Bx**2 + data.By**2 + data.Bz**2) * sigma0,
14+
'Prtl_Energy': lambda data: data.T00,
15+
'Total_Energy': lambda data: 0.5 * (data.Ex**2 + data.Ey**2 + data.Ez**2 +
16+
data.Bx**2 + data.By**2 + data.Bz**2) * sigma0+ data.T00,
17+
'N' : lambda data: data.N_1 + data.N_2,
18+
'Bxy_Energy' : lambda data: 0.5 * (data.Bx**2 + data.By**2) * sigma0,
19+
}
20+
21+
def parallel(func, steps, dataset, num_cpus=None):
22+
import multiprocessing as mp
23+
import numpy as np # 添加numpy导入
24+
25+
if num_cpus is None:
26+
num_cpus = mp.cpu_count()
27+
28+
global calculate
29+
def calculate(t):
30+
try:
31+
value = func(t, dataset)
32+
except Exception as e:
33+
print(f"Error in processing {t}: {e}")
34+
return t, None
35+
return t, value
36+
37+
# 初始化多进程池
38+
pool = mp.Pool(num_cpus)
39+
try:
40+
# 添加进度条
41+
results = [pool.apply_async(calculate, args=(t,)) for t in tqdm(steps)]
42+
pool.close() # 关闭输入通道
43+
pool.join() # 等待所有进程完成
44+
except Exception as e:
45+
pool.terminate() # 遇到异常时终止所有进程
46+
print(f"Error during multiprocessing: {e}")
47+
raise
48+
49+
# 获取结果并排序
50+
sorted_results = sorted([r.get() for r in results], key=lambda x: x[0])
51+
52+
# 提取结果值为numpy数组
53+
return np.array([value for t, value in sorted_results])
54+
55+
56+
def get_means(data, times, name, num_cpus=4):
57+
if name in field_map:
58+
field = field_map[name](data)
59+
elif hasattr(data, name):
60+
field = getattr(data, name)
61+
else:
62+
raise ValueError("Invalid name.")
63+
return parallel(lambda t, data: data.sel({'t':t}, method='nearest').mean(('x', 'y')).compute().item(),
64+
times,
65+
field,
66+
num_cpus)
67+
68+
def plot_means(data, times, name, num_cpus=4):
69+
means = get_means(data, times, name, num_cpus)
70+
plt.plot(times, means)
71+
plt.xscale("log")
72+
plt.yscale("log")
73+
plt.savefig("mean_{}.png".format(name), dpi=100, bbox_inches="tight")
74+
plt.close()
75+
np.savetxt("mean_{}.dat".format(name), np.column_stack((times, means)))
76+
77+
78+
def decay_rate(data, name, ts, num_cpus=4):
79+
Qs = get_means(data, ts, name, num_cpus)
80+
rate = np.array([np.log(Qs[i-1] / Qs[i+1]) / np.log(ts[i+1] / ts[i-1]) for i in tqdm(range(1, len(ts) - 1))])
81+
plt.plot(ts[1:-1], rate)
82+
plt.savefig("decay_{}.png".format(name), dpi=100, bbox_inches="tight")
83+
np.savetxt("decay_{}.dat".format(name), np.column_stack((ts[1:-1], rate)))
84+
85+
def compute_spectrum(field, dx, k_min=None, k_max=None, num_bins=200, use_log_bins=True):
86+
"""
87+
计算二维场的径向功率谱。
88+
89+
参数:
90+
field: 2D numpy数组,输入场
91+
dx: 空间分辨率
92+
k_min: 最小波数 (可选)
93+
k_max: 最大波数 (可选)
94+
num_bins: 波数bin的数量
95+
use_log_bins: 是否使用对数间隔的波数bins
96+
97+
返回:
98+
k_bin_centers: 波数bin中心点
99+
power_spectrum_binned: 相应的功率谱密度
100+
"""
101+
Ny, Nx = field.shape
102+
103+
# 计算FFT并取幅度平方
104+
power_spectrum = np.abs(np.fft.fftshift(np.fft.fft2(field)))**2
105+
106+
# 计算波数网格
107+
dkx = 2 * np.pi / (Nx * dx)
108+
dky = 2 * np.pi / (Ny * dx)
109+
kx, ky = np.meshgrid(
110+
np.linspace(-Nx/2 * dkx, Nx/2 * dkx - dkx, Nx),
111+
np.linspace(-Ny/2 * dky, Ny/2 * dky - dky, Ny)
112+
)
113+
114+
# 计算波数幅度
115+
k_mag = np.sqrt(kx**2 + ky**2).flatten()
116+
power_spectrum_flatten = power_spectrum.flatten()
117+
118+
# 设置波数范围
119+
if k_max is None:
120+
k_max = np.max(k_mag)
121+
if k_min is None:
122+
k_min = 0.0
123+
124+
125+
k_bins = np.linspace(k_min, k_max, num=num_bins)
126+
k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])
127+
128+
power_spectrum_binned = np.zeros(len(k_bin_centers))
129+
130+
for i in range(len(k_bins) - 1):
131+
bin_mask = (k_mag >= k_bins[i]) & (k_mag < k_bins[i + 1])
132+
if np.sum(bin_mask) > 0:
133+
power_spectrum_binned[i] = np.sum(power_spectrum_flatten[bin_mask])
134+
135+
# 归一化
136+
power_spectrum_binned /= (Nx * Ny) # 归一化FFT
137+
138+
139+
return k_bin_centers, power_spectrum_binned
140+
141+
def spectrum(t, data):
142+
frame = data.sel({'t':t}, method="nearest")
143+
dx = data.coords['x'].values[1] - data.coords['x'].values[0]
144+
k_bins, powers = compute_spectrum(frame.values, dx, None, None, 200, True)
145+
plt.plot(k_bins, powers)
146+
#plt.xscale("log")
147+
plt.yscale("log")
148+
149+
150+
def compute_L(field, dx, k_min=None, k_max=None, num_bins=200):
151+
k_bin_centers, power_spectrum_binned = compute_spectrum(field, dx, k_min, k_max, num_bins)
152+
return np.dot(1.0 / k_bin_centers, power_spectrum_binned) / np.sum(power_spectrum_binned)
153+
154+
def get_L(data, name, times, k_min=None, k_max=None, num_bins=200, num_cpus=4):
155+
if name in field_map:
156+
field = field_map[name](data)
157+
elif hasattr(data, name):
158+
field = getattr(data, name)
159+
else:
160+
raise ValueError("Invalid type.")
161+
dx = data.coords['x'].values[1] - data.coords['x'].values[0]
162+
return parallel(lambda t, data: compute_L(data.sel({'t': t}, method='nearest'), dx, k_min, k_max, num_bins),
163+
times,
164+
field,
165+
num_cpus)
166+
167+
def increase_rate_L(data, name, ts, k_min=None, k_max=None, num_bins=200, num_cpus=4):
168+
Qs = get_L(data, name, ts, k_min, k_max, num_bins, num_cpus)
169+
rate = np.array([np.log(Qs[i+1] / Qs[i-1]) / np.log(ts[i+1] / ts[i-1]) for i in range(1, len(ts) - 1)])
170+
plt.plot(ts[1:-1], rate)
171+
plt.savefig("L_{}.png".format(name), dpi=300, bbox_inches="tight")
172+
np.savetxt("L_{}.dat".format(name), np.column_stack((ts[1:-1], rate)))
9173

10174
def plot_spectra(t, data):
11175
frame = data.sel({'t':t}, method="nearest")
@@ -26,12 +190,16 @@ def plot_func(t, fld):
26190
fld.sel({'t':t}, method='nearest').plot(ax=ax, norm=mpl.colors.Normalize(vmin, vmax), cmap=colormap)
27191

28192
def main():
193+
data = nt2.Data(path="turbulence")
29194
num_cpus = 32
30195
times = np.linspace(0, 1000, 200)
196+
plot_means(data.fields, times, 'EM_Energy', num_cpus=num_cpus)
197+
plot_means(data.fields, times, 'Total_Energy', num_cpus=num_cpus)
198+
plot_means(data.fields, times, 'Prtl_Energy', num_cpus=num_cpus)
31199
#sp = data.spectra
32200

33201
#nt2.export.makeFrames(plot_spectra, times, 'spectra', sp, num_cpus=num_cpus)
34202
#nt2.export.makeFrames(plot_func, times, 'N', data.fields.N, num_cpus=num_cpus)
35-
print(data)
203+
36204
if __name__ == '__main__':
37205
main()

0 commit comments

Comments
 (0)