@@ -83,23 +83,28 @@ def forward(self,
8383 for i in range (int (np .ceil (num_k / nk ))):
8484 data [AtomicDataDict .KPOINT_KEY ] = kpoints [i * nk :(i + 1 )* nk ]
8585 data = self .h2k (data )
86+ h_transformed_np = None
8687 if self .overlap :
8788 data = self .s2k (data )
8889 if eig_solver == 'torch' :
8990 chklowt = torch .linalg .cholesky (data [self .s_out_field ])
9091 chklowtinv = torch .linalg .inv (chklowt )
9192 data [self .h_out_field ] = (chklowtinv @ data [self .h_out_field ] @ torch .transpose (chklowtinv ,dim0 = 1 ,dim1 = 2 ).conj ())
9293 elif eig_solver == 'numpy' :
93- chklowt = np .linalg .cholesky (data [self .s_out_field ].detach ().numpy ())
94+ s_np = data [self .s_out_field ].detach ().cpu ().numpy ()
95+ h_np = data [self .h_out_field ].detach ().cpu ().numpy ()
96+ chklowt = np .linalg .cholesky (s_np )
9497 chklowtinv = np .linalg .inv (chklowt )
95- data [self .h_out_field ] = (chklowtinv @ data [self .h_out_field ].detach ().numpy () @ np .transpose (chklowtinv ,(0 ,2 ,1 )).conj ())
96- else :
97- data [self .h_out_field ] = data [self .h_out_field ]
98-
98+ h_transformed_np = chklowtinv @ h_np @ np .transpose (chklowtinv ,(0 ,2 ,1 )).conj ()
99+
99100 if eig_solver == 'torch' :
100101 eigvals .append (torch .linalg .eigvalsh (data [self .h_out_field ]))
101102 elif eig_solver == 'numpy' :
102- eigvals .append (torch .from_numpy (np .linalg .eigvalsh (a = data [self .h_out_field ])))
103+ if h_transformed_np is None :
104+ h_transformed_np = data [self .h_out_field ].detach ().cpu ().numpy ()
105+ eigvals_np = np .linalg .eigvalsh (a = h_transformed_np )
106+ # Preserve dtype by converting to the Hamiltonian's original dtype
107+ eigvals .append (torch .from_numpy (eigvals_np ).to (dtype = self .h2k .dtype , device = self .h2k .device ))
103108
104109 data [self .out_field ] = torch .nested .as_nested_tensor ([torch .cat (eigvals , dim = 0 )])
105110 if nested :
0 commit comments