10
10
import unittest
11
11
from unittest .mock import MagicMock , patch
12
12
13
- from executorch .examples .models .llama .config .llm_config import LlmConfig
14
- from executorch .extension .llm .export .export_llm import main , parse_config_arg , pop_config_arg
13
+ from executorch .extension .llm .export .export_llm import (
14
+ main ,
15
+ parse_config_arg ,
16
+ pop_config_arg ,
17
+ )
15
18
16
19
17
20
class TestExportLlm (unittest .TestCase ):
@@ -45,12 +48,14 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
45
48
"""Test main function with --config file and no hydra args."""
46
49
# Create a temporary config file
47
50
with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
48
- f .write ("""
51
+ f .write (
52
+ """
49
53
base:
50
54
tokenizer_path: /path/to/tokenizer.json
51
55
export:
52
56
max_seq_length: 256
53
- """ )
57
+ """
58
+ )
54
59
config_file = f .name
55
60
56
61
try :
@@ -61,7 +66,9 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
61
66
# Verify export_llama was called with config
62
67
mock_export_llama .assert_called_once ()
63
68
called_config = mock_export_llama .call_args [0 ][0 ]
64
- self .assertEqual (called_config ["base" ]["tokenizer_path" ], "/path/to/tokenizer.json" )
69
+ self .assertEqual (
70
+ called_config ["base" ]["tokenizer_path" ], "/path/to/tokenizer.json"
71
+ )
65
72
self .assertEqual (called_config ["export" ]["max_seq_length" ], 256 )
66
73
finally :
67
74
os .unlink (config_file )
@@ -70,7 +77,9 @@ def test_with_cli_args(self) -> None:
70
77
"""Test main function with only hydra CLI args."""
71
78
test_argv = ["script.py" , "debug.verbose=True" ]
72
79
with patch .object (sys , "argv" , test_argv ):
73
- with patch ("executorch.extension.llm.export.export_llm.hydra_main" ) as mock_hydra :
80
+ with patch (
81
+ "executorch.extension.llm.export.export_llm.hydra_main"
82
+ ) as mock_hydra :
74
83
main ()
75
84
mock_hydra .assert_called_once ()
76
85
@@ -86,9 +95,12 @@ def test_config_with_cli_args_error(self) -> None:
86
95
with patch .object (sys , "argv" , test_argv ):
87
96
with self .assertRaises (ValueError ) as cm :
88
97
main ()
89
-
98
+
90
99
error_msg = str (cm .exception )
91
- self .assertIn ("Cannot specify additional CLI arguments when using --config" , error_msg )
100
+ self .assertIn (
101
+ "Cannot specify additional CLI arguments when using --config" ,
102
+ error_msg ,
103
+ )
92
104
finally :
93
105
os .unlink (config_file )
94
106
@@ -99,7 +111,13 @@ def test_config_rejects_multiple_cli_args(self) -> None:
99
111
config_file = f .name
100
112
101
113
try :
102
- test_argv = ["script.py" , "--config" , config_file , "debug.verbose=True" , "export.output_dir=/tmp" ]
114
+ test_argv = [
115
+ "script.py" ,
116
+ "--config" ,
117
+ config_file ,
118
+ "debug.verbose=True" ,
119
+ "export.output_dir=/tmp" ,
120
+ ]
103
121
with patch .object (sys , "argv" , test_argv ):
104
122
with self .assertRaises (ValueError ):
105
123
main ()
@@ -109,4 +127,3 @@ def test_config_rejects_multiple_cli_args(self) -> None:
109
127
110
128
if __name__ == "__main__" :
111
129
unittest .main ()
112
-
0 commit comments