Skip to content

Commit eacbf73

Browse files
authored
Mimi: sqnr and test without streaming
Differential Revision: D72978040 Pull Request resolved: #10004
1 parent 1eb9546 commit eacbf73

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,28 @@ def test_streaming_encoding_decoding(self):
135135

136136
all_codes_th = torch.cat(all_codes, dim=-1)
137137

138+
pcm_ref = self.mimi.decode(all_codes_th)
139+
138140
all_pcms = []
141+
for i in range(all_codes_th.shape[-1]):
142+
codes = all_codes_th[..., i : i + 1]
143+
pcm = self.mimi.decode(codes)
144+
all_pcms.append(pcm)
145+
all_pcms = torch.cat(all_pcms, dim=-1)
146+
sqnr = compute_sqnr(pcm_ref, all_pcms)
147+
print(f"sqnr = {sqnr} dB")
148+
self.assertTrue(sqnr > 4)
149+
150+
all_pcms_streaming = []
139151
with self.mimi.streaming(1):
140152
for i in range(all_codes_th.shape[-1]):
141153
codes = all_codes_th[..., i : i + 1]
142-
pcm = self.mimi.decode(codes)
143-
all_pcms.append(pcm)
144-
all_pcms = torch.cat(all_pcms, dim=-1)
145-
146-
pcm_ref = self.mimi.decode(all_codes_th)
147-
self.assertTrue(torch.allclose(pcm_ref, all_pcms, atol=1e-5))
154+
pcm_streaming = self.mimi.decode(codes)
155+
all_pcms_streaming.append(pcm_streaming)
156+
all_pcms_streaming = torch.cat(all_pcms_streaming, dim=-1)
157+
sqnr_streaming = compute_sqnr(pcm_ref, all_pcms_streaming)
158+
print(f"sqnr_streaming = {sqnr_streaming} dB")
159+
self.assertTrue(sqnr_streaming > 100)
148160

149161
def test_exported_encoding(self):
150162
"""Ensure exported encoding model is consistent with reference output."""

0 commit comments

Comments
 (0)