From 260dae8c9b2a41d999b331a9f59ecad1b78cceea Mon Sep 17 00:00:00 2001 From: Zhe Wang Date: Tue, 22 Oct 2024 14:36:35 -0400 Subject: [PATCH] Add the additional triples term when the field is on. --- pycc/ccdensity.py | 108 +++++++++++++++++++++++++++++++++++++++------- pycc/cclambda.py | 92 ++++++++++++++++++++++++++++----------- pycc/cctriples.py | 90 +++++++++++++++++++++++++++++++++++--- pycc/ccwfn.py | 63 ++++++++++++++++++++------- 4 files changed, 289 insertions(+), 64 deletions(-) diff --git a/pycc/ccdensity.py b/pycc/ccdensity.py index cdc42ab..dc18ff4 100644 --- a/pycc/ccdensity.py +++ b/pycc/ccdensity.py @@ -8,7 +8,7 @@ import time import numpy as np import torch -from .cctriples import t3c_ijk, t3c_abc, l3_ijk, l3_abc, t3c_bc, l3_bc, t3_pert_ijk, t3_pert_bc +from .cctriples import t3c_ijk, t3c_abc, l3_ijk, l3_abc, t3c_bc, l3_bc, t3_pert_ijk, t3_pert_bc, t3_ijkabc_alt class ccdensity(object): """ @@ -172,6 +172,8 @@ def compute_onepdm(self, t1, t2, l1, l2, real_time=False): F = self.ccwfn.H.F ERI = self.ccwfn.H.ERI L = self.ccwfn.H.L + + contract = self.contract if isinstance(t1, torch.Tensor): if self.ccwfn.precision == 'DP': @@ -195,16 +197,37 @@ def compute_onepdm(self, t1, t2, l1, l2, real_time=False): Wovoo = self.ccwfn.build_cc3_Wmbij(o, v, ERI, t1, Woooo) Wvovv = self.ccwfn.build_cc3_Wamef(o, v, ERI, t1) Wooov = self.ccwfn.build_cc3_Wmnie(o, v, ERI, t1) + + if real_time is True: + t3_full = t3_ijkabc_alt(o, v, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + """ + for i in range(no): + for j in range(no): + for k in range(no): + for a in range(nv): + for b in range(nv): + for c in range(nv): + t3_full[i,j,k,a,b,c] = t3_ijkabc(o, v, i, j, k, a, b, c, t2, Wvvvo, Wovoo, F, contract) + """ + opdm[o,v] += self.build_cc3_Dov_pert(o, v, no, nv, F, L, t1, t2, l1, l2, t3_full, Wvvvo, Wovoo, Fov, Wvovv, Wooov, real_time=real_time) + # Density matrix blocks in contractions with T1-transformed dipole integrals + if isinstance(t1, torch.Tensor): + opdm_cc3 = torch.zeros_like(opdm) + else: + opdm_cc3 = np.zeros_like(opdm) + opdm_cc3[o,o] += self.build_cc3_Doo_pert(o, v, no, nv, F, L, t2, l1, l2, t3_full, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=real_time) + opdm_cc3[v,v] += self.build_cc3_Dvv_pert(o, v, no, nv, F, L, t2, l1, l2, t3_full, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=real_time) - opdm[o,v] += self.build_cc3_Dov(o, v, no, nv, F, L, t1, t2, l1, l2, Wvvvo, Wovoo, Fov, Wvovv, Wooov, real_time=real_time) - - # Density matrix blocks in contractions with T1-transformed dipole integrals - if isinstance(t1, torch.Tensor): - opdm_cc3 = torch.zeros_like(opdm) else: - opdm_cc3 = np.zeros_like(opdm) - opdm_cc3[o,o] += self.build_cc3_Doo(o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov) - opdm_cc3[v,v] += self.build_cc3_Dvv(o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov) + opdm[o,v] += self.build_cc3_Dov(o, v, no, nv, F, L, t1, t2, l1, l2, Wvvvo, Wovoo, Fov, Wvovv, Wooov) + + # Density matrix blocks in contractions with T1-transformed dipole integrals + if isinstance(t1, torch.Tensor): + opdm_cc3 = torch.zeros_like(opdm) + else: + opdm_cc3 = np.zeros_like(opdm) + opdm_cc3[o,o] += self.build_cc3_Doo(o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov) + opdm_cc3[v,v] += self.build_cc3_Dvv(o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov) return (opdm, opdm_cc3) @@ -275,6 +298,60 @@ def build_Dov(self, t1, t2, l1, l2): # complete # CC3 contributions to the one electron densities def build_cc3_Dov(self, o, v, no, nv, F, L, t1, t2, l1, l2, Wvvvo, Wovoo, Fov, Wvovv, Wooov, real_time=False): + contract = self.contract + if isinstance(t1, torch.Tensor): + Dov = torch.zeros_like(t1) + Zlmdi = torch.zeros_like(t2[:,:,:,:no]) + else: + Dov = np.zeros_like(t1) + Zlmdi = np.zeros_like(t2[:,:,:,:no]) + for i in range(no): + for j in range(no): + for k in range(no): + l3 = l3_ijk(i, j, k, o, v, L, l1, l2, Fov, Wvovv, Wooov, F, contract) + # Intermediate for Dov_2 + Zlmdi[i,j] += contract('def,ife->di', l3, t2[k]) + # Dov_1 + t3 = t3c_ijk(o, v, i, j, k, t2, Wvvvo, Wovoo, F, contract) + Dov[i] += contract('abc,bc->a', t3 - t3.swapaxes(0,1), l2[j,k]) + # Dov_2 + Dov -= contract('lmdi, lmda->ia', Zlmdi, t2) + + return Dov + + def build_cc3_Doo(self, o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=False): + contract = self.contract + if isinstance(l1, torch.Tensor): + Doo = torch.zeros_like(l1[:,:no]) + else: + Doo = np.zeros_like(l1[:,:no]) + for b in range(nv): + for c in range(nv): + t3 = t3c_bc(o, v, b, c, t2, Wvvvo, Wovoo, F, contract) + l3 = l3_bc(b, c, o, v, L, l1, l2, Fov, Wvovv, Wooov, F, contract) + Doo -= 0.5 * contract('lmia,lmja->ij', t3, l3) + + return Doo + + def build_cc3_Dvv(self, o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=False): + contract = self.contract + if isinstance(l1, torch.Tensor): + Dvv = torch.zeros_like(l1) + Dvv = torch.nn.functional.pad(Dvv, (0,0,0,nv-no)) + else: + Dvv = np.zeros_like(l1) + Dvv = np.pad(Dvv, ((0,nv-no), (0,0))) + for i in range(no): + for j in range(no): + for k in range(no): + t3 = t3c_ijk(o, v, i, j, k, t2, Wvvvo, Wovoo, F, contract) + l3 = l3_ijk(i, j, k, o, v, L, l1, l2, Fov, Wvovv, Wooov, F, contract) + Dvv += 0.5 * contract('bdc,adc->ab', t3, l3) + + return Dvv + + # CC3 contributions to the one electron densities when a perturbation is present + def build_cc3_Dov_pert(self, o, v, no, nv, F, L, t1, t2, l1, l2, t3_full, Wvvvo, Wovoo, Fov, Wvovv, Wooov, real_time=False): contract = self.contract if isinstance(t1, torch.Tensor): Dov = torch.zeros_like(t1) @@ -295,14 +372,14 @@ def build_cc3_Dov(self, o, v, no, nv, F, L, t1, t2, l1, l2, Wvvvo, Wovoo, Fov, W V = F - self.ccwfn.H.F.clone() else: V = F - self.ccwfn.H.F.copy() - t3 -= t3_pert_ijk(o, v, i, j, k, t2, V, F, contract) + t3 -= t3_pert_ijk(o, v, i, j, k, t2, t3_full, V, F, contract) Dov[i] += contract('abc,bc->a', t3 - t3.swapaxes(0,1), l2[j,k]) # Dov_2 Dov -= contract('lmdi, lmda->ia', Zlmdi, t2) return Dov - def build_cc3_Doo(self, o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=False): + def build_cc3_Doo_pert(self, o, v, no, nv, F, L, t2, l1, l2, t3_full, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=False): contract = self.contract if isinstance(l1, torch.Tensor): Doo = torch.zeros_like(l1[:,:no]) @@ -316,13 +393,13 @@ def build_cc3_Doo(self, o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv V = F - self.ccwfn.H.F.clone() else: V = F - self.ccwfn.H.F.copy() - t3 -= t3_pert_bc(o, v, b, c, t2, V, F, contract) + t3 -= t3_pert_bc(o, v, b, c, t2, t3_full, V, F, contract) l3 = l3_bc(b, c, o, v, L, l1, l2, Fov, Wvovv, Wooov, F, contract) Doo -= 0.5 * contract('lmia,lmja->ij', t3, l3) return Doo - def build_cc3_Dvv(self, o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=False): + def build_cc3_Dvv_pert(self, o, v, no, nv, F, L, t2, l1, l2, t3_full, Fov, Wvvvo, Wovoo, Wvovv, Wooov, real_time=False): contract = self.contract if isinstance(l1, torch.Tensor): Dvv = torch.zeros_like(l1) @@ -339,7 +416,7 @@ def build_cc3_Dvv(self, o, v, no, nv, F, L, t2, l1, l2, Fov, Wvvvo, Wovoo, Wvovv V = F - self.ccwfn.H.F.clone() else: V = F - self.ccwfn.H.F.copy() - t3 -= t3_pert_ijk(o, v, i, j, k, t2, V, F, contract) + t3 -= t3_pert_ijk(o, v, i, j, k, t2, t3_full, V, F, contract) l3 = l3_ijk(i, j, k, o, v, L, l1, l2, Fov, Wvovv, Wooov, F, contract) Dvv += 0.5 * contract('bdc,adc->ab', t3, l3) @@ -610,5 +687,4 @@ def build_Mvv(self, no, nv, ints, t1): Mvv = Mvv - contract('ie,ia->ae', ints[:no,-nv:], t1) return Mvv - - + diff --git a/pycc/cclambda.py b/pycc/cclambda.py index a958971..51c2b8b 100644 --- a/pycc/cclambda.py +++ b/pycc/cclambda.py @@ -11,7 +11,7 @@ from opt_einsum import contract from .utils import helper_diis import torch -from .cctriples import t3c_ijk, l3_ijk, l3_ijk_alt, t3_pert_ijk +from .cctriples import t3c_ijk, l3_ijk, l3_ijk_alt, t3_pert_ijk, t3_ijkabc_alt class cclambda(object): @@ -342,20 +342,51 @@ def residuals(self, F, t1, t2, l1, l2): Zmndi = np.zeros_like(t2[:,:,:,:no]) Zmdfa = np.zeros_like(t2) Zmdfa = np.pad(Zmdfa, ((0,0), (0,nv-no), (0,0), (0,0))) - for m in range(no): - for n in range(no): - for l in range(no): - t3_lmn = t3c_ijk(o, v, l, m, n, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) - if self.ccwfn.real_time is True: - if isinstance(t1, torch.Tensor): - V = F - self.ccwfn.H.F.clone() - else: - V = F - self.ccwfn.H.F.copy() - t3_lmn -= t3_pert_ijk(o, v, l, m, n, t2, V, F, contract) - Zmndi[m,n] += contract('def,ief->di', t3_lmn, ERI[o,l,v,v]) - Zmndi[m,n] -= contract('fed,ief->di', t3_lmn, L[o,l,v,v]) - Zmdfa[m] += contract('def,ea->dfa', t3_lmn, ERI[n,l,v,v]) - Zmdfa[m] -= contract('dfe,ea->dfa', t3_lmn, L[n,l,v,v]) + """ + if self.ccwfn.real_time is True: + t3_full = np.zeros((no,no,no,nv,nv,nv)) + for i in range(no): + for j in range(no): + for k in range(no): + for a in range(nv): + for b in range(nv): + for c in range(nv): + t3_full[i,j,k,a,b,c] = t3_ijkabc(o, v, i, j, k, a, b, c, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + """ + if self.ccwfn.real_time is True: + if isinstance(t1, torch.Tensor): + V = F - self.ccwfn.H.F.clone() + else: + V = F - self.ccwfn.H.F.copy() + if abs(V[0,0]) >= 1E-6: + t3_full = t3_ijkabc_alt(o, v, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + for m in range(no): + for n in range(no): + for l in range(no): + t3_lmn = t3_full[l,m,n] + t3_lmn -= t3_pert_ijk(o, v, l, m, n, t2, t3_full, V, F, contract) + Zmndi[m,n] += contract('def,ief->di', t3_lmn, ERI[o,l,v,v]) + Zmndi[m,n] -= contract('fed,ief->di', t3_lmn, L[o,l,v,v]) + Zmdfa[m] += contract('def,ea->dfa', t3_lmn, ERI[n,l,v,v]) + Zmdfa[m] -= contract('dfe,ea->dfa', t3_lmn, L[n,l,v,v]) + else: + for m in range(no): + for n in range(no): + for l in range(no): + t3_lmn = t3c_ijk(o, v, l, m, n, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + Zmndi[m,n] += contract('def,ief->di', t3_lmn, ERI[o,l,v,v]) + Zmndi[m,n] -= contract('fed,ief->di', t3_lmn, L[o,l,v,v]) + Zmdfa[m] += contract('def,ea->dfa', t3_lmn, ERI[n,l,v,v]) + Zmdfa[m] -= contract('dfe,ea->dfa', t3_lmn, L[n,l,v,v]) + else: + for m in range(no): + for n in range(no): + for l in range(no): + t3_lmn = t3c_ijk(o, v, l, m, n, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + Zmndi[m,n] += contract('def,ief->di', t3_lmn, ERI[o,l,v,v]) + Zmndi[m,n] -= contract('fed,ief->di', t3_lmn, L[o,l,v,v]) + Zmdfa[m] += contract('def,ea->dfa', t3_lmn, ERI[n,l,v,v]) + Zmdfa[m] -= contract('dfe,ea->dfa', t3_lmn, L[n,l,v,v]) if isinstance(t1, torch.Tensor): Y1 = torch.zeros_like(l1) Y2 = torch.zeros_like(l2) @@ -389,17 +420,26 @@ def residuals(self, F, t1, t2, l1, l2): Zjlid_1 = np.zeros_like(l2[:,:,:no,:]) Zjlid_2 = np.zeros_like(l2[:,:,:no,:]) # t3l1 - for l in range(no): - for m in range(no): - for n in range(no): - t3_lmn = t3c_ijk(o, v, l, m, n, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) - if self.ccwfn.real_time is True: - if isinstance(t1, torch.Tensor): - V = F - self.ccwfn.H.F.clone() - else: - V = F - self.ccwfn.H.F.copy() - t3_lmn -= t3_pert_ijk(o, v, l, m, n, t2, V, F, contract) - Znf[n] += contract('de,def->f', l2[l,m], (t3_lmn - t3_lmn.swapaxes(0,2))) + if self.ccwfn.real_time is True: + if abs(V[0,0]) >= 1E-6: + for l in range(no): + for m in range(no): + for n in range(no): + t3_lmn = t3_full[l,m,n] + t3_lmn -= t3_pert_ijk(o, v, l, m, n, t2, t3_full, V, F, contract) + Znf[n] += contract('de,def->f', l2[l,m], (t3_lmn - t3_lmn.swapaxes(0,2))) + else: + for l in range(no): + for m in range(no): + for n in range(no): + t3_lmn = t3c_ijk(o, v, l, m, n, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + Znf[n] += contract('de,def->f', l2[l,m], (t3_lmn - t3_lmn.swapaxes(0,2))) + else: + for l in range(no): + for m in range(no): + for n in range(no): + t3_lmn = t3c_ijk(o, v, l, m, n, t2, Wvvvo, Wovoo, F, contract, WithDenom=True) + Znf[n] += contract('de,def->f', l2[l,m], (t3_lmn - t3_lmn.swapaxes(0,2))) for m in range(no): Y1 += contract('idf,dfa->ia', l2[:,m], Zmdfa[m]) Y1 += contract('iaf,f->ia', L[o,m,v,v], Znf[m]) diff --git a/pycc/cctriples.py b/pycc/cctriples.py index 8a4e27e..5081f4a 100644 --- a/pycc/cctriples.py +++ b/pycc/cctriples.py @@ -2,6 +2,7 @@ import numpy as np import torch +import time # Various triples drivers; useful for (T) corrections and CC3 @@ -544,9 +545,17 @@ def l3_bc(b, c, o, v, L, l1, l2, Fov, Wvovv, Wooov, F, contract, WithDenom=True) # Useful for RT-CC3 # Additional term in T3 equation when an external perturbation is present -def t3_pert_ijk(o, v, i, j, k, t2, V, F, contract, WithDenom=True): +def t3_pert_ijk(o, v, i, j, k, t2, t3_full, V, F, contract, WithDenom=True): + + #time1 = time.time() + t3 = -0.5 * contract('ad,dbc->abc', V[v,v], t3_full[i,j,k]) + t3 -= -0.5 * contract('l,labc->abc', V[o,i], t3_full[:,j,k]) + #print("Time(term1): ", time.time() - time1) + + #time1 = time.time() tmp = contract('ld,ad->al', V[o,v], t2[i,j]) - t3 = contract('al,lcb->abc', tmp, t2[k]) + t3 -= contract('al,lcb->abc', tmp, t2[k]) + #print("Time(term2): ", time.time() - time1) if WithDenom is True: if isinstance(t2, torch.Tensor): @@ -561,9 +570,13 @@ def t3_pert_ijk(o, v, i, j, k, t2, V, F, contract, WithDenom=True): else: return t3 -def t3_pert_abc(o, v, a, b, c, t2, V, F, contract, WithDenom=True): +def t3_pert_abc(o, v, a, b, c, t2, t3_full, V, F, contract, WithDenom=True): + + t3 = -0.5 * contract('d,ijkd->ijk', V[a,v], t3_full[:,:,:,:,b,c]) + t3 -= -0.5 * contract('li,ljk->ijk', V[o,o], t3_full[:,:,:,a,b,c]) + tmp = contract('ld,ijd->ijl', V[o,v], t2[:,:,a]) - t3 = contract('ijl,kl->ijk', tmp, t2[:,:,c,b]) + t3 -= contract('ijl,kl->ijk', tmp, t2[:,:,c,b]) if WithDenom is True: no = o.stop @@ -579,9 +592,13 @@ def t3_pert_abc(o, v, a, b, c, t2, V, F, contract, WithDenom=True): else: return t3 -def t3_pert_bc(o, v, b, c, t2, V, F, contract, WithDenom=True): +def t3_pert_bc(o, v, b, c, t2, t3_full, V, F, contract, WithDenom=True): + + t3 = -0.5 * contract('ad,ijkd->ijka', V[v,v], t3_full[:,:,:,:,b,c]) + t3 -= -0.5 * contract('li,ljka->ijka', V[o,o], t3_full[:,:,:,:,b,c]) + tmp = contract('ld,ijad->ijal', V[o,v], t2) - t3 = contract('ijal,kl->ijka', tmp, t2[:,:,c,b]) + t3 -= contract('ijal,kl->ijka', tmp, t2[:,:,c,b]) if WithDenom is True: no = o.stop @@ -599,3 +616,64 @@ def t3_pert_bc(o, v, b, c, t2, V, F, contract, WithDenom=True): return t3/denom else: return t3 + +def t3_ijkabc(o, v, i, j, k, a, b, c, t2, Wvvvo, Wovoo, F, contract, WithDenom=True): + + time0 = time.time() + + t3 = contract('e,e->', Wvvvo[b,a,:,i], t2[k,j,c]) + t3 += contract('e,e->', Wvvvo[c,a,:,i], t2[j,k,b]) + t3 += contract('e,e->', Wvvvo[a,c,:,k], t2[j,i,b]) + t3 += contract('e,e->', Wvvvo[b,c,:,k], t2[i,j,a]) + t3 += contract('e,e->', Wvvvo[c,b,:,j], t2[i,k,a]) + t3 += contract('e,e->', Wvvvo[a,b,:,j], t2[k,i,c]) + + t3 -= contract('m,m->', Wovoo[:,c,j,k], t2[i,:,a,b]) + t3 -= contract('m,m->', Wovoo[:,b,k,j], t2[i,:,a,c]) + t3 -= contract('m,m->', Wovoo[:,b,i,j], t2[k,:,c,a]) + t3 -= contract('m,m->', Wovoo[:,a,j,i], t2[k,:,c,b]) + t3 -= contract('m,m->', Wovoo[:,a,k,i], t2[j,:,b,c]) + t3 -= contract('m,m->', Wovoo[:,c,i,k], t2[j,:,b,a]) + + if WithDenom is True: + no = o.stop + denom = F[i,i] + F[j,j] + F[k,k] + denom -= F[a+no,a+no] + F[b+no,b+no] + F[c+no,c+no] + return t3/denom + else: + return t3 + +def t3_ijkabc_alt(o, v, t2, Wvvvo, Wovoo, F, contract, WithDenom=True): + time0 = time.time() + + t3 = contract('baei,kjce->ijkabc', Wvvvo, t2) + t3 += contract('caei,jkbe->ijkabc', Wvvvo, t2) + t3 += contract('acek,jibe->ijkabc', Wvvvo, t2) + t3 += contract('bcek,ijae->ijkabc', Wvvvo, t2) + t3 += contract('cbej,ikae->ijkabc', Wvvvo, t2) + t3 += contract('abej,kice->ijkabc', Wvvvo, t2) + + t3 -= contract('mcjk,imab->ijkabc', Wovoo, t2) + t3 -= contract('mbkj,imac->ijkabc', Wovoo, t2) + t3 -= contract('mbij,kmca->ijkabc', Wovoo, t2) + t3 -= contract('maji,kmcb->ijkabc', Wovoo, t2) + t3 -= contract('maki,jmbc->ijkabc', Wovoo, t2) + t3 -= contract('mcik,jmba->ijkabc', Wovoo, t2) + + if WithDenom is True: + if isinstance(t2, torch.Tensor): + Fv = torch.diag(F)[v] + Fo = torch.diag(F)[o] + denom = torch.zeros_like(t3) + else: + Fv = np.diag(F)[v] + Fo = np.diag(F)[o] + denom = np.zeros_like(t3) + denom += Fo.reshape(-1,1,1,1,1,1) + Fo.reshape(1,-1,1,1,1,1) + Fo.reshape(1,1,-1,1,1,1) + denom -= Fv.reshape(1,1,1,-1,1,1) + Fv.reshape(1,1,1,1,-1,1) + Fv.reshape(1,1,1,1,1,-1) + + print("t3(full): ", time.time() - time0) + return t3/denom + else: + return t3 + diff --git a/pycc/ccwfn.py b/pycc/ccwfn.py index d6fdc6d..09f7522 100644 --- a/pycc/ccwfn.py +++ b/pycc/ccwfn.py @@ -13,7 +13,7 @@ from .utils import helper_diis, cc_contract from .hamiltonian import Hamiltonian from .local import Local -from .cctriples import t_tjl, t3c_ijk, t3d_ijk, t3c_abc, t3d_abc, t3_pert_ijk +from .cctriples import t_tjl, t3c_ijk, t3d_ijk, t3c_abc, t3d_abc, t3_pert_ijk, t3_ijkabc_alt, t3_ijkabc from .lccwfn import lccwfn class ccwfn(object): @@ -381,21 +381,52 @@ def residuals(self, F, t1, t2, real_time=False): else: X1 = np.zeros_like(t1) X2 = np.zeros_like(t2) - - for i in range(no): - for j in range(no): - for k in range(no): - t3 = t3c_ijk(o, v, i, j, k, t2, Wabei_cc3, Wmbij_cc3, F, contract, WithDenom=True) - if real_time is True: - if isinstance(t1, torch.Tensor): - V = F - self.H.F.clone() - else: - V = F - self.H.F.copy() - t3 -= t3_pert_ijk(o, v, i, j, k, t2, V, F, contract) - X1[i] += contract('abc,bc->a', t3 - t3.swapaxes(0,2), L[j,k,v,v]) - X2[i,j] += contract('abc,c->ab', t3 - t3.swapaxes(0,2), Fme[k]) - X2[i,j] += contract('abc,dbc->ad', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wamef_cc3.swapaxes(0,1)[k]) - X2[i] -= contract('abc,lc->lab', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wmnie_cc3[j,k]) + + """ + if real_time is True: + t3_full = np.zeros((no,no,no,nv,nv,nv)) + for i in range(no): + for j in range(no): + for k in range(no): + for a in range(nv): + for b in range(nv): + for c in range(nv): + t3_full[i,j,k,a,b,c] = t3_ijkabc(o, v, i, j, k, a, b, c, t2, Wabei_cc3, Wmbij_cc3, F, contract, WithDenom=True) + """ + if real_time is True: + if isinstance(t1, torch.Tensor): + V = F - self.H.F.clone() + else: + V = F - self.H.F.copy() + if abs(V[0,0]) >= 1E-6: + t3_full = t3_ijkabc_alt(o, v, t2, Wabei_cc3, Wmbij_cc3, F, contract, WithDenom=True) + for i in range(no): + for j in range(no): + for k in range(no): + t3 = t3_full[i,j,k] + t3 -= t3_pert_ijk(o, v, i, j, k, t2, t3_full, V, F, contract) + X1[i] += contract('abc,bc->a', t3 - t3.swapaxes(0,2), L[j,k,v,v]) + X2[i,j] += contract('abc,c->ab', t3 - t3.swapaxes(0,2), Fme[k]) + X2[i,j] += contract('abc,dbc->ad', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wamef_cc3.swapaxes(0,1)[k]) + X2[i] -= contract('abc,lc->lab', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wmnie_cc3[j,k]) + else: + for i in range(no): + for j in range(no): + for k in range(no): + t3 = t3c_ijk(o, v, i, j, k, t2, Wabei_cc3, Wmbij_cc3, F, contract, WithDenom=True) + X1[i] += contract('abc,bc->a', t3 - t3.swapaxes(0,2), L[j,k,v,v]) + X2[i,j] += contract('abc,c->ab', t3 - t3.swapaxes(0,2), Fme[k]) + X2[i,j] += contract('abc,dbc->ad', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wamef_cc3.swapaxes(0,1)[k]) + X2[i] -= contract('abc,lc->lab', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wmnie_cc3[j,k]) + else: + for i in range(no): + for j in range(no): + for k in range(no): + t3 = t3c_ijk(o, v, i, j, k, t2, Wabei_cc3, Wmbij_cc3, F, contract, WithDenom=True) + X1[i] += contract('abc,bc->a', t3 - t3.swapaxes(0,2), L[j,k,v,v]) + X2[i,j] += contract('abc,c->ab', t3 - t3.swapaxes(0,2), Fme[k]) + X2[i,j] += contract('abc,dbc->ad', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wamef_cc3.swapaxes(0,1)[k]) + X2[i] -= contract('abc,lc->lab', 2 * t3 - t3.swapaxes(1,2) - t3.swapaxes(0,2), Wmnie_cc3[j,k]) r1 += X1 r2 += X2 + X2.swapaxes(0,1).swapaxes(2,3)