@@ -135,16 +135,28 @@ def test_streaming_encoding_decoding(self):
135
135
136
136
all_codes_th = torch .cat (all_codes , dim = - 1 )
137
137
138
+ pcm_ref = self .mimi .decode (all_codes_th )
139
+
138
140
all_pcms = []
141
+ for i in range (all_codes_th .shape [- 1 ]):
142
+ codes = all_codes_th [..., i : i + 1 ]
143
+ pcm = self .mimi .decode (codes )
144
+ all_pcms .append (pcm )
145
+ all_pcms = torch .cat (all_pcms , dim = - 1 )
146
+ sqnr = compute_sqnr (pcm_ref , all_pcms )
147
+ print (f"sqnr = { sqnr } dB" )
148
+ self .assertTrue (sqnr > 4 )
149
+
150
+ all_pcms_streaming = []
139
151
with self .mimi .streaming (1 ):
140
152
for i in range (all_codes_th .shape [- 1 ]):
141
153
codes = all_codes_th [..., i : i + 1 ]
142
- pcm = self .mimi .decode (codes )
143
- all_pcms .append (pcm )
144
- all_pcms = torch .cat (all_pcms , dim = - 1 )
145
-
146
- pcm_ref = self . mimi . decode ( all_codes_th )
147
- self .assertTrue (torch . allclose ( pcm_ref , all_pcms , atol = 1e-5 ) )
154
+ pcm_streaming = self .mimi .decode (codes )
155
+ all_pcms_streaming .append (pcm_streaming )
156
+ all_pcms_streaming = torch .cat (all_pcms_streaming , dim = - 1 )
157
+ sqnr_streaming = compute_sqnr ( pcm_ref , all_pcms_streaming )
158
+ print ( f"sqnr_streaming = { sqnr_streaming } dB" )
159
+ self .assertTrue (sqnr_streaming > 100 )
148
160
149
161
def test_exported_encoding (self ):
150
162
"""Ensure exported encoding model is consistent with reference output."""
0 commit comments