You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
@@ -405,14 +405,14 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
405
405
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
406
406
407
407
```bash
408
-
--checkpointing_steps=500
408
+
--checkpointing_steps=500
409
409
```
410
410
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
411
411
412
412
You can then start your training from this saved checkpoint with
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
@@ -422,7 +422,7 @@ We also support gradient accumulation - it is a technique that lets you use a bi
422
422
You can **profile your code** with:
423
423
424
424
```bash
425
-
--profile_steps==5
425
+
--profile_steps==5
426
426
```
427
427
428
428
Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
0 commit comments