@@ -105,15 +105,13 @@ def compute_unreduced_loss(
105105        logits_masked  =  ops .where (
106106            valid_mask , logits , ops .full_like (logits , - 1e9 )
107107        )
108- 
109108        sorted_logits , sorted_valid_mask  =  sort_by_scores (
110109            tensors_to_sort = [logits_masked , valid_mask ],
111110            scores = labels_for_sorting ,
112111            mask = None ,
113112            shuffle_ties = False ,
114113            seed = None ,
115114        )
116- 
117115        sorted_logits  =  ops .divide (
118116            sorted_logits , ops .cast (self .temperature , dtype = sorted_logits .dtype )
119117        )
@@ -139,9 +137,9 @@ def compute_unreduced_loss(
139137        # cumsum_forward = ops.cumsum(exp_logits, axis=1) 
140138        # total_sum = ops.sum(exp_logits, axis=1, keepdims=True) 
141139        # cumsum_from_right = total_sum - cumsum_forward + exp_logits 
142-         reversed_exp  =  ops .flip (exp_logits , axis = [ 1 ] )
140+         reversed_exp  =  ops .flip (exp_logits , axis = 1 )
143141        reversed_cumsum  =  ops .cumsum (reversed_exp , axis = 1 )
144-         cumsum_from_right  =  ops .flip (reversed_cumsum , axis = [ 1 ] )
142+         cumsum_from_right  =  ops .flip (reversed_cumsum , axis = 1 )
145143
146144        log_normalizers  =  ops .log (cumsum_from_right  +  self ._epsilon )
147145        log_probs  =  ops .subtract (sorted_logits , log_normalizers )
0 commit comments