Skip to content

Conversation

pbontrager
Copy link
Contributor

In an effort to update the recipe to match the design of the recipe RFC, I needed to update the current utils to function in a similar way to the utils in the RFC. This PR refactors existing utils and doesn't add new utils or update the recipe, that will be a followup PR. These utils are meant to hide a lot of complexity for the user so they can add features without thinking too much about it, if they want more granular control they have direct access in the recipe to not use them and use core Pytorch.

Changelog

Utils folder after this PR: argparse, data, device, distributed, precision, seed, generation, logits_transforms
(Utils are not meant to be finetuning specific and not anything, logits_transforms might still need to be recategorized)

  • Split init_from_env into get_device and init_distributed: get_device allows users much finer grained control of the device they want to use but still can automatically get the device. init_distributed contains the remaining functionality of init_from_env around starting distributed process groups.
  • seed was changed to set_seed and moved to it's own utility: set_seed is unchanged except it can take in None and it'll randomly select a seed to set, it returns this int. This will allow users to run with a fixed seed that they can save/log but get a new seed every time they launch a new run.
  • precision was changed to get_dtype, autocast, and get_gradient_autoscaler: get_dtype can take either strings or dtypes from the user and returns a dtype object making it easier to change between str/object. autocast is get_autocast_manager just simplified to take advantage of get_dtype. get_gradient_autoscaler is the same as get_grad_scaler but it will return a disabled GradScaler when fp16 isn't provided. This allows the recipe to be simplified.
  • env was renamed to distributed as all the remaining env functions related to distributed variables and future distributed utils will be added there.
  • ReproducibleDataloader was moved to utils so we won't have a trainer folder. It was put in a utils.data file with batch_seq collate function. The docstring was also updated to pass pydoclint and the exception removed from the linter
  • Added util functions all to utils/init.py
  • finetune_llm had small changes to use these new utils but no large change yet
  • Tests were all updated to handle the new names, imports, additional edge cases

Test plan

  • All existing and updated tests (including the recipe) pass without any expected value changes
  • New edge cases added to unit tests since some utils allow more varied user inputs
  • Both commands from the README Quickstart launch and start runing

Copy link

netlify bot commented Jan 11, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 349b53a
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65aafdd977d0d900080f7ce8
😎 Deploy Preview https://deploy-preview-180--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 11, 2024
@@ -17,5 +17,6 @@ optimizer: SGD
loss: CrossEntropyLoss
output_dir: /tmp/alpaca-llama2-finetune
device: cuda
dtype: bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated changes keep to separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, that should be fp32

_DEFAULT_LABEL_PADDING_IDX: int = -100


class ReproducibleDataLoader(DataLoader):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this is being removed in #161

