@@ -1031,16 +1031,10 @@ def custom_forward(*inputs):
10311031 hidden_states = torch .utils .checkpoint .checkpoint (
10321032 create_custom_forward (resnet ), hidden_states , temb , scale
10331033 )
1034- hidden_states = torch .utils .checkpoint .checkpoint (
1035- create_custom_forward (motion_module ),
1036- hidden_states .requires_grad_ (),
1037- temb ,
1038- num_frames ,
1039- )
10401034
10411035 else :
10421036 hidden_states = resnet (hidden_states , temb , scale = scale )
1043- hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
1037+ hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
10441038
10451039 output_states = output_states + (hidden_states ,)
10461040
@@ -1221,10 +1215,10 @@ def custom_forward(*inputs):
12211215 encoder_attention_mask = encoder_attention_mask ,
12221216 return_dict = False ,
12231217 )[0 ]
1224- hidden_states = motion_module (
1225- hidden_states ,
1226- num_frames = num_frames ,
1227- )[0 ]
1218+ hidden_states = motion_module (
1219+ hidden_states ,
1220+ num_frames = num_frames ,
1221+ )[0 ]
12281222
12291223 # apply additional residuals to the output of the last pair of resnet and attention blocks
12301224 if i == len (blocks ) - 1 and additional_residuals is not None :
@@ -1425,10 +1419,10 @@ def custom_forward(*inputs):
14251419 encoder_attention_mask = encoder_attention_mask ,
14261420 return_dict = False ,
14271421 )[0 ]
1428- hidden_states = motion_module (
1429- hidden_states ,
1430- num_frames = num_frames ,
1431- )[0 ]
1422+ hidden_states = motion_module (
1423+ hidden_states ,
1424+ num_frames = num_frames ,
1425+ )[0 ]
14321426
14331427 if self .upsamplers is not None :
14341428 for upsampler in self .upsamplers :
@@ -1563,15 +1557,10 @@ def custom_forward(*inputs):
15631557 hidden_states = torch .utils .checkpoint .checkpoint (
15641558 create_custom_forward (resnet ), hidden_states , temb
15651559 )
1566- hidden_states = torch .utils .checkpoint .checkpoint (
1567- create_custom_forward (resnet ),
1568- hidden_states ,
1569- temb ,
1570- )
15711560
15721561 else :
15731562 hidden_states = resnet (hidden_states , temb , scale = scale )
1574- hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
1563+ hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
15751564
15761565 if self .upsamplers is not None :
15771566 for upsampler in self .upsamplers :
0 commit comments