|
25 | 25 | from itertools import product
|
26 | 26 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
27 | 27 |
|
28 |
| -from neural_compressor.common.logger import Logger |
29 |
| -from neural_compressor.common.utility import ( |
| 28 | +from neural_compressor.common import Logger |
| 29 | +from neural_compressor.common.utils import ( |
30 | 30 | BASE_CONFIG,
|
31 | 31 | COMPOSABLE_CONFIG,
|
32 | 32 | DEFAULT_WHITE_LIST,
|
| 33 | + DEFAULT_WORKSPACE, |
33 | 34 | EMPTY_WHITE_LIST,
|
34 | 35 | GLOBAL,
|
35 | 36 | LOCAL,
|
|
38 | 39 |
|
39 | 40 | logger = Logger().get_logger()
|
40 | 41 |
|
| 42 | +__all__ = [ |
| 43 | + "ConfigRegistry", |
| 44 | + "register_config", |
| 45 | + "BaseConfig", |
| 46 | + "ComposableConfig", |
| 47 | + "Options", |
| 48 | + "options", |
| 49 | +] |
41 | 50 |
|
42 | 51 | # Dictionary to store registered configurations
|
43 | 52 |
|
@@ -411,3 +420,122 @@ def to_config_mapping(
|
411 | 420 | def register_supported_configs(cls):
|
412 | 421 | """Add all supported configs."""
|
413 | 422 | raise NotImplementedError
|
| 423 | + |
| 424 | + |
| 425 | +def _check_value(name, src, supported_type, supported_value=[]): |
| 426 | + """Check if the given object is the given supported type and in the given supported value. |
| 427 | +
|
| 428 | + Example:: |
| 429 | +
|
| 430 | + from neural_compressor.common.base_config import _check_value |
| 431 | +
|
| 432 | + def datatype(self, datatype): |
| 433 | + if _check_value("datatype", datatype, list, ["fp32", "bf16", "uint8", "int8"]): |
| 434 | + self._datatype = datatype |
| 435 | + """ |
| 436 | + if isinstance(src, list) and any([not isinstance(i, supported_type) for i in src]): |
| 437 | + assert False, "Type of {} items should be {} but not {}".format( |
| 438 | + name, str(supported_type), [type(i) for i in src] |
| 439 | + ) |
| 440 | + elif not isinstance(src, list) and not isinstance(src, supported_type): |
| 441 | + assert False, "Type of {} should be {} but not {}".format(name, str(supported_type), type(src)) |
| 442 | + |
| 443 | + if len(supported_value) > 0: |
| 444 | + if isinstance(src, str) and src not in supported_value: |
| 445 | + assert False, "{} is not in supported {}: {}. Skip setting it.".format(src, name, str(supported_value)) |
| 446 | + elif ( |
| 447 | + isinstance(src, list) |
| 448 | + and all([isinstance(i, str) for i in src]) |
| 449 | + and any([i not in supported_value for i in src]) |
| 450 | + ): |
| 451 | + assert False, "{} is not in supported {}: {}. Skip setting it.".format(src, name, str(supported_value)) |
| 452 | + |
| 453 | + return True |
| 454 | + |
| 455 | + |
| 456 | +class Options: |
| 457 | + """Option Class for configs. |
| 458 | +
|
| 459 | + This class is used for configuring global variables. The global variable options is created with this class. |
| 460 | + If you want to change global variables, you should use functions from neural_compressor.common.utils.utility.py: |
| 461 | + set_random_seed(seed: int) |
| 462 | + set_workspace(workspace: str) |
| 463 | + set_resume_from(resume_from: str) |
| 464 | + set_tensorboard(tensorboard: bool) |
| 465 | +
|
| 466 | + Args: |
| 467 | + random_seed(int): Random seed used in neural compressor. |
| 468 | + Default value is 1978. |
| 469 | + workspace(str): The directory where intermediate files and tuning history file are stored. |
| 470 | + Default value is: |
| 471 | + "./nc_workspace/{}/".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")). |
| 472 | + resume_from(str): The directory you want to resume tuning history file from. |
| 473 | + The tuning history was automatically saved in the workspace directory |
| 474 | + during the last tune process. |
| 475 | + Default value is None. |
| 476 | + tensorboard(bool): This flag indicates whether to save the weights of the model and the inputs of each layer |
| 477 | + for visual display. |
| 478 | + Default value is False. |
| 479 | +
|
| 480 | + Example:: |
| 481 | +
|
| 482 | + from neural_compressor.common import set_random_seed, set_workspace, set_resume_from, set_tensorboard |
| 483 | + set_random_seed(2022) |
| 484 | + set_workspace("workspace_path") |
| 485 | + set_resume_from("workspace_path") |
| 486 | + set_tensorboard(True) |
| 487 | + """ |
| 488 | + |
| 489 | + def __init__(self, random_seed=1978, workspace=DEFAULT_WORKSPACE, resume_from=None, tensorboard=False): |
| 490 | + """Init an Option object.""" |
| 491 | + self.random_seed = random_seed |
| 492 | + self.workspace = workspace |
| 493 | + self.resume_from = resume_from |
| 494 | + self.tensorboard = tensorboard |
| 495 | + |
| 496 | + @property |
| 497 | + def random_seed(self): |
| 498 | + """Get random seed.""" |
| 499 | + return self._random_seed |
| 500 | + |
| 501 | + @random_seed.setter |
| 502 | + def random_seed(self, random_seed): |
| 503 | + """Set random seed.""" |
| 504 | + if _check_value("random_seed", random_seed, int): |
| 505 | + self._random_seed = random_seed |
| 506 | + |
| 507 | + @property |
| 508 | + def workspace(self): |
| 509 | + """Get workspace.""" |
| 510 | + return self._workspace |
| 511 | + |
| 512 | + @workspace.setter |
| 513 | + def workspace(self, workspace): |
| 514 | + """Set workspace.""" |
| 515 | + if _check_value("workspace", workspace, str): |
| 516 | + self._workspace = workspace |
| 517 | + |
| 518 | + @property |
| 519 | + def resume_from(self): |
| 520 | + """Get resume_from.""" |
| 521 | + return self._resume_from |
| 522 | + |
| 523 | + @resume_from.setter |
| 524 | + def resume_from(self, resume_from): |
| 525 | + """Set resume_from.""" |
| 526 | + if resume_from is None or _check_value("resume_from", resume_from, str): |
| 527 | + self._resume_from = resume_from |
| 528 | + |
| 529 | + @property |
| 530 | + def tensorboard(self): |
| 531 | + """Get tensorboard.""" |
| 532 | + return self._tensorboard |
| 533 | + |
| 534 | + @tensorboard.setter |
| 535 | + def tensorboard(self, tensorboard): |
| 536 | + """Set tensorboard.""" |
| 537 | + if _check_value("tensorboard", tensorboard, bool): |
| 538 | + self._tensorboard = tensorboard |
| 539 | + |
| 540 | + |
| 541 | +options = Options() |
0 commit comments