Skip to content

Commit 99ce47c

Browse files
committed
add: standalone test for generate_model_card
1 parent 5297ad4 commit 99ce47c

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

tests/others/test_hub_utils.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,57 @@
1717
from tempfile import TemporaryDirectory
1818
from unittest.mock import Mock, patch
1919

20-
import diffusers.utils.hub_utils
20+
from diffusers.utils.hub_utils import create_model_card, generate_model_card
2121

2222

2323
class CreateModelCardTest(unittest.TestCase):
24+
def create_dummy_args(self, output_dir):
25+
# Dummy args values
26+
args = Mock()
27+
args.output_dir = output_dir
28+
args.local_rank = 0
29+
args.hub_token = "hub_token"
30+
args.dataset_name = "dataset_name"
31+
args.learning_rate = 0.01
32+
args.train_batch_size = 100000
33+
args.eval_batch_size = 10000
34+
args.gradient_accumulation_steps = 0.01
35+
args.adam_beta1 = 0.02
36+
args.adam_beta2 = 0.03
37+
args.adam_weight_decay = 0.0005
38+
args.adam_epsilon = 0.000001
39+
args.lr_scheduler = 1
40+
args.lr_warmup_steps = 10
41+
args.ema_inv_gamma = 0.001
42+
args.ema_power = 0.1
43+
args.ema_max_decay = 0.2
44+
args.mixed_precision = True
45+
return args
46+
2447
@patch("diffusers.utils.hub_utils.get_full_repo_name")
2548
def test_create_model_card(self, repo_name_mock: Mock) -> None:
2649
repo_name_mock.return_value = "full_repo_name"
2750
with TemporaryDirectory() as tmpdir:
28-
# Dummy args values
29-
args = Mock()
30-
args.output_dir = tmpdir
31-
args.local_rank = 0
32-
args.hub_token = "hub_token"
33-
args.dataset_name = "dataset_name"
34-
args.learning_rate = 0.01
35-
args.train_batch_size = 100000
36-
args.eval_batch_size = 10000
37-
args.gradient_accumulation_steps = 0.01
38-
args.adam_beta1 = 0.02
39-
args.adam_beta2 = 0.03
40-
args.adam_weight_decay = 0.0005
41-
args.adam_epsilon = 0.000001
42-
args.lr_scheduler = 1
43-
args.lr_warmup_steps = 10
44-
args.ema_inv_gamma = 0.001
45-
args.ema_power = 0.1
46-
args.ema_max_decay = 0.2
47-
args.mixed_precision = True
51+
args = self.create_dummy_args(output_dir=tmpdir)
52+
53+
# Model card mush be rendered and saved
54+
create_model_card(args, model_name="model_name")
55+
self.assertTrue((Path(tmpdir) / "README.md").is_file())
56+
57+
def test_generate_existing_model_card_with_library_name(self):
58+
with TemporaryDirectory() as tmpdir:
59+
args = self.create_dummy_args(output_dir=tmpdir)
4860

4961
# Model card mush be rendered and saved
50-
diffusers.utils.hub_utils.create_model_card(args, model_name="model_name")
62+
create_model_card(args, model_name="model_name")
5163
self.assertTrue((Path(tmpdir) / "README.md").is_file())
64+
65+
model_card = generate_model_card(tmpdir)
66+
assert model_card.data.library_name == "diffusers"
67+
68+
def test_generate_model_card_with_library_name(self):
69+
with TemporaryDirectory() as tmpdir:
70+
model_card = generate_model_card(tmpdir)
71+
72+
model_card = generate_model_card(tmpdir)
73+
assert model_card.data.library_name == "diffusers"

0 commit comments

Comments
 (0)