-
Notifications
You must be signed in to change notification settings - Fork 6.1k
small tweaks for parsing thibaudz controlnet checkpoints #3657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
small tweaks for parsing thibaudz controlnet checkpoints #3657
Conversation
|
||
# small workaround to get argparser to parse a boolean input as either true _or_ false | ||
def parse_bool(string): | ||
if string == "True": | ||
return True | ||
elif string == "False": | ||
return False | ||
else: | ||
raise ValueError(f"could not parse string as bool {string}") | ||
|
||
parser.add_argument( | ||
"--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool | ||
) | ||
|
||
parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are overrides for these two configs which aren't correctly set for the config file https://huggingface.co/thibaud/controlnet-sd21/blob/main/control_v11p_sd21_openpose.yaml
if skip_extract_state_dict: | ||
unet_state_dict = checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this allows us to skip the initial key trimming for when the root level checkpoint doesn't need to be trimmed. If we don't manually skip it, the unet_state_dict
will be empty because none of the keys have the expected prefix
# Some controlnet ckpt files are distributed independently from the rest of the | ||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ | ||
if "time_embed.0.weight" in checkpoint: | ||
skip_extract_state_dict = True | ||
else: | ||
skip_extract_state_dict = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the controlnet case, we can manually perform the check if this is the root level state dict
The documentation is not available anymore as the PR was closed or merged. |
Nice! Good to merge for me once the following tests pass:
|
confirmed pass! |
re: #2774