@@ -451,7 +451,7 @@ def _get_cell_noise_count_posterior_coo(
451451 f'accurate for your dataset.' )
452452 raise RuntimeError ('Zero cells found!' )
453453
454- dataloader_index_to_analyzed_bc_index = np .where (cell_logic )[0 ]
454+ dataloader_index_to_analyzed_bc_index = torch .where (torch . tensor ( cell_logic ) )[0 ]
455455 cell_data_loader = DataLoader (
456456 count_matrix [cell_logic ],
457457 empty_drop_dataset = None ,
@@ -468,6 +468,12 @@ def _get_cell_noise_count_posterior_coo(
468468 log_probs = []
469469 ind = 0
470470 n_minibatches = len (cell_data_loader )
471+ analyzed_gene_inds = torch .tensor (self .analyzed_gene_inds .copy ())
472+ if analyzed_bcs_only :
473+ barcode_inds = torch .tensor (self .dataset_obj .analyzed_barcode_inds .copy ())
474+ else :
475+ barcode_inds = torch .tensor (self .barcode_inds .copy ())
476+ nonzero_noise_offset_dict = {}
471477
472478 logger .info ('Computing posterior noise count probabilities in mini-batches.' )
473479
@@ -505,57 +511,52 @@ def _get_cell_noise_count_posterior_coo(
505511 )
506512
507513 # Get the original gene index from gene index in the trimmed dataset.
508- genes_i = self . analyzed_gene_inds [genes_i_analyzed ]
514+ genes_i = analyzed_gene_inds [genes_i_analyzed . cpu () ]
509515
510516 # Barcode index in the dataloader.
511- bcs_i = bcs_i_chunk + ind
517+ bcs_i = ( bcs_i_chunk + ind ). cpu ()
512518
513519 # Obtain the real barcode index since we only use cells.
514520 bcs_i = dataloader_index_to_analyzed_bc_index [bcs_i ]
515521
516522 # Translate chunk barcode inds to overall inds.
517- if analyzed_bcs_only :
518- bcs_i = self .dataset_obj .analyzed_barcode_inds [bcs_i ]
519- else :
520- bcs_i = self .barcode_inds [bcs_i ]
523+ bcs_i = barcode_inds [bcs_i ]
521524
522525 # Add sparse matrix values to lists.
523- try :
524- bcs .extend (bcs_i .tolist ())
525- genes .extend (genes_i .tolist ())
526- c .extend (c_i .tolist ())
527- log_probs .extend (log_prob_i .tolist ())
528- c_offset .extend (noise_count_offset_NG [bcs_i_chunk , genes_i_analyzed ]
529- .detach ().cpu ().numpy ())
530- except TypeError as e :
531- # edge case of a single value
532- bcs .append (bcs_i )
533- genes .append (genes_i )
534- c .append (c_i )
535- log_probs .append (log_prob_i )
536- c_offset .append (noise_count_offset_NG [bcs_i_chunk , genes_i_analyzed ]
537- .detach ().cpu ().numpy ())
526+ bcs .append (bcs_i .detach ())
527+ genes .append (genes_i .detach ())
528+ c .append (c_i .detach ().cpu ())
529+ log_probs .append (log_prob_i .detach ().cpu ())
530+
531+ # Update offset dict with any nonzeros.
532+ nonzero_offset_inds , nonzero_noise_count_offsets = dense_to_sparse_op_torch (
533+ noise_count_offset_NG [bcs_i_chunk , genes_i_analyzed ].detach ().flatten (),
534+ )
535+ m_i = self .index_converter .get_m_indices (cell_inds = bcs_i , gene_inds = genes_i )
536+
537+ nonzero_noise_offset_dict .update (
538+ dict (zip (m_i [nonzero_offset_inds .detach ().cpu ()].tolist (),
539+ nonzero_noise_count_offsets .detach ().cpu ().tolist ()))
540+ )
541+ c_offset .append (noise_count_offset_NG [bcs_i_chunk , genes_i_analyzed ].detach ().cpu ())
538542
539543 # Increment barcode index counter.
540544 ind += data .shape [0 ] # Same as data_loader.batch_size
541545
542- # Convert the lists to numpy arrays.
543- log_probs = np .array (log_probs , dtype = float )
544- c = np .array (c , dtype = np .uint32 )
545- barcodes = np .array (bcs , dtype = np .uint64 ) # uint32 is too small!
546- genes = np .array (genes , dtype = np .uint64 ) # use same as above for IndexConverter
547- noise_count_offsets = np .array (c_offset , dtype = np .uint32 )
546+ # Concatenate lists.
547+ log_probs = torch .cat (log_probs )
548+ c = torch .cat (c )
549+ barcodes = torch .cat (bcs )
550+ genes = torch .cat (genes )
548551
549552 # Translate (barcode, gene) inds to 'm' format index.
550553 m = self .index_converter .get_m_indices (cell_inds = barcodes , gene_inds = genes )
551554
552555 # Put the counts into a sparse csr_matrix.
553556 self ._noise_count_posterior_coo = sp .coo_matrix (
554557 (log_probs , (m , c )),
555- shape = [np .prod (self .count_matrix_shape ), n_counts_max ],
558+ shape = [np .prod (self .count_matrix_shape , dtype = np . uint64 ), n_counts_max ],
556559 )
557- noise_offset_dict = dict (zip (m , noise_count_offsets ))
558- nonzero_noise_offset_dict = {k : v for k , v in noise_offset_dict .items () if (v > 0 )}
559560 self ._noise_count_posterior_coo_offsets = nonzero_noise_offset_dict
560561 return self ._noise_count_posterior_coo
561562
@@ -1517,7 +1518,9 @@ def __repr__(self):
15171518 f'\n \t total_n_genes: { self .total_n_genes } '
15181519 f'\n \t matrix_shape: { self .matrix_shape } ' )
15191520
1520- def get_m_indices (self , cell_inds : np .ndarray , gene_inds : np .ndarray ) -> np .ndarray :
1521+ def get_m_indices (self ,
1522+ cell_inds : Union [np .ndarray , torch .Tensor ],
1523+ gene_inds : Union [np .ndarray , torch .Tensor ]) -> Union [np .ndarray , torch .Tensor ]:
15211524 """Given arrays of cell indices and gene indices, suitable for a sparse matrix,
15221525 convert them to 'm' index values.
15231526 """
@@ -1527,7 +1530,12 @@ def get_m_indices(self, cell_inds: np.ndarray, gene_inds: np.ndarray) -> np.ndar
15271530 if not ((gene_inds >= 0 ) & (gene_inds < self .total_n_genes )).all ():
15281531 raise ValueError (f'Requested gene_inds out of range: '
15291532 f'{ gene_inds [(gene_inds < 0 ) | (gene_inds >= self .total_n_genes )]} ' )
1530- return cell_inds * self .total_n_genes + gene_inds
1533+ if type (cell_inds ) == np .ndarray :
1534+ return cell_inds .astype (np .uint64 ) * self .total_n_genes + gene_inds .astype (np .uint64 )
1535+ elif type (cell_inds ) == torch .Tensor :
1536+ return cell_inds .type (torch .int64 ) * self .total_n_genes + gene_inds .type (torch .int64 )
1537+ else :
1538+ raise ValueError ('IndexConverter.get_m_indices received cell_inds of unkown object type' )
15311539
15321540 def get_ng_indices (self , m_inds : np .ndarray ) -> Tuple [np .ndarray , np .ndarray ]:
15331541 """Given a list of 'm' index values, return two arrays: cell index values
0 commit comments