Skip to content

Commit acd2079

Browse files
committed
Update
[ghstack-poisoned]
1 parent eedd833 commit acd2079

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

extension/llm/export/export_llm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@
3434
from typing import Any, List, Tuple
3535

3636
import hydra
37-
import yaml
3837

3938
from executorch.examples.models.llama.config.llm_config import LlmConfig
4039
from executorch.examples.models.llama.export_llama_lib import export_llama
4140
from hydra.core.config_store import ConfigStore
42-
from omegaconf import DictConfig, OmegaConf
41+
from omegaconf import OmegaConf
4342

4443
cs = ConfigStore.instance()
4544
cs.store(name="llm_config", node=LlmConfig)
@@ -79,7 +78,7 @@ def main() -> None:
7978
"Cannot specify additional CLI arguments when using --config. "
8079
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
8180
)
82-
81+
8382
config_file_path = pop_config_arg()
8483
default_llm_config = LlmConfig()
8584
llm_config_from_file = OmegaConf.load(config_file_path)

extension/llm/export/test/test_export_llm.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
import unittest
1111
from unittest.mock import MagicMock, patch
1212

13-
from executorch.examples.models.llama.config.llm_config import LlmConfig
14-
from executorch.extension.llm.export.export_llm import main, parse_config_arg, pop_config_arg
13+
from executorch.extension.llm.export.export_llm import (
14+
main,
15+
parse_config_arg,
16+
pop_config_arg,
17+
)
1518

1619

1720
class TestExportLlm(unittest.TestCase):
@@ -45,12 +48,14 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
4548
"""Test main function with --config file and no hydra args."""
4649
# Create a temporary config file
4750
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
48-
f.write("""
51+
f.write(
52+
"""
4953
base:
5054
tokenizer_path: /path/to/tokenizer.json
5155
export:
5256
max_seq_length: 256
53-
""")
57+
"""
58+
)
5459
config_file = f.name
5560

5661
try:
@@ -61,7 +66,9 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
6166
# Verify export_llama was called with config
6267
mock_export_llama.assert_called_once()
6368
called_config = mock_export_llama.call_args[0][0]
64-
self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json")
69+
self.assertEqual(
70+
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
71+
)
6572
self.assertEqual(called_config["export"]["max_seq_length"], 256)
6673
finally:
6774
os.unlink(config_file)
@@ -70,7 +77,9 @@ def test_with_cli_args(self) -> None:
7077
"""Test main function with only hydra CLI args."""
7178
test_argv = ["script.py", "debug.verbose=True"]
7279
with patch.object(sys, "argv", test_argv):
73-
with patch("executorch.extension.llm.export.export_llm.hydra_main") as mock_hydra:
80+
with patch(
81+
"executorch.extension.llm.export.export_llm.hydra_main"
82+
) as mock_hydra:
7483
main()
7584
mock_hydra.assert_called_once()
7685

@@ -86,9 +95,12 @@ def test_config_with_cli_args_error(self) -> None:
8695
with patch.object(sys, "argv", test_argv):
8796
with self.assertRaises(ValueError) as cm:
8897
main()
89-
98+
9099
error_msg = str(cm.exception)
91-
self.assertIn("Cannot specify additional CLI arguments when using --config", error_msg)
100+
self.assertIn(
101+
"Cannot specify additional CLI arguments when using --config",
102+
error_msg,
103+
)
92104
finally:
93105
os.unlink(config_file)
94106

@@ -99,7 +111,13 @@ def test_config_rejects_multiple_cli_args(self) -> None:
99111
config_file = f.name
100112

101113
try:
102-
test_argv = ["script.py", "--config", config_file, "debug.verbose=True", "export.output_dir=/tmp"]
114+
test_argv = [
115+
"script.py",
116+
"--config",
117+
config_file,
118+
"debug.verbose=True",
119+
"export.output_dir=/tmp",
120+
]
103121
with patch.object(sys, "argv", test_argv):
104122
with self.assertRaises(ValueError):
105123
main()
@@ -109,4 +127,3 @@ def test_config_rejects_multiple_cli_args(self) -> None:
109127

110128
if __name__ == "__main__":
111129
unittest.main()
112-

0 commit comments

Comments
 (0)