Skip to content

Commit cdb3cc6

Browse files
committed
Adds --ccache (doesn't really help though).
1 parent 8b58016 commit cdb3cc6

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import time
2323

2424
import jax
25+
import jax.experimental.compilation_cache.compilation_cache as cc
2526
import jax.numpy as jnp
2627
import numpy as np
2728
import optax
@@ -230,6 +231,12 @@ def parse_args():
230231
action="store_true",
231232
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
232233
)
234+
parser.add_argument(
235+
"--ccache",
236+
type=str,
237+
default=None,
238+
help="Enables compilation cache.",
239+
)
233240
parser.add_argument(
234241
"--tokenizer_name",
235242
type=str,
@@ -971,6 +978,9 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
971978
}
972979
)
973980

981+
if args.ccache:
982+
cc.initialize_cache(args.ccache)
983+
974984
global_step = 0
975985
epochs = tqdm(
976986
range(args.num_train_epochs),

0 commit comments

Comments
 (0)