File tree Expand file tree Collapse file tree 5 files changed +50
-5
lines changed
Expand file tree Collapse file tree 5 files changed +50
-5
lines changed Original file line number Diff line number Diff line change @@ -734,7 +734,16 @@ def add_noise(
734734 schedule_timesteps = self .timesteps .to (original_samples .device )
735735 timesteps = timesteps .to (original_samples .device )
736736
737- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
737+ step_indices = []
738+ for timestep in timesteps :
739+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
740+ if len (index_candidates ) == 0 :
741+ step_index = len (schedule_timesteps ) - 1
742+ elif len (index_candidates ) > 1 :
743+ step_index = index_candidates [1 ].item ()
744+ else :
745+ step_index = index_candidates [0 ].item ()
746+ step_indices .append (step_index )
738747
739748 sigma = sigmas [step_indices ].flatten ()
740749 while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -896,7 +896,16 @@ def add_noise(
896896 schedule_timesteps = self .timesteps .to (original_samples .device )
897897 timesteps = timesteps .to (original_samples .device )
898898
899- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
899+ step_indices = []
900+ for timestep in timesteps :
901+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
902+ if len (index_candidates ) == 0 :
903+ step_index = len (schedule_timesteps ) - 1
904+ elif len (index_candidates ) > 1 :
905+ step_index = index_candidates [1 ].item ()
906+ else :
907+ step_index = index_candidates [0 ].item ()
908+ step_indices .append (step_index )
900909
901910 sigma = sigmas [step_indices ].flatten ()
902911 while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -891,7 +891,16 @@ def add_noise(
891891 schedule_timesteps = self .timesteps .to (original_samples .device )
892892 timesteps = timesteps .to (original_samples .device )
893893
894- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
894+ step_indices = []
895+ for timestep in timesteps :
896+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
897+ if len (index_candidates ) == 0 :
898+ step_index = len (schedule_timesteps ) - 1
899+ elif len (index_candidates ) > 1 :
900+ step_index = index_candidates [1 ].item ()
901+ else :
902+ step_index = index_candidates [0 ].item ()
903+ step_indices .append (step_index )
895904
896905 sigma = sigmas [step_indices ].flatten ()
897906 while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -897,7 +897,16 @@ def add_noise(
897897 schedule_timesteps = self .timesteps .to (original_samples .device )
898898 timesteps = timesteps .to (original_samples .device )
899899
900- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
900+ step_indices = []
901+ for timestep in timesteps :
902+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
903+ if len (index_candidates ) == 0 :
904+ step_index = len (schedule_timesteps ) - 1
905+ elif len (index_candidates ) > 1 :
906+ step_index = index_candidates [1 ].item ()
907+ else :
908+ step_index = index_candidates [0 ].item ()
909+ step_indices .append (step_index )
901910
902911 sigma = sigmas [step_indices ].flatten ()
903912 while len (sigma .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -828,7 +828,16 @@ def add_noise(
828828 schedule_timesteps = self .timesteps .to (original_samples .device )
829829 timesteps = timesteps .to (original_samples .device )
830830
831- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
831+ step_indices = []
832+ for timestep in timesteps :
833+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
834+ if len (index_candidates ) == 0 :
835+ step_index = len (schedule_timesteps ) - 1
836+ elif len (index_candidates ) > 1 :
837+ step_index = index_candidates [1 ].item ()
838+ else :
839+ step_index = index_candidates [0 ].item ()
840+ step_indices .append (step_index )
832841
833842 sigma = sigmas [step_indices ].flatten ()
834843 while len (sigma .shape ) < len (original_samples .shape ):
You can’t perform that action at this time.
0 commit comments