Skip to content

Commit 95bd47c

Browse files
committed
fix: Ensure proper tensor conversion for numpy solver in Eigenvalues class
1 parent bd6677a commit 95bd47c

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

dptb/nn/energy.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,19 @@ def forward(self,
9090
chklowtinv = torch.linalg.inv(chklowt)
9191
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
9292
elif eig_solver == 'numpy':
93-
chklowt = np.linalg.cholesky(data[self.s_out_field].detach().numpy())
93+
chklowt = np.linalg.cholesky(data[self.s_out_field].detach().cpu().numpy())
9494
chklowtinv = np.linalg.inv(chklowt)
95-
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field].detach().numpy() @ np.transpose(chklowtinv,(0,2,1)).conj())
96-
else:
97-
data[self.h_out_field] = data[self.h_out_field]
98-
95+
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field].detach().cpu().numpy() @ np.transpose(chklowtinv,(0,2,1)).conj())
96+
elif eig_solver == 'numpy':
97+
# Convert to numpy when using numpy solver without overlap
98+
data[self.h_out_field] = data[self.h_out_field].detach().cpu().numpy()
99+
99100
if eig_solver == 'torch':
100101
eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
101102
elif eig_solver == 'numpy':
102-
eigvals.append(torch.from_numpy(np.linalg.eigvalsh(a=data[self.h_out_field])))
103+
eigvals_np = np.linalg.eigvalsh(a=data[self.h_out_field])
104+
# Preserve dtype by converting to the Hamiltonian's original dtype
105+
eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype))
103106

104107
data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
105108
if nested:

0 commit comments

Comments
 (0)