|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import sys |
| 9 | +import tempfile |
| 10 | +import unittest |
| 11 | +from unittest.mock import MagicMock, patch |
| 12 | + |
| 13 | +from executorch.extension.llm.export.export_llm import ( |
| 14 | + main, |
| 15 | + parse_config_arg, |
| 16 | + pop_config_arg, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +class TestExportLlm(unittest.TestCase): |
| 21 | + def test_parse_config_arg_with_config(self) -> None: |
| 22 | + """Test parse_config_arg when --config is provided.""" |
| 23 | + # Mock sys.argv to include --config |
| 24 | + test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] |
| 25 | + with patch.object(sys, "argv", test_argv): |
| 26 | + config_path, remaining = parse_config_arg() |
| 27 | + self.assertEqual(config_path, "test_config.yaml") |
| 28 | + self.assertEqual(remaining, ["extra", "args"]) |
| 29 | + |
| 30 | + def test_parse_config_arg_without_config(self) -> None: |
| 31 | + """Test parse_config_arg when --config is not provided.""" |
| 32 | + test_argv = ["script.py", "debug.verbose=True"] |
| 33 | + with patch.object(sys, "argv", test_argv): |
| 34 | + config_path, remaining = parse_config_arg() |
| 35 | + self.assertIsNone(config_path) |
| 36 | + self.assertEqual(remaining, ["debug.verbose=True"]) |
| 37 | + |
| 38 | + def test_pop_config_arg(self) -> None: |
| 39 | + """Test pop_config_arg removes --config and its value from sys.argv.""" |
| 40 | + test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] |
| 41 | + with patch.object(sys, "argv", test_argv): |
| 42 | + config_path = pop_config_arg() |
| 43 | + self.assertEqual(config_path, "test_config.yaml") |
| 44 | + self.assertEqual(sys.argv, ["script.py", "other", "args"]) |
| 45 | + |
| 46 | + @patch("executorch.extension.llm.export.export_llm.export_llama") |
| 47 | + def test_with_config(self, mock_export_llama: MagicMock) -> None: |
| 48 | + """Test main function with --config file and no hydra args.""" |
| 49 | + # Create a temporary config file |
| 50 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
| 51 | + f.write( |
| 52 | + """ |
| 53 | +base: |
| 54 | + tokenizer_path: /path/to/tokenizer.json |
| 55 | +export: |
| 56 | + max_seq_length: 256 |
| 57 | +""" |
| 58 | + ) |
| 59 | + config_file = f.name |
| 60 | + |
| 61 | + try: |
| 62 | + test_argv = ["script.py", "--config", config_file] |
| 63 | + with patch.object(sys, "argv", test_argv): |
| 64 | + main() |
| 65 | + |
| 66 | + # Verify export_llama was called with config |
| 67 | + mock_export_llama.assert_called_once() |
| 68 | + called_config = mock_export_llama.call_args[0][0] |
| 69 | + self.assertEqual( |
| 70 | + called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" |
| 71 | + ) |
| 72 | + self.assertEqual(called_config["export"]["max_seq_length"], 256) |
| 73 | + finally: |
| 74 | + os.unlink(config_file) |
| 75 | + |
| 76 | + def test_with_cli_args(self) -> None: |
| 77 | + """Test main function with only hydra CLI args.""" |
| 78 | + test_argv = ["script.py", "debug.verbose=True"] |
| 79 | + with patch.object(sys, "argv", test_argv): |
| 80 | + with patch( |
| 81 | + "executorch.extension.llm.export.export_llm.hydra_main" |
| 82 | + ) as mock_hydra: |
| 83 | + main() |
| 84 | + mock_hydra.assert_called_once() |
| 85 | + |
| 86 | + def test_config_with_cli_args_error(self) -> None: |
| 87 | + """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" |
| 88 | + # Create a temporary config file |
| 89 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
| 90 | + f.write("base:\n checkpoint: /path/to/checkpoint.pth") |
| 91 | + config_file = f.name |
| 92 | + |
| 93 | + try: |
| 94 | + test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] |
| 95 | + with patch.object(sys, "argv", test_argv): |
| 96 | + with self.assertRaises(ValueError) as cm: |
| 97 | + main() |
| 98 | + |
| 99 | + error_msg = str(cm.exception) |
| 100 | + self.assertIn( |
| 101 | + "Cannot specify additional CLI arguments when using --config", |
| 102 | + error_msg, |
| 103 | + ) |
| 104 | + finally: |
| 105 | + os.unlink(config_file) |
| 106 | + |
| 107 | + def test_config_rejects_multiple_cli_args(self) -> None: |
| 108 | + """Test that --config rejects multiple CLI arguments (not just single ones).""" |
| 109 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
| 110 | + f.write("export:\n max_seq_length: 128") |
| 111 | + config_file = f.name |
| 112 | + |
| 113 | + try: |
| 114 | + test_argv = [ |
| 115 | + "script.py", |
| 116 | + "--config", |
| 117 | + config_file, |
| 118 | + "debug.verbose=True", |
| 119 | + "export.output_dir=/tmp", |
| 120 | + ] |
| 121 | + with patch.object(sys, "argv", test_argv): |
| 122 | + with self.assertRaises(ValueError): |
| 123 | + main() |
| 124 | + finally: |
| 125 | + os.unlink(config_file) |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + unittest.main() |
0 commit comments