@@ -82,6 +62,7 @@ def recipe(kwargs):
{TransformerDecoderLayer}
) # TODO: remove model specific components
if kwargs["fsdp"]:
utils.init_distributed(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason I'm not able to comment on the entire block, but this block of code around FSDP and activation checkpointing is really ugly. If our goal is to abstract this complexity away from the user and to allow them to peel back the layers as much as they want (IMO that should absolutely be the goal), then this isn't achieving that goal.

Curious if @rohan-varma has any thoughts here . In its current form, this is really ugly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything related to FSDP is out of scope for this PR. This PR was reworking the existing utils. Moving the distributed code to utils is for the followup PR. Specifically "init_distributed" is a temporary function containing the distributed portion of "init_from_env" since the rest of the logic was moved to get_device. init_distributed will likely be removed in the followup PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment on the PR - I'd just roll in all of the changes in this PR since in its current form it's creating more confusion than is necessary.

@@ -134,10 +116,10 @@ def recipe(kwargs):
input_ids = input_ids.to(device)
labels = labels.to(device)

# Note: context manager for autocast is only applied in forward pass.
# Automatically handles mixed precision when given a low precision dtype.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this level of magic? Why can't we make things a bit more explicit i.e. ask the user to specify whether they want mixed precision or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, this is the level of magic decided on by the core team with the introduction of torch.autocast. We're just extending autocast to include dtype=fp32

from torchtune.utils.generation import GenerationUtils
from torchtune.utils.seed import set_seed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for changing the name - seed(..) was really confusing.

assert device.type == "cpu"
assert device.index is None
@patch("torch.backends.mps.is_available", return_value=False)
def test_get_cpu_device(self, mock_cuda, mock_mps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the noob question, but I thought the order of the patches needs to be reversed since they are applied bottoms up? If this is true, then the two params need to exchange positions.

Ref: https://stackoverflow.com/questions/47042196/mock-patches-appearing-in-the-wrong-order

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been addressed. Am I wrong or is the test wrong? Also, can we remove the mentions of MPS?



class TestSeed:
def test_seed_range(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the noob question, but it seems like this should be a validation on the input config rather than a test. What am I missing?

@@ -20,7 +21,7 @@
def get_model(name: str, device: Union[str, torch.device], **kwargs) -> Module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pardon the noob question, but why are all of these functions in __init___ instead of models/config_utils.py or something similar. All of these functions are essentially parsing the config right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refers back to our config philosophy in general. I think that discussion shouldn't happen here with this PR. But as a reference, recipe level configs are meant to be "one level deep". i.e. the user can select they want a model but not configure the model from the config. This forces the user to define a new builder function for a model if they want a new configuration of a model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still doesn't answer my question for why this logic is in ___init___.py. What am I missing?

@@ -20,7 +21,7 @@
def get_model(name: str, device: Union[str, torch.device], **kwargs) -> Module:
"""Get known supported models by name"""
if name in _MODEL_DICT:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems really clunky. Like if we support 100 models, will we manually expand this dict to a 100 key-value pairs? What alternatives have we considered here? I'm not sold on this being the best way to instantiate the model from the config.

I've mentioned this a few times, but I'm really missing Hydra's ability to just instantiate from a class name over here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned, this discussion is best on the ArgParse PR. But we made an explicit choice to avoid powerful config parsers like Hydra. The alternative to this is a model registry, where you add a decorator on new models to register it but @ebsmothers had concerns about that approach causing all models to get imported when we a recipe does "from torchtune import models"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pbontrager I'd like to have this discussion again. Can you link the ArgsParse PR here as well as any design doc where we spoke about configs? This is the entry point for every user bucket and we have to get it right. So I'd like to understand the option space and the associated pros-and-cons.

from .device import get_device
from .distributed import init_distributed
from .precision import autocast, get_dtype, get_gradient_autoscaler, list_dtypes
from .seed import set_seed

__all__ = [
"TuneArgumentParser",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: TuneArgumentParser is a really weird name. Tune is a verb and so this makes it sound like we're tuneing the argument parser. Can we either expand this to TorchTuneArgParser or something better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create an issue for this?

from torch.utils.data.dataloader import _get_distributed_settings

# TokenPair is a pair (tuple) of two lists: tokenized text inputs and labels.
TokenPair = Tuple[List[int], List[int]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This should be TokenPairs or TokensPairs. We're defining a tuple over lists and so the singular form of this doesn't make sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below, I don't really want to update data utils in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a comment on this. Let's roll in all of the changes into this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pbontrager can we address this comment?

TokenPair = Tuple[List[int], List[int]]

_DEFAULT_INPUT_PADDING_IDX: int = 0
_DEFAULT_LABEL_PADDING_IDX: int = -100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the reasoning behind these magical values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I don't think we really need these, but I don't really want to make changes to the data utils in this PR. I was just moving them into a consistent namespace.

Comment on lines 151 to 160
input_ids = pad_sequence(
[torch.tensor(x[0]) for x in batch],
batch_first=True,
padding_value=input_padding_idx,
)
labels = pad_sequence(
[torch.tensor(x[1]) for x in batch],
batch_first=True,
padding_value=label_padding_idx,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making a note of how inefficient this is. We have two for loops of size B (for creating the input_ids and the labels sequences) and then each call to pad_sequence is going to iterate through B inputs to find the longest sequence and pad to that length. Just because we're calling built-in functions doesn't mean this is efficient. Can we do this better?

Not sure who the original author here is but cc: @rohan-varma, @gokulavasan

@@ -5,14 +5,15 @@
# LICENSE file in the root directory of this source tree.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I'm grossly misunderstanding this section, but this file is very poorly written.

  • The functions are not cleanly defined i.e. both of them seem to be reading from and validating environment variables.
  • MPS support is thrown in as an after-thought without validation, testing or any plan for future support.
  • get_device seems to be a generalized entry point (the signature takes in a Union) but I don't see where this generalization is coming from. The only call is from the recipe which takes in the string from the config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended as a rework of the device logic in "init_from_env". The issue with init_from_env is that it doesn't allow the user much control over the device they want to use. So get_device() allows a user request a device and get_device will check the environment to see if that device is possible, if there is missing information (like cuda index) it'll set that. This utility makes it easy for the recipe and other utilities to all work with devices the same way and safely, if any utility that needs a device gets a string or a torch.device, it can convert it to torch.device and won't require a lot of conditional logic for converting between types.

  • get_device has to validate that the requested device is possible, while _get_device_from_env is responsible for returning a device if the user didn't provide it. I can provide a separate _validate_device_from_env for better separation.
  • I'm not sure why this appears as an after thought. Before, we were explicitly blocking users from using mps for no reason which is very frustrating to users who want to take advantage of their hardware. Having our util automatically select mps over cpu is how you'd expect the util to work independent of the recipe, it's up to the user as to whether they want to run the recipe on mps.
  • other utils use this internally, if we require that it only takes in a string, then we implicitly enforce an order on when get_device has to be called in people's scripts which makes it more brittle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This utility makes it easy for the recipe and other utilities to all work with devices the same way and safely

If this is the case, then the core logic should be separated out into a private function and get_device (exposed to the user) should only do what the name suggests i.e. take in a string, validate this requirement and return a torch.device. Function called from other utilities should then wrap around this private function and should themselves be private.


import torch


def _get_device_from_env() -> torch.device:
"""Function that gets the torch.device based on the current environment.

This currently supports only CPU and GPU devices. If CUDA is available, this function also sets the CUDA device.
This currently supports CPU, GPU and MPS. If CUDA is available, this function also sets the CUDA device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we validated that MPS backend works? Do we have tests for this? Are we going to provide first-class support for this by consistently checking and making sure things aren't breaking when there are updates from Apple? If the answer to any of this is NO (which I believe it is), then please remove any references to MPS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our device util should support every backend that torch core supports (currently only added mps here but we should add all of them). This does not mean that our recipes work on those devices, but it does mean that users can experiment with them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our device util should support every backend that torch core supports

We need to align on what "supporting every backend" means. Just adding the device doesnt mean anything. You need to make sure all of the functionality works, numerics are on par and then have solid tests to make sure nothing breaks. We haven't done any of this. In the future we need to support MPS, AMD, Windows etc. All of this will need work and thought on how we support this. For now, please remove mentions of MPS.

else:
device = torch.device("cpu")
return device


def set_float32_precision(precision: str = "high") -> None:
"""Sets the precision of float32 matrix multiplications and convolution operations.
def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I don't understand this function signature. The recipe is calling get_device on the value specified within the config. Why is this a union? When is it called with a torch.device as input?

Also, what are the sets of valid values for device? This is taking arbitrary strings as input and we need to check for typos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_device isn't just used from the config, it is used by other utils and the user might not be interacting with the recipe from a util but instead a jupyter notebook. The set of strings available is whatever is supported by torch.device. Think of this as an extension of torch.device that takes both string names and devices as input. This just adds additional logic to get/test that the device is compatible with your env settings (including distributed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to my comment about cleanly separating public and private functions.

else:
torch.backends.cudnn.allow_tf32 = True
# Convert device string to torch.device
if type(device) != torch.device:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along the lines of the comment above, I don't understand when this check will fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user passes in a torch.device instead of a string

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mean logically :)

torch.backends.cudnn.allow_tf32 = True
# Convert device string to torch.device
if type(device) != torch.device:
if device is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the case when device is None? When the user doesnt specify this in config? Why is that a valid setting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our default device value in our configs should likely be None. Then if a user runs it on a machine without cuda, it'll automatically set the device to cpu, though they can set it in the config. This would become more useful if we add ROCm support to _get_device_from_env in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would become more useful if we add ROCm support to _get_device_from_env

For now can we define the utilities for the features we support? We can define how to generalize when we get there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the feature is kept, what happens when None is passed should be documented

device = torch.device(device)

# Get device rank for cuda devices if not provided, and set Cuda device
if device.type == "cuda":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this separate from the block above i.e. we're parsing a string from the config so _get_device_from_env will be called and so the index will be set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string could be "cuda" or "cuda:4" for example. None of these call _get_device_from_env, they instead call torch.device(device_str). Also, as mentioned, the input might not be a string at all. This section ensures that device.index is set if it hasn't been, and also that the index is possible if it was set. (one thing we have to test for is if the user sets cuda:4, useful for a single process run, but it's a distributed run, so the distributed_launcher assigned that process to "cuda:7")

torch.cuda.set_device(device)

# Check if the device index is correct when distributed training
local_rank = os.environ.get("LOCAL_RANK", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not checked inside _get_device_from_env?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is checking if a specific index provided by the user matches the local rank. If _get_device_from_env is called, then the index would be set from LOCAL_RANK and there wouldn't be an issue.

@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi this file has been deleted upstream

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by review, mainly waiting for the next push

Comment on lines 51 to 52
world_size, rank = utils.get_world_size_and_rank()
seed = utils.set_seed(kwargs["seed"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -88,6 +65,7 @@ def recipe(kwargs):
{TransformerDecoderLayer}
) # TODO: remove model specific components
if kwargs["fsdp"]:
utils.init_distributed(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to have distributed initialization at the very beginning. We need to avoid bugs such as get_world_size being called before distributed is initialized, and more nuanced things such as how seed setting works in distributed scenario.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the distributed functions to call init_distributed themselves now, so whichever function gets called first will initialize the process group.


if seed is None:
seed = random.randint(min_val, max_val)
local_seed = seed + rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, though I guess if dist is not initialized, rank should be 0 as returned by get_world_size_and_rank. If we don't do this here, curious where is your suggestion to do it?

if seed is None:
seed = random.randint(min_val, max_val)
local_seed = seed + rank
_log.debug(f"Setting seed to {seed} and rank local seed to {local_seed}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the log should clarify whether distributed is initialized or not, and if it is, clarify we're offsetting the local seed by the rank.

Also, this is confusing because it says "Setting seed to {seed}", but actually what's set in code is the manual_seed.

_log: logging.Logger = logging.getLogger(__name__)


def set_seed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder what our principle for utils is in general on them doing multiple (but somewhat related) things: in this particular case, setting the random seed and the deterministic debug mode. Would it be simpler and lead to better testable components if we had separate utils to do these two tasks, or do we feel that these are intertwined enough to have to be done together?

cc @kartikayk @NicolasHug @pbontrager

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be fine with pulling the debug_mode stuff into its own separate function but wouldn't want it included by default in the recipe then. In general I'm on the fence where it comes to functions that do multiple things, they should do just one thing as much as possible but sometimes we need higher level functions that stitch them together.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is useful to discuss, and I'd vote for pulling the determinism to a separate function as it makes this function more in line with its name and less complicated. But feel free to do this in a separate PR.

model=model,
device=device,
dtype=dtype,
strategy="FUll_SHARD",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo? How does this pass the getattr check in _get_sharding_strategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it was a typo, it was not getting passed to _get_sharding_strategy because of an incorrect "not" in "if strategy is None" which resulted in it being run with the default "NO_SHARD".

auto_wrap_policy=auto_wrap_policy,
device_id=device,
param_init_fn=lambda m: m.to_empty(device=device, recurse=False),
logger = logging.getLogger()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: All functions are following the format "get_" except this one. Let's be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't our function, it's the python logger. I wonder for a user facing script if we should just use "print"

logger.info(msg=f"Loaded tokenizer from {tokenizer_checkpoint}")

model = models.get_model(model, device=device)
if distributed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block of code is applying FSDP if the flag is True. is_distributed assumes that distributed training == FSDP, which is a bad assumption. This flag should be renamed to enable_fsdp or is_fsdp

model = models.get_model(model, device=device)
if distributed:
# TODO: initialize models for distributed on meta or cpu device to avoid OOMs
model = utils.get_distributed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more appropriate name for this function is something like apply_fsdp.

loss_fn = get_loss(kwargs["loss"])
loss_fn = losses.get_loss(loss)

grad_scaler = utils.get_gradient_scaler(dtype, distributed) or GradScaler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few unrelated comments on this:

I'm not sure why this needs to be so complicated (you lost me at Optional[Union[...]]). Since you've already called get_dtype and converted the string into torch.dtype, why does get_gradient_scaler need to take in a string? You're unnecessarily making a second call to get_type inside the function. Can we just pass in the resolved type and remove the Union? This also removes the need for this to be Optional (which doesn't make any sense since the function has a hard dependency on dtype.

I'm also not a fan of binary operators between a function which MAY return a None and an object - it's just not very readable code.

Zooming out a bit, I don't think this utility is saving the reader/user anything. You've added the cognitive load of trying to figure out when get_gradient_scaler returns None on a user who doesn't care about mixed_precision training. So they still need to read through a bunch of code, but you've made it harder since they need to navigate to an entirely different file. In this case, just unrolling the code here would have at least spared the user the need to find this extra file and figure what's going on.

If I put on my "user hat", then the following would have made my life easier by explicitly differentiating between the scenario where I'm using mixed_precision and where I'm not.

if dtype == 'fp16':
      grad_scaler = utils.get_mixed_precision_scaler(enable_fsdp=True)
else:
      grad_scaler = GradScaler(enabled=False)

My larger point is that simply abstracting code out of the recipe into helper functions without thinking about whether it's making the reader's life easier or not is not going to lead to the outcomes you're hoping for. Our original hope was to reduce cognitive load by not unnecessarily littering the recipe with if-else blocks. I don't think this is achieving that goal.

grad_scaler = utils.get_gradient_scaler(dtype, distributed) or GradScaler(
enabled=False
)
autocast = utils.get_autocast(dtype, device) or contextlib.nullcontext()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have exactly the same comments here. You already have resolved the dtype and device from the input string, there's no need to add Unions for these params and make the reader/user's life harder.

Same comment about cognitive load on the user. I'm not sure what we're saving here. Unrolling the code at least saves me a level of indirection.

autocast = contextlib.nullcontext()
if dtype == 'fp16' or dtype == 'bf16':
        autocast = torch.autocast(
            device_type=device.type,
            dtype=dtype,
        )

from torchtune.utils.distributed import get_world_size_and_rank


class TestDistributed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to rebase on top of
https://github.com/pytorch-labs/torchtune/pull/193/files

This already adds a number of tests related to world size and rank.

cc: @gokulavasan who's the author of this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's rebase and use Gokul's tests here which have distributed multiprocessing set up.

from torch import nn


def get_loss(loss: str) -> nn.Module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a config parsing function (you called this a "getter") inside a top-level file called losses.py doesn't make sense to me. losses.py is where I would expect loss implementations, especially in a top-level file. This function needs to move into config_utils.py. Folder structure will likely be fixed once this lands and we have a complete view of everything.

@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as losses.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, these are more getter code, feel like torchtune.optim should house something like memory-efficient optimizers, for example

Comment on lines +84 to +87
raise RuntimeError(
f"You can't specify a device index when using distributed training. \
Device specified is {device} but was assigned cuda:{local_rank}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does this check actually pass? Is this a user-side error when the device index is not the same as the machine this is running on?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would pass when we're in distributed, and the CUDA device this rank is using is properly set (i.e. to the local rank of the process we're on).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if the user passes in "cuda:4" as their device but it's a distributed run so the local process has its own device index given from the environment.

torch.backends.cudnn.allow_tf32 = True
if device is None:
device = _get_device_type_from_env()
device = torch.device(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_device_type_from_env returns torch.device and then we pass this torch.device to torch.device? That doesn't make any sense?

return 1, 0


def get_distributed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in the recipe, this function should be updated to something like apply_fsdp

cc: @rohan-varma

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super awesome work, love how this is shaping up and thanks for painstakingly going through our feedback!

Left a bunch of comments and indicated which ones in particular can be punted on. Biggest concern is probably the just in time distributed init style, and the fact that there's no explicit call to distributed init in the recipe and I believe it happens in set_seed, which is an odd side effect.

@@ -17,5 +17,6 @@ optimizer: SGD
loss: CrossEntropyLoss
output_dir: /tmp/alpaca-llama2-finetune
device: cuda
fsdp: False
dtype: fp32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Can punt to next PR] Not clear to me what dtype this refers to. For example, if it is fp16, is the entire model loaded in fp16, is it autocast fp16, etc.

@@ -17,5 +17,6 @@ optimizer: SGD
loss: CrossEntropyLoss
output_dir: /tmp/alpaca-llama2-finetune
device: cuda
fsdp: False
dtype: fp32
distributed: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Address in this PR] Should specify the parallelism algorithm instead of "distributed" which is very generic IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I floated the "distributed" idea as I think it's a more general and well understood name and the fsdp is actually a generalized data parallel paradigm, if I run FSDP with No_Shard, it's not actually FSDP. But I also understand, keep the name with the underlying library we're calling will better inform the user so I'll change it back.

logger = logging.getLogger()
device = utils.get_device(device)
dtype = utils.get_dtype(dtype)
seed = utils.set_seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Address in this PR] Think this reintroduces the same bug as #193.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, technically not due to the side effect of init distributed

logger.info(msg=f"Loaded tokenizer from {tokenizer_checkpoint}")

model = models.get_model(model, device=device)
if distributed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Address in this PR] Also it's not clear to me where distributed is being initialized, i.e. init_process_group call. This should be separate from setting up the parallelism algorithm, and should occur at the very beginning of the recipe.

opt.step()
grad_scaler.scale(loss).backward()
grad_scaler.step(opt)
grad_scaler.update()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still would prefer explicit use of grad scaler and regular loss.backward() ; optimizer.step() but think we've discussed this enough and if the consensus advocates for this, then let's just do it.

def set_activation_checkpointing(
model: nn.Module, auto_wrap_policy: Optional[Set[nn.Module]] = None, **kwargs
) -> None:
"""Utility to setup activation checkpointing and setup the model sharding.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Activation checkpointing won't shard the model so let's remove "set up the model sharding"


Args:
model (nn.Module): Model to setup activation checkpointing.
auto_wrap_policy (Optional[Set[nn.Module]]): Policy to wrap module for sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding doesn't happen here

_log: logging.Logger = logging.getLogger(__name__)


def set_seed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is useful to discuss, and I'd vote for pulling the determinism to a separate function as it makes this function more in line with its name and less complicated. But feel free to do this in a separate PR.


if seed is None:
seed = random.randint(min_val, max_val)
local_seed = seed + rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re the concern around using this before distributed has been initialized, it's a very valid concern as this happened before, and seems like its actually happening in this PR again, but isn't because distributed is sort of just in time initialized. I'm not a fan of just in time initializing distributed though, it should be an explicit call so that the user knows when it happens.

For example in the current recipe, the function set_seed is what initializes distributed AFAICT which seems like an odd side effect for such a function to have

np.random.seed(local_seed)
random.seed(local_seed)

if debug_mode is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looking at this the debug mode seems entirely separate. I'd just pull it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it was already part of seed before this PR, I'll leave it as is for now and it can be split out later.

@pbontrager pbontrager merged commit 1645f4a into main Jan 20, 2024
@pbontrager pbontrager deleted the phil-device-fix branch January 20, 2024 17:12
@@ -4,61 +4,35 @@
# This source code is licensed under the BSD-style license found in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you updated the name of this file, are we planning to update finetune_llm as well to match?

import logging

import torch

from torchtune.models.llama2 import llama2_7b, llama2_tokenizer
from torchtune.utils.env import _get_device_from_env, seed
from torchtune.utils import get_device, set_seed, TuneArgumentParser
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In llama_generate we do from torchtune.utils import ... but in finetune_llm we do from torchtune import utils. Should we update finetune_llm to match this? (I think the way it's done in llama_generate is cleaner)

@joecummings joecummings mentioned this pull request Jan 22, 2024
@pbontrager pbontrager linked an issue Jan 22, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add autocast and scaler
7 participants