2828 "https" : "http://fwdproxy:8080" ,
2929}
3030
31-
3231def compute_sqnr (x : torch .Tensor , y : torch .Tensor ) -> float :
3332 assert x .shape == y .shape , "Tensor shapes do not match"
3433 x = x .float ()
@@ -173,15 +172,23 @@ def __init__(self, mimi: nn.Module):
173172 self .mimi_model = mimi
174173
175174 def forward (self , x ):
176- return self .mimi_model .decode (x )
175+ x = x .transpose (1 , 2 )
176+ x = self .mimi_model .upsample (x )
177+ (emb ,) = self .mimi_model .decoder_transformer (x )
178+ emb .transpose (1 , 2 )
179+ with self .mimi_model ._context_for_encoder_decoder :
180+ out = self .mimi_model .decoder (emb )
181+ return out
177182
178- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
179- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
180- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
181- input = self .mimi .encode (chunk )
183+ emb_input = torch .rand (1 , 1 , 512 , device = "cpu" )
182184
183185 mimi_decode = MimiDecode (self .mimi )
184- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
186+ mimi_decode .eval ()
187+ mimi_decode (emb_input )
188+
189+ exported_decode : ExportedProgram = export (
190+ mimi_decode , (emb_input ,), strict = False
191+ )
185192 quantization_config = get_symmetric_quantization_config (
186193 is_per_channel = True ,
187194 is_dynamic = True ,
@@ -190,12 +197,12 @@ def forward(self, x):
190197 quantizer .set_global (quantization_config )
191198 m = exported_decode .module ()
192199 m = prepare_pt2e (m , quantizer )
193- m (input )
200+ m (emb_input )
194201 m = convert_pt2e (m )
195202 print ("quantized graph:" )
196203 print (m .graph )
197204 # Export quantized module
198- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
205+ exported_decode : ExportedProgram = export (m , (emb_input ,), strict = False )
199206 # Lower
200207 edge_manager = to_edge_transform_and_lower (
201208 exported_decode ,
@@ -208,16 +215,16 @@ def forward(self, x):
208215 with open (output_file , "wb" ) as file :
209216 exec_prog .write_to_file (file )
210217
211- eager_res = mimi_decode (input )
218+ eager_res = mimi_decode (emb_input )
212219 runtime = Runtime .get ()
213220 program = runtime .load_program (output_file )
214221 method = program .load_method ("forward" )
215- flattened_x = tree_flatten (input )[0 ]
222+ flattened_x = tree_flatten (emb_input )[0 ]
216223 res = method .execute (flattened_x )
217224 # Compare results
218225 sqnr = compute_sqnr (eager_res , res [0 ])
219226 print (f"SQNR: { sqnr } " )
220- torch .testing .assert_close (eager_res , res [0 ], atol = 1e -3 , rtol = 1e-3 )
227+ torch .testing .assert_close (eager_res , res [0 ], atol = 4e -3 , rtol = 1e-3 )
221228
222229
223230if __name__ == "__main__" :
0 commit comments