|
17 | 17 | from tempfile import TemporaryDirectory |
18 | 18 | from unittest.mock import Mock, patch |
19 | 19 |
|
20 | | -import diffusers.utils.hub_utils |
| 20 | +from diffusers.utils.hub_utils import create_model_card, generate_model_card |
21 | 21 |
|
22 | 22 |
|
23 | 23 | class CreateModelCardTest(unittest.TestCase): |
| 24 | + def create_dummy_args(self, output_dir): |
| 25 | + # Dummy args values |
| 26 | + args = Mock() |
| 27 | + args.output_dir = output_dir |
| 28 | + args.local_rank = 0 |
| 29 | + args.hub_token = "hub_token" |
| 30 | + args.dataset_name = "dataset_name" |
| 31 | + args.learning_rate = 0.01 |
| 32 | + args.train_batch_size = 100000 |
| 33 | + args.eval_batch_size = 10000 |
| 34 | + args.gradient_accumulation_steps = 0.01 |
| 35 | + args.adam_beta1 = 0.02 |
| 36 | + args.adam_beta2 = 0.03 |
| 37 | + args.adam_weight_decay = 0.0005 |
| 38 | + args.adam_epsilon = 0.000001 |
| 39 | + args.lr_scheduler = 1 |
| 40 | + args.lr_warmup_steps = 10 |
| 41 | + args.ema_inv_gamma = 0.001 |
| 42 | + args.ema_power = 0.1 |
| 43 | + args.ema_max_decay = 0.2 |
| 44 | + args.mixed_precision = True |
| 45 | + return args |
| 46 | + |
24 | 47 | @patch("diffusers.utils.hub_utils.get_full_repo_name") |
25 | 48 | def test_create_model_card(self, repo_name_mock: Mock) -> None: |
26 | 49 | repo_name_mock.return_value = "full_repo_name" |
27 | 50 | with TemporaryDirectory() as tmpdir: |
28 | | - # Dummy args values |
29 | | - args = Mock() |
30 | | - args.output_dir = tmpdir |
31 | | - args.local_rank = 0 |
32 | | - args.hub_token = "hub_token" |
33 | | - args.dataset_name = "dataset_name" |
34 | | - args.learning_rate = 0.01 |
35 | | - args.train_batch_size = 100000 |
36 | | - args.eval_batch_size = 10000 |
37 | | - args.gradient_accumulation_steps = 0.01 |
38 | | - args.adam_beta1 = 0.02 |
39 | | - args.adam_beta2 = 0.03 |
40 | | - args.adam_weight_decay = 0.0005 |
41 | | - args.adam_epsilon = 0.000001 |
42 | | - args.lr_scheduler = 1 |
43 | | - args.lr_warmup_steps = 10 |
44 | | - args.ema_inv_gamma = 0.001 |
45 | | - args.ema_power = 0.1 |
46 | | - args.ema_max_decay = 0.2 |
47 | | - args.mixed_precision = True |
| 51 | + args = self.create_dummy_args(output_dir=tmpdir) |
| 52 | + |
| 53 | + # Model card mush be rendered and saved |
| 54 | + create_model_card(args, model_name="model_name") |
| 55 | + self.assertTrue((Path(tmpdir) / "README.md").is_file()) |
| 56 | + |
| 57 | + def test_generate_existing_model_card_with_library_name(self): |
| 58 | + with TemporaryDirectory() as tmpdir: |
| 59 | + args = self.create_dummy_args(output_dir=tmpdir) |
48 | 60 |
|
49 | 61 | # Model card mush be rendered and saved |
50 | | - diffusers.utils.hub_utils.create_model_card(args, model_name="model_name") |
| 62 | + create_model_card(args, model_name="model_name") |
51 | 63 | self.assertTrue((Path(tmpdir) / "README.md").is_file()) |
| 64 | + |
| 65 | + model_card = generate_model_card(tmpdir) |
| 66 | + assert model_card.data.library_name == "diffusers" |
| 67 | + |
| 68 | + def test_generate_model_card_with_library_name(self): |
| 69 | + with TemporaryDirectory() as tmpdir: |
| 70 | + model_card = generate_model_card(tmpdir) |
| 71 | + |
| 72 | + model_card = generate_model_card(tmpdir) |
| 73 | + assert model_card.data.library_name == "diffusers" |
0 commit comments