Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/fairseq2/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from fairseq2.recipe.evaluator import EvalUnit as EvalUnit
from fairseq2.recipe.generator import Generator as Generator
from fairseq2.recipe.generator import GeneratorUnit as GeneratorUnit
from fairseq2.recipe.model import DDPModel as DDPModel
from fairseq2.recipe.model import FSDP1Model as FSDP1Model
from fairseq2.recipe.model import FSDP2Model as FSDP2Model
from fairseq2.recipe.model import RecipeModel as RecipeModel
from fairseq2.recipe.model import StandardRecipeModel as StandardRecipeModel
from fairseq2.recipe.run import evaluate as evaluate
from fairseq2.recipe.run import generate as generate
from fairseq2.recipe.run import train as train
Expand Down
30 changes: 25 additions & 5 deletions src/fairseq2/recipe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from fairseq2.recipe.config import RecipeConfig
from fairseq2.recipe.error import (
BeamSearchAlgorithmNotKnownError,
DatasetTypeNotValidError,
DeviceTypeNotSupportedError,
ErrorContext,
FSDPNotSupportedError,
Expand All @@ -89,6 +90,7 @@
SequenceGeneratorNotKnownError,
SplitNotKnownError,
TokenizerModelNotFoundError,
TokenizerTypeNotValidError,
TorchCompileError,
TorchCompileNotSupportedError,
WandbInitializationError,
Expand Down Expand Up @@ -507,9 +509,10 @@ def register(kls: type[ExceptionT], handler: ExceptionHandler[ExceptionT]) -> No
register(ClusterNotKnownError, _handle_cluster_not_known_error)
register(ComponentNotKnownError, _handle_component_not_known_error)
register(DataReadError, _handle_data_read_error)
register(DatasetError, _handle_dataset_error)
register(DatasetFamilyNotKnownError, _handle_dataset_family_not_known_error)
register(DatasetNotKnownError, _handle_dataset_not_known_error)
register(DatasetError, _handle_dataset_error)
register(DatasetTypeNotValidError, _handle_dataset_type_not_valid_error)
register(DeviceTypeNotSupportedError, _handle_device_type_not_supported_error)
register(EnvironmentVariableError, _handle_env_variable_error)
register(FSDPNotSupportedError, _handle_fsdp_not_supported_error)
Expand Down Expand Up @@ -538,6 +541,7 @@ def register(kls: type[ExceptionT], handler: ExceptionHandler[ExceptionT]) -> No
register(TokenizerModelError, _handle_tokenizer_model_error)
register(TokenizerModelNotFoundError, _handle_tokenizer_model_not_found_error)
register(TokenizerNotKnownError, _handle_tokenizer_not_known_error)
register(TokenizerTypeNotValidError, _handle_tokenizer_type_not_valid_error)
register(TorchCompileError, _handle_torch_compile_error)
register(TorchCompileNotSupportedError, _handle_torch_compile_not_supported_error)
register(WandbInitializationError, _handle_wandb_init_error)
Expand Down Expand Up @@ -637,6 +641,15 @@ def _handle_dataset_error(ex: DatasetError) -> int:
return 1


def _handle_dataset_type_not_valid_error(ex: DatasetTypeNotValidError) -> int:
if ex.section_name == "dataset":
log.error("Dataset must be of type `{}`, but is of type `{}` instead.", ex.expected_kls, ex.kls)
else:
log.error("Dataset specified in `{}` section must be of type `{}`, but is of type `{}` instead.", ex.section_name, ex.expected_kls, ex.kls)

return 2


def _handle_device_type_not_supported_error(ex: DeviceTypeNotSupportedError) -> int:
log.error("For distributed jobs, only `cpu` and `cuda` devices are supported, but the device of the process is `{}`.", ex.device)

Expand Down Expand Up @@ -761,12 +774,10 @@ def _handle_model_not_known_error(ex: ModelNotKnownError) -> int:


def _handle_model_type_not_valid_error(ex: ModelTypeNotValidError) -> int:
section_name = ErrorContext.maybe_get_config_section_name(ex)

if section_name is None:
if ex.section_name == "model":
log.error("Model must be of type `{}`, but is of type `{}` instead.", ex.expected_kls, ex.kls)
else:
log.error("Model specified in `{}` section must be of type `{}`, but is of type `{}` instead.", section_name, ex.expected_kls, ex.kls)
log.error("Model specified in `{}` section must be of type `{}`, but is of type `{}` instead.", ex.section_name, ex.expected_kls, ex.kls)

return 2

Expand Down Expand Up @@ -832,6 +843,15 @@ def _handle_tokenizer_not_known_error(ex: TokenizerNotKnownError) -> int:
return 2


def _handle_tokenizer_type_not_valid_error(ex: TokenizerTypeNotValidError) -> int:
if ex.section_name == "tokenizer":
log.error("Tokenizer must be of type `{}`, but is of type `{}` instead.", ex.expected_kls, ex.kls)
else:
log.error("Tokenizer specified in `{}` section must be of type `{}`, but is of type `{}` instead.", ex.section_name, ex.expected_kls, ex.kls)

return 2


def _handle_torch_compile_error(ex: TorchCompileError) -> int:
section_name = ErrorContext.maybe_get_config_section_name(ex)

Expand Down
24 changes: 17 additions & 7 deletions src/fairseq2/recipe/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,30 @@

from typing import TypeVar, final

from fairseq2.datasets import DatasetFamily
from fairseq2.recipe.error import DatasetTypeNotValidError

DatasetT = TypeVar("DatasetT")


@final
class RecipeDataset:
def __init__(
self, inner_dataset: object, config: object, family: DatasetFamily
self,
inner_dataset: object,
config: object,
family_name: str,
*,
section_name: str = "dataset",
) -> None:
self._inner_dataset = inner_dataset
self._config = config
self._family = family
self._family_name = family_name
self._section_name = section_name

def as_(self, kls: type[DatasetT]) -> DatasetT:
if not isinstance(self._inner_dataset, kls):
raise TypeError(
f"Dataset is expected to be of type `{kls}`, but is of type `{type(self._inner_dataset)}` instead."
raise DatasetTypeNotValidError(
type(self._inner_dataset), kls, self._section_name
)

return self._inner_dataset
Expand All @@ -35,5 +41,9 @@ def config(self) -> object:
return self._config

@property
def family(self) -> DatasetFamily:
return self._family
def family_name(self) -> str:
return self._family_name

@property
def section_name(self) -> str:
return self._section_name
32 changes: 31 additions & 1 deletion src/fairseq2/recipe/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from torch.nn import Module

from fairseq2.data.tokenizers import Tokenizer
from fairseq2.device import Device


Expand All @@ -30,6 +31,19 @@ def __init__(self, name: str) -> None:
self.name = name


class DatasetTypeNotValidError(Exception):
def __init__(
self, kls: type[object], expected_kls: type[object], section_name: str
) -> None:
super().__init__(
f"Dataset must be of type `{expected_kls}`, but is of type `{kls}` instead."
)

self.kls = kls
self.expected_kls = expected_kls
self.section_name = section_name


class DeviceTypeNotSupportedError(Exception):
def __init__(self, device: Device) -> None:
super().__init__(
Expand Down Expand Up @@ -111,13 +125,16 @@ def __init__(self, path: Path) -> None:


class ModelTypeNotValidError(Exception):
def __init__(self, kls: type[Module], expected_kls: type[Module]) -> None:
def __init__(
self, kls: type[Module], expected_kls: type[Module], section_name: str
) -> None:
super().__init__(
f"Model must be of type `{expected_kls}`, but is of type `{kls}` instead."
)

self.kls = kls
self.expected_kls = expected_kls
self.section_name = section_name


class OptimizerNotKnownError(Exception):
Expand Down Expand Up @@ -159,6 +176,19 @@ def __init__(self, path: Path) -> None:
self.path = path


class TokenizerTypeNotValidError(Exception):
def __init__(
self, kls: type[Tokenizer], expected_kls: type[Tokenizer], section_name: str
) -> None:
super().__init__(
f"Tokenizer must be of type `{expected_kls}`, but is of type `{kls}` instead."
)

self.kls = kls
self.expected_kls = expected_kls
self.section_name = section_name


class TorchCompileError(Exception):
def __init__(self) -> None:
super().__init__("torch.compile() failed.")
Expand Down
9 changes: 6 additions & 3 deletions src/fairseq2/recipe/internal/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
from __future__ import annotations

from fairseq2.logging import log
from fairseq2.models import ModelFamily
from fairseq2.recipe.config import CompileOptions
from fairseq2.recipe.error import TorchCompileError, TorchCompileNotSupportedError
from fairseq2.recipe.model import RecipeModel


def _compile_model(model: RecipeModel, options: CompileOptions) -> None:
if not model.family.supports_compilation:
def _compile_model(
model: RecipeModel, family: ModelFamily, options: CompileOptions
) -> None:
if not family.supports_compilation:
raise TorchCompileNotSupportedError()

log.info("Applying torch.compile() to the model.")

try:
model.family.compile(
family.compile(
model.module,
fullgraph=options.fullgraph,
dynamic=options.dynamic,
Expand Down
Loading
Loading