Skip to content

[Community] Testing Stable Diffusion is hard 🥵 #937

@patrickvonplaten

Description

@patrickvonplaten

It's really difficult to test stable diffusion due to the following:

    1. Continous output: Diffusion models take float values as input and output float values. This is different from NLP models which tend to take int64 as inputs and int64 as outputs.
    1. Output dimensions are huge. If an image has a output size of (1, 512, 512, 3) this means that there are 512 * 512 * 3 ~ 800,000 values that need to be within a given range. Say if you want to test for a max difference of (pred - ref).abs() < 1e-3 we have roughly a million values where this has to hold true. This is quite different in NLP where we rather test things like text generation or final logit layers which usually aren't bigger then a dozen or so tensors of size 768 or 1024.
    1. Error propagation: We cannot simple test one forward pass for stable diffusion because in practice people use 50 forward passes. Error propagation becomes a real problem in this case. This again is different from say generation in NLP because in generation at every generation step errors can be somewhat "smoothed" out since a "argmax" of "softmax" operation is used after each step
    1. Composite systems: Stable Diffusion has three main components for inference: A Unet, a scheduler and a VAE decoder. The UNet and Scheduler are very entangled during the forward pass. Just because we know the forward pass of both the scheduler and unet work independently, it doesn't mean that using them together works.

=> Therefore, we need to do full integration tests, meaning we need to make sure that the output of a full denoising process stays within a given error range. At the moment, we're having quite some problems though to get full reproducible of results on different GPUs, CUDA versions etc... (especially for FP16).

That being said, it is extremely important to test stable diffusion to avoid issues like this in the future: #902 whereas we should still be able to improve speed with PRs like this: #371

At the moment, we're running multiple integration tests for all 50 diffusion steps every time a PR is merged to master, see:

Nevertheless, the tests weren't sufficient to detect: #902

Testing Puzzle 🧩: How can we find the best trade-off between fast & in-expensive test and best possible test coverage taking into account the above points?

We already looked quite a bit into: https://pytorch.org/docs/stable/notes/randomness.html

Metadata

Metadata

Labels

staleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions