Skip to content

Commit d4f846f

Browse files
[WIP]Flax training script for controlnet (#2818)
* add train_controlnet_flax --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 58fc824 commit d4f846f

File tree

2 files changed

+1097
-0
lines changed

2 files changed

+1097
-0
lines changed

examples/controlnet/README.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,99 @@ image = pipe(
267267
268268
image.save("./output.png")
269269
```
270+
271+
## Training with Flax/JAX
272+
273+
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
274+
275+
### Running on Google Cloud TPU
276+
277+
See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).
278+
279+
First create a single TPUv4-8 VM and connect to it:
280+
281+
```
282+
ZONE=us-central2-b
283+
TPU_TYPE=v4-8
284+
VM_NAME=hg_flax
285+
286+
gcloud alpha compute tpus tpu-vm create $VM_NAME \
287+
--zone $ZONE \
288+
--accelerator-type $TPU_TYPE \
289+
--version tpu-vm-v4-base
290+
291+
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
292+
```
293+
294+
When connected install JAX `0.4.5`:
295+
296+
```
297+
pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
298+
```
299+
300+
To verify that JAX was correctly installed, you can run the following command:
301+
302+
```
303+
import jax
304+
jax.device_count()
305+
```
306+
307+
This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.
308+
309+
Then install Diffusers and the library's training dependencies:
310+
311+
```bash
312+
git clone https://github.com/huggingface/diffusers
313+
cd diffusers
314+
pip install .
315+
```
316+
317+
Then cd in the example folder and run
318+
319+
```bash
320+
pip install -U -r requirements_flax.txt
321+
```
322+
323+
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
324+
325+
```
326+
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
327+
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
328+
```
329+
330+
We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
331+
332+
```
333+
huggingface-cli login
334+
```
335+
336+
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
337+
338+
```
339+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
340+
export OUTPUT_DIR="control_out"
341+
export HUB_MODEL_ID="fill-circle-controlnet"
342+
```
343+
344+
And finally start the training
345+
346+
```
347+
python3 train_controlnet_flax.py \
348+
--pretrained_model_name_or_path=$MODEL_DIR \
349+
--output_dir=$OUTPUT_DIR \
350+
--dataset_name=fusing/fill50k \
351+
--resolution=512 \
352+
--learning_rate=1e-5 \
353+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
354+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
355+
--validation_steps=1000 \
356+
--train_batch_size=2 \
357+
--revision="non-ema" \
358+
--from_pt \
359+
--report_to="wandb" \
360+
--max_train_steps=10000 \
361+
--push_to_hub \
362+
--hub_model_id=$HUB_MODEL_ID
363+
```
364+
365+
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).

0 commit comments

Comments
 (0)