Skip to content

Commit 84d18b8

Browse files
Merge pull request #287 from AsymmetryChou/fix_ovp_np
fix: Ensure proper tensor conversion for numpy solver in Eigenvalues …
2 parents fd65c44 + 52d5928 commit 84d18b8

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

dptb/nn/energy.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,28 @@ def forward(self,
8383
for i in range(int(np.ceil(num_k / nk))):
8484
data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk]
8585
data = self.h2k(data)
86+
h_transformed_np = None
8687
if self.overlap:
8788
data = self.s2k(data)
8889
if eig_solver == 'torch':
8990
chklowt = torch.linalg.cholesky(data[self.s_out_field])
9091
chklowtinv = torch.linalg.inv(chklowt)
9192
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
9293
elif eig_solver == 'numpy':
93-
chklowt = np.linalg.cholesky(data[self.s_out_field].detach().numpy())
94+
s_np = data[self.s_out_field].detach().cpu().numpy()
95+
h_np = data[self.h_out_field].detach().cpu().numpy()
96+
chklowt = np.linalg.cholesky(s_np)
9497
chklowtinv = np.linalg.inv(chklowt)
95-
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field].detach().numpy() @ np.transpose(chklowtinv,(0,2,1)).conj())
96-
else:
97-
data[self.h_out_field] = data[self.h_out_field]
98-
98+
h_transformed_np = chklowtinv @ h_np @ np.transpose(chklowtinv,(0,2,1)).conj()
99+
99100
if eig_solver == 'torch':
100101
eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
101102
elif eig_solver == 'numpy':
102-
eigvals.append(torch.from_numpy(np.linalg.eigvalsh(a=data[self.h_out_field])))
103+
if h_transformed_np is None:
104+
h_transformed_np = data[self.h_out_field].detach().cpu().numpy()
105+
eigvals_np = np.linalg.eigvalsh(a=h_transformed_np)
106+
# Preserve dtype by converting to the Hamiltonian's original dtype
107+
eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype, device=self.h2k.device))
103108

104109
data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
105110
if nested:

0 commit comments

Comments
 (0)