Skip to content

Commit e986ef5

Browse files
axis values unchanged
1 parent 3a402e3 commit e986ef5

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

keras_rs/src/losses/list_mle_loss.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,13 @@ def compute_unreduced_loss(
105105
logits_masked = ops.where(
106106
valid_mask, logits, ops.full_like(logits, -1e9)
107107
)
108-
109108
sorted_logits, sorted_valid_mask = sort_by_scores(
110109
tensors_to_sort=[logits_masked, valid_mask],
111110
scores=labels_for_sorting,
112111
mask=None,
113112
shuffle_ties=False,
114113
seed=None,
115114
)
116-
117115
sorted_logits = ops.divide(
118116
sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
119117
)
@@ -139,9 +137,9 @@ def compute_unreduced_loss(
139137
# cumsum_forward = ops.cumsum(exp_logits, axis=1)
140138
# total_sum = ops.sum(exp_logits, axis=1, keepdims=True)
141139
# cumsum_from_right = total_sum - cumsum_forward + exp_logits
142-
reversed_exp = ops.flip(exp_logits, axis=[1])
140+
reversed_exp = ops.flip(exp_logits, axis=1)
143141
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
144-
cumsum_from_right = ops.flip(reversed_cumsum, axis=[1])
142+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
145143

146144
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
147145
log_probs = ops.subtract(sorted_logits, log_normalizers)

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras_rs.src import testing
88
from keras_rs.src.losses.list_mle_loss import ListMLELoss
99

10+
1011
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
1112
def setUp(self):
1213
self.unbatched_scores = ops.array(

0 commit comments

Comments
 (0)