Skip to content

Add image_processor #2617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Mar 15, 2023
Merged

Add image_processor #2617

merged 46 commits into from
Mar 15, 2023

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Mar 9, 2023

added a VaeImageProcessor class that provides unified API for preprocessing and postprocessing of image inputs for pipelines

Original PR:
#2304

to-do:

  • refactor depth_to_image, ControlNet and pix2pix
  • improve tests (think we should move the relevant test to PipelineTesterMixin)

@yiyixuxu yiyixuxu changed the title add image_processor [WIP] add image_processor Mar 9, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -674,7 +658,7 @@ def __call__(
)

# 4. Preprocess image
image = preprocess(image)
image = self.vae_feature_extractor.encode(image)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool Think the design is great! Just left a couple of comments to make the image processor class even a bit more robust.

Think once all the comments are treated we can apply the image processor also to depth_to_image and pix2pix pipeline and then add new tests to all three pipelines (img2img, depth_to_image, StableDiffusionControlNetPipeline, and pix2pix) to check that all different input & output combinations work)

@patrickvonplaten
Copy link
Contributor

@pcuenca and @williamberman can you also take a look here?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Left a few comments / questions.

"""
if images.ndim == 3:
images = images[..., None]
elif images.ndim == 5:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand, we accept tensors in 3 forms:

  1. with batch dimension ([B,C,H,W])
  2. without the batch dimension ([C, H, W] )
  3. a list of tensors with shape [C,H,W])

and same goes for numpy array too

the way code is written, we will get ndim=5 for tensors with batch dimension because we put it into a list and do torch.cat()

Comment on lines 442 to 424
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the return type different now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah now it returns torch tensor here so if the output_type is pt it can stays in the device


return image

def decode(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we denormalize here (if appropriate), to return the range to [0, 1]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, it's denormalized inside decode_latent - I think we can move it to the image processor, but I'm not sure how the decoding part of the image processor fits in the pipeline - I tried to refactor the img2img pipeline with it, but it seems that we can't abstract the post-processing away from the pipeline if we don't move the safty_checker to image processor

images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
return images

def encode(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing that these methods are called encode / decode, same as those of the autoencoder. Is this standard nomenclature we use in transformers, or elsewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed - maybe preprocess is better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True preprocess and postprocess might be better here

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Almost done - think we just have two final TODOs:

  • 1.) Improve the error message in the processor slightly
  • 2.) Remove the deprecation warning for now to not copy it in all other processors (sorry this was a bad call from my side)

@yiyixuxu
Copy link
Collaborator Author

@patrickvonplaten let me know if the resize error message is ok now - I refactored the preprocess method a little bit, so now different input formats are processed in separate elif blocks with no shared processing - and I throw an error message for numpy and pytorch separately at where we would to apply resize if we were to support it

happy to change it if you think it is better to throw one error message. I think I've address all other comments and that's the last thing left

@@ -32,6 +32,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .image_processor import VaeImageProcessor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not make it public fro now

@patrickvonplaten patrickvonplaten changed the title [WIP] add image_processor Add image_processor Mar 15, 2023
@patrickvonplaten
Copy link
Contributor

Great, I think we can merge this and then go fix the other pipelines in follow-up PRs. Think we just need to run a quick make style and this should be good to go.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

@yiyixuxu yiyixuxu merged commit e52cd55 into main Mar 15, 2023
@yiyixuxu yiyixuxu deleted the image-processor branch March 15, 2023 17:56
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* add image_processor

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add image_processor

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add image_processor

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants