-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
60 lines (53 loc) · 2.04 KB
/
plot.py
File metadata and controls
60 lines (53 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from ase.io import read,write
import re
import dpdata
import os
from ase import db
import pandas as pd
from ase.db import connect
import matplotlib.pyplot as plt
from matplotlib import ticker
def loss_curve_plot(path,rmse=False,energy=True,force=True,mode='plot',savefig=False,savefigpath=None,show=False,):
from MLIP_processing.utils import loss_curve
epoch, rmse_total, rmse_energy, rmse_force = loss_curve(path,energy=energy,force=force)
fig,ax = plt.subplots()
if mode=='plot':
if rmse==True:
plt.plot(epoch,rmse_total,label='RMSE')
if energy == True:
plt.plot(epoch,rmse_energy,label='Energy_RMSE (eV)')
if force == True:
plt.plot(epoch,rmse_force,label='Force_RMSE (eV/Å)')
plt.ticklabel_format(style='plain')
elif mode=='loglog':
if rmse==True:
plt.loglog(epoch,rmse_total,label='RMSE')
if energy == True:
plt.loglog(epoch,rmse_energy,label='Energy_RMSE (eV)')
if force == True:
plt.loglog(epoch,rmse_force,label='Force_RMSE (eV/Å)')
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
elif mode=='semilogx':
if rmse==True:
plt.semilogx(epoch,rmse_total,label='RMSE')
if energy == True:
plt.semilogx(epoch,rmse_energy,label='Energy_RMSE (eV)')
if force == True:
plt.semilogx(epoch,rmse_force,label='Force_RMSE (eV/Å)')
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
elif mode=='semilogy':
if rmse==True:
plt.semilogy(epoch,rmse_total,label='RMSE')
if energy == True:
plt.semilogy(epoch,rmse_energy,label='Energy_RMSE (eV)')
if force == True:
plt.semilogy(epoch,rmse_force,label='Force_RMSE (eV/Å)')
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
plt.xlabel("Epoch")
plt.ylabel('RMSE')
plt.legend()
if(savefig==True):
plt.savefig(savefigpath)
if(show==True):
plt.show()