Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions examples/jax/01_slater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pyscf.pbc.gto as gto
import pyscf.gto as molgto
import pyscf.pbc.dft as dft
import numpy
import pandas as pd


def make_cell_basis():
"""
Here we make an uncontracted basis for the cell.
Our JAX implementation requires this for efficient evaluation,
and often for solids the uncontracted basis is much better because the
others
"""
cell = gto.Cell()
cell.atom = '''C 0. 0. 0.
C 0.8917 0.8917 0.8917
C 1.7834 1.7834 0.
C 2.6751 2.6751 0.8917
C 1.7834 0. 1.7834
C 2.6751 0.8917 2.6751
C 0. 1.7834 1.7834
C 0.8917 2.6751 2.6751'''
cell.basis = 'unc-ccecp-ccpvtz'
cell.pseudo = 'ccecp' #These are high accuracy ECP's for QMC.
cell.cart = True # also important for JAX efficiency.
cell.a = numpy.eye(3)*3.5668
cell.build()
return cell



def run_dft(cell, chkfile):
mf = dft.RKS(cell)
mf.xc = 'lda,vwn'
mf = mf.multigrid_numint()
mf.chkfile = chkfile
mf.kernel()
return mf.e_tot


def generate_etb_set(cell, alpha0=0.2, l_polarization=2):
"""
Generate an even tempered basis which is selected to best reproduce the
basis in cell.

You can use this as written for the first 3 rows and all sp elements.

cell is a cell object with a given basis
alpha0 is the longest range you would like to go. typically 0.2 or 0.1
l_polarization is how many angular momentum functions you'd like to allow
"""
#print(cell._basis)
new_basis ={}
for atomname, basis in cell._basis.items():
maxl_contract = 1 + l_polarization
# If you are doing a F-electron or 4d or 5d element then you
# need to update this.
if atomname in ['Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu']:
maxl_contract = 2+l_polarization

# in this loop find the minimum and maximum exponents
# for each angular momentum channel.
maxexp = numpy.zeros(maxl_contract+1)
minexp = numpy.ones(maxl_contract+1)*1000
for element in basis:
#print(element)
l = element[0]
exponents = numpy.max([e[0] for e in element[1:]])
#print(exponents)
if l <= maxl_contract:
maxexp[l] = numpy.max([exponents, maxexp[l]])
minexp[l] = numpy.min([exponents, minexp[l]])

# Truncate the minimum exponent to alpha0
minexp[minexp < alpha0] = alpha0
# Now
etbs = []
for l, maxe in enumerate(maxexp):
n = int(numpy.log2(maxe/minexp[l]))+2
etbs.append((l, n, minexp[l], 2))
new_basis[atomname] = molgto.etbs(etbs)
newcell = cell.copy()
newcell.basis = new_basis
newcell.build()
return newcell


if __name__ == "__main__":
cell_orig = make_cell_basis()
cell_etb = generate_etb_set(cell_orig, alpha0=0.2, l_polarization=2)
cell_exp2 = cell_orig.copy()
cell_exp2.exp_to_discard = 0.2
cell_exp2.build()

# Sometimes an even tempered basis is better than the uncontract and expand, and
# sometimes not.
print("etb basis", run_dft(cell_etb, 'etb0.2.hdf5'))
print("uncontracted + exp_to_discard", run_dft(cell_exp2, 'exp_to_discard0.2.hdf5'))

# You should probably also try decreasing alpha0, depending on the material.

86 changes: 86 additions & 0 deletions examples/jax/02_check_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
# here we are desperately trying to avoid multithreading
# to get a good read on the single-threaded performance
# you may or many not want this depending on your use case
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["XLA_FLAGS"] = (
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1"
)
os.environ["JAX_NUM_CLIENTS"] = "1"
os.environ["NPROC"] = "1"

import pyqmc.api as pyq
import jax
import time
import numpy as np
import pandas as pd

# you may or may not want to set this.
jax.config.update('jax_platform_name', 'cpu')
# you almost always want 64-bit math.
jax.config.update("jax_enable_x64", True)
print(jax.devices())


def check_value(configs, wf_jax, wf_pyscf):
"""
print out timing and difference in values for recompute()
"""
vals_jax = wf_jax.recompute(configs)
jax.block_until_ready(vals_jax)
start = time.perf_counter()
vals_jax = wf_jax.recompute(configs)
jax.block_until_ready(vals_jax)
jax_time = time.perf_counter()
vals_pyscf = wf_pyscf.recompute(configs)
pyscf = time.perf_counter()
print(f"JAX time {jax_time - start} s PYSCF time {pyscf - jax_time} s "
f" Difference {np.mean(np.abs(vals_jax[0] - vals_pyscf[0]))}")


def check_energy(configs, wf_jax, wf_pyscf) -> pd.DataFrame:
"""
Check the various energies.
Note that between old and new ECPS they should only agree on average

"""
enacc = {'old':pyq.EnergyAccumulator(cell, use_old_ecp=True),
'new': pyq.EnergyAccumulator(cell, use_old_ecp=False) }
wfs = {'jax': wf_jax, 'pyscf': wf_pyscf}
data = []
for ecp in enacc.keys():
for wf in wfs.keys():
if wf == 'jax': #force compile
enacc[ecp](configs, wfs[wf])
start = time.perf_counter()
en = enacc[ecp](configs, wfs[wf])
end = time.perf_counter()
data.append({ 'time': end - start,
'wf':wf,
'ecp':ecp,
'ecp_en': np.mean(en['ecp']),
'grad2': np.mean(en['grad2']),
'ke': np.mean(en['ke']),
'total':np.mean(en['total']),
})
return pd.DataFrame(data)


if __name__ == "__main__":
# we found that etb0.2 was the lowest DFT energy
cell, mf = pyq.recover_pyscf("etb0.2.hdf5")
wf_jax, _ = pyq.generate_slater(cell, mf, jax=True, nimages=2)
wf_pyscf, _ = pyq.generate_slater(cell, mf, jax=False, eval_gto_precision=1e-16 )

for nconfig in [10, 100, 1000]:
print(f"###### nconfig = {nconfig}")
configs = pyq.initial_guess(cell, nconfig)
check_value(configs, wf_jax, wf_pyscf)
df = check_energy(configs, wf_jax, wf_pyscf)
print(df)



Loading