Skip to content

Commit 1d7b4b6

Browse files
pcuencaSkylion007
andauthored
Ruff: apply same rules as in transformers (#2827)
* Apply same ruff settings as in transformers See https://github.com/huggingface/transformers/blob/main/pyproject.toml Co-authored-by: Aaron Gokaslan <[email protected]> * Apply new style rules * Style Co-authored-by: Aaron Gokaslan <[email protected]> * style * remove list, ruff wouldn't auto fix. --------- Co-authored-by: Aaron Gokaslan <[email protected]>
1 parent abb22b4 commit 1d7b4b6

File tree

45 files changed

+209
-213
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+209
-213
lines changed

examples/community/checkpoint_merger.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,24 +199,20 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
199199
if not attr.startswith("_"):
200200
checkpoint_path_1 = os.path.join(cached_folders[1], attr)
201201
if os.path.exists(checkpoint_path_1):
202-
files = list(
203-
(
204-
*glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
205-
*glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
206-
)
207-
)
202+
files = [
203+
*glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
204+
*glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
205+
]
208206
checkpoint_path_1 = files[0] if len(files) > 0 else None
209207
if len(cached_folders) < 3:
210208
checkpoint_path_2 = None
211209
else:
212210
checkpoint_path_2 = os.path.join(cached_folders[2], attr)
213211
if os.path.exists(checkpoint_path_2):
214-
files = list(
215-
(
216-
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
217-
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
218-
)
219-
)
212+
files = [
213+
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
214+
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
215+
]
220216
checkpoint_path_2 = files[0] if len(files) > 0 else None
221217
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
222218
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.

examples/community/imagic_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
def preprocess(image):
5050
w, h = image.size
51-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
51+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
5252
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
5353
image = np.array(image).astype(np.float32) / 255.0
5454
image = image[None].transpose(0, 3, 1, 2)

examples/community/lpw_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def get_weighted_text_embeddings(
376376

377377
def preprocess_image(image):
378378
w, h = image.size
379-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
379+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
380380
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
381381
image = np.array(image).astype(np.float32) / 255.0
382382
image = image[None].transpose(0, 3, 1, 2)
@@ -387,7 +387,7 @@ def preprocess_image(image):
387387
def preprocess_mask(mask, scale_factor=8):
388388
mask = mask.convert("L")
389389
w, h = mask.size
390-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
390+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
391391
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
392392
mask = np.array(mask).astype(np.float32) / 255.0
393393
mask = np.tile(mask, (4, 1, 1))

examples/community/lpw_stable_diffusion_onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def get_weighted_text_embeddings(
403403

404404
def preprocess_image(image):
405405
w, h = image.size
406-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
406+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
407407
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
408408
image = np.array(image).astype(np.float32) / 255.0
409409
image = image[None].transpose(0, 3, 1, 2)
@@ -413,7 +413,7 @@ def preprocess_image(image):
413413
def preprocess_mask(mask, scale_factor=8):
414414
mask = mask.convert("L")
415415
w, h = mask.size
416-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
416+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
417417
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
418418
mask = np.array(mask).astype(np.float32) / 255.0
419419
mask = np.tile(mask, (4, 1, 1))

examples/community/stable_unclip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
):
4747
super().__init__()
4848

49-
decoder_pipe_kwargs = dict(image_encoder=None) if decoder_pipe_kwargs is None else decoder_pipe_kwargs
49+
decoder_pipe_kwargs = {"image_encoder": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs
5050

5151
decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype
5252

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def preprocess_train(examples):
673673
examples["edited_pixel_values"] = edited_images
674674

675675
# Preprocess the captions.
676-
captions = [caption for caption in examples[edit_prompt_column]]
676+
captions = list(examples[edit_prompt_column])
677677
examples["input_ids"] = tokenize_captions(captions)
678678
return examples
679679

examples/rl/run_diffuser_locomotion.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
from diffusers.experimental import ValueGuidedRLPipeline
55

66

7-
config = dict(
8-
n_samples=64,
9-
horizon=32,
10-
num_inference_steps=20,
11-
n_guide_steps=2, # can set to 0 for faster sampling, does not use value network
12-
scale_grad_by_std=True,
13-
scale=0.1,
14-
eta=0.0,
15-
t_grad_cutoff=2,
16-
device="cpu",
17-
)
7+
config = {
8+
"n_samples": 64,
9+
"horizon": 32,
10+
"num_inference_steps": 20,
11+
"n_guide_steps": 2, # can set to 0 for faster sampling, does not use value network
12+
"scale_grad_by_std": True,
13+
"scale": 0.1,
14+
"eta": 0.0,
15+
"t_grad_cutoff": 2,
16+
"device": "cpu",
17+
}
1818

1919

2020
if __name__ == "__main__":

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ target-version = ['py37']
44

55
[tool.ruff]
66
# Never enforce `E501` (line length violations).
7-
ignore = ["E501", "E741", "W605"]
8-
select = ["E", "F", "I", "W"]
7+
ignore = ["C901", "E501", "E741", "W605"]
8+
select = ["C", "E", "F", "I", "W"]
99
line-length = 119
1010

1111
# Ignore import violations in all `__init__.py` files.

scripts/convert_ddpm_original_checkpoint_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
404404
config = json.loads(f.read())
405405

406406
# unet case
407-
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
407+
key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
408408
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
409409
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
410410
else:

scripts/convert_models_diffuser_to_diffusers.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,29 @@ def unet(hor):
2424
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
2525
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
2626
state_dict = model.state_dict()
27-
config = dict(
28-
down_block_types=down_block_types,
29-
block_out_channels=block_out_channels,
30-
up_block_types=up_block_types,
31-
layers_per_block=1,
32-
use_timestep_embedding=True,
33-
out_block_type="OutConv1DBlock",
34-
norm_num_groups=8,
35-
downsample_each_block=False,
36-
in_channels=14,
37-
out_channels=14,
38-
extra_in_channels=0,
39-
time_embedding_type="positional",
40-
flip_sin_to_cos=False,
41-
freq_shift=1,
42-
sample_size=65536,
43-
mid_block_type="MidResTemporalBlock1D",
44-
act_fn="mish",
45-
)
27+
config = {
28+
"down_block_types": down_block_types,
29+
"block_out_channels": block_out_channels,
30+
"up_block_types": up_block_types,
31+
"layers_per_block": 1,
32+
"use_timestep_embedding": True,
33+
"out_block_type": "OutConv1DBlock",
34+
"norm_num_groups": 8,
35+
"downsample_each_block": False,
36+
"in_channels": 14,
37+
"out_channels": 14,
38+
"extra_in_channels": 0,
39+
"time_embedding_type": "positional",
40+
"flip_sin_to_cos": False,
41+
"freq_shift": 1,
42+
"sample_size": 65536,
43+
"mid_block_type": "MidResTemporalBlock1D",
44+
"act_fn": "mish",
45+
}
4646
hf_value_function = UNet1DModel(**config)
4747
print(f"length of state dict: {len(state_dict.keys())}")
4848
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
49-
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
49+
mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
5050
for k, v in mapping.items():
5151
state_dict[v] = state_dict.pop(k)
5252
hf_value_function.load_state_dict(state_dict)
@@ -57,33 +57,33 @@ def unet(hor):
5757

5858

5959
def value_function():
60-
config = dict(
61-
in_channels=14,
62-
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
63-
up_block_types=(),
64-
out_block_type="ValueFunction",
65-
mid_block_type="ValueFunctionMidBlock1D",
66-
block_out_channels=(32, 64, 128, 256),
67-
layers_per_block=1,
68-
downsample_each_block=True,
69-
sample_size=65536,
70-
out_channels=14,
71-
extra_in_channels=0,
72-
time_embedding_type="positional",
73-
use_timestep_embedding=True,
74-
flip_sin_to_cos=False,
75-
freq_shift=1,
76-
norm_num_groups=8,
77-
act_fn="mish",
78-
)
60+
config = {
61+
"in_channels": 14,
62+
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
63+
"up_block_types": (),
64+
"out_block_type": "ValueFunction",
65+
"mid_block_type": "ValueFunctionMidBlock1D",
66+
"block_out_channels": (32, 64, 128, 256),
67+
"layers_per_block": 1,
68+
"downsample_each_block": True,
69+
"sample_size": 65536,
70+
"out_channels": 14,
71+
"extra_in_channels": 0,
72+
"time_embedding_type": "positional",
73+
"use_timestep_embedding": True,
74+
"flip_sin_to_cos": False,
75+
"freq_shift": 1,
76+
"norm_num_groups": 8,
77+
"act_fn": "mish",
78+
}
7979

8080
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
8181
state_dict = model
8282
hf_value_function = UNet1DModel(**config)
8383
print(f"length of state dict: {len(state_dict.keys())}")
8484
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
8585

86-
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
86+
mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys()))
8787
for k, v in mapping.items():
8888
state_dict[v] = state_dict.pop(k)
8989

0 commit comments

Comments
 (0)