@@ -221,11 +221,6 @@ def parse_args():
221
221
default = None ,
222
222
help = "Revision of controlnet model identifier from huggingface.co/models." ,
223
223
)
224
- parser .add_argument (
225
- "--controlnet_from_pt" ,
226
- action = "store_true" ,
227
- help = "Load the controlnet model from a PyTorch checkpoint." ,
228
- )
229
224
parser .add_argument (
230
225
"--profile_steps" ,
231
226
type = int ,
@@ -248,12 +243,6 @@ def parse_args():
248
243
default = None ,
249
244
help = "Enables compilation cache." ,
250
245
)
251
- parser .add_argument (
252
- "--controlnet_revision" ,
253
- type = str ,
254
- default = None ,
255
- help = "Revision of controlnet model identifier from huggingface.co/models." ,
256
- )
257
246
parser .add_argument (
258
247
"--controlnet_from_pt" ,
259
248
action = "store_true" ,
@@ -1058,16 +1047,18 @@ def l2(xs):
1058
1047
)
1059
1048
# train
1060
1049
for batch in train_dataloader :
1061
- if args .profile_steps and global_step == 0 :
1062
- jax .profiler .start_trace (args .output_dir )
1063
- if args .profile_steps and args .profile_steps == global_step :
1064
- jax .profiler .stop_trace ()
1065
1050
1066
1051
batch = shard (batch )
1067
1052
with jax .profiler .StepTraceAnnotation ("train" , step_num = global_step ):
1068
1053
state , train_metric , train_rngs = p_train_step (
1069
1054
state , unet_params , text_encoder_params , vae_params , batch , train_rngs
1070
1055
)
1056
+ if args .profile_steps and global_step == 0 :
1057
+ train_metric ['loss' ].block_until_ready ()
1058
+ jax .profiler .start_trace (args .output_dir )
1059
+ if args .profile_steps and args .profile_steps == global_step :
1060
+ train_metric ['loss' ].block_until_ready ()
1061
+ jax .profiler .stop_trace ()
1071
1062
train_metrics .append (train_metric )
1072
1063
1073
1064
train_step_progress_bar .update (1 )
0 commit comments