-
Notifications
You must be signed in to change notification settings - Fork 671
Utils Refactor #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utils Refactor #180
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
@@ -17,5 +17,6 @@ optimizer: SGD | |||
loss: CrossEntropyLoss | |||
output_dir: /tmp/alpaca-llama2-finetune | |||
device: cuda | |||
dtype: bf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated changes keep to separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, that should be fp32
torchtune/utils/data.py
Outdated
_DEFAULT_LABEL_PADDING_IDX: int = -100 | ||
|
||
|
||
class ReproducibleDataLoader(DataLoader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI this is being removed in #161
recipes/finetune_llm.py
Outdated
@@ -82,6 +62,7 @@ def recipe(kwargs): | |||
{TransformerDecoderLayer} | |||
) # TODO: remove model specific components | |||
if kwargs["fsdp"]: | |||
utils.init_distributed(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason I'm not able to comment on the entire block, but this block of code around FSDP and activation checkpointing is really ugly. If our goal is to abstract this complexity away from the user and to allow them to peel back the layers as much as they want (IMO that should absolutely be the goal), then this isn't achieving that goal.
Curious if @rohan-varma has any thoughts here . In its current form, this is really ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything related to FSDP is out of scope for this PR. This PR was reworking the existing utils. Moving the distributed code to utils is for the followup PR. Specifically "init_distributed" is a temporary function containing the distributed portion of "init_from_env" since the rest of the logic was moved to get_device. init_distributed will likely be removed in the followup PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment on the PR - I'd just roll in all of the changes in this PR since in its current form it's creating more confusion than is necessary.
recipes/finetune_llm.py
Outdated
@@ -134,10 +116,10 @@ def recipe(kwargs): | |||
input_ids = input_ids.to(device) | |||
labels = labels.to(device) | |||
|
|||
# Note: context manager for autocast is only applied in forward pass. | |||
# Automatically handles mixed precision when given a low precision dtype. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this level of magic? Why can't we make things a bit more explicit i.e. ask the user to specify whether they want mixed precision or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, this is the level of magic decided on by the core team with the introduction of torch.autocast. We're just extending autocast to include dtype=fp32
from torchtune.utils.generation import GenerationUtils | ||
from torchtune.utils.seed import set_seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for changing the name - seed(..)
was really confusing.
tests/torchtune/utils/test_device.py
Outdated
assert device.type == "cpu" | ||
assert device.index is None | ||
@patch("torch.backends.mps.is_available", return_value=False) | ||
def test_get_cpu_device(self, mock_cuda, mock_mps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the noob question, but I thought the order of the patches needs to be reversed since they are applied bottoms up? If this is true, then the two params need to exchange positions.
Ref: https://stackoverflow.com/questions/47042196/mock-patches-appearing-in-the-wrong-order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hasn't been addressed. Am I wrong or is the test wrong? Also, can we remove the mentions of MPS?
|
||
|
||
class TestSeed: | ||
def test_seed_range(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the noob question, but it seems like this should be a validation on the input config rather than a test. What am I missing?
@@ -20,7 +21,7 @@ | |||
def get_model(name: str, device: Union[str, torch.device], **kwargs) -> Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pardon the noob question, but why are all of these functions in __init___
instead of models/config_utils.py or something similar. All of these functions are essentially parsing the config right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refers back to our config philosophy in general. I think that discussion shouldn't happen here with this PR. But as a reference, recipe level configs are meant to be "one level deep". i.e. the user can select they want a model but not configure the model from the config. This forces the user to define a new builder function for a model if they want a new configuration of a model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still doesn't answer my question for why this logic is in ___init___.py
. What am I missing?
@@ -20,7 +21,7 @@ | |||
def get_model(name: str, device: Union[str, torch.device], **kwargs) -> Module: | |||
"""Get known supported models by name""" | |||
if name in _MODEL_DICT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems really clunky. Like if we support 100 models, will we manually expand this dict to a 100 key-value pairs? What alternatives have we considered here? I'm not sold on this being the best way to instantiate the model from the config.
I've mentioned this a few times, but I'm really missing Hydra's ability to just instantiate from a class name over here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned, this discussion is best on the ArgParse PR. But we made an explicit choice to avoid powerful config parsers like Hydra. The alternative to this is a model registry, where you add a decorator on new models to register it but @ebsmothers had concerns about that approach causing all models to get imported when we a recipe does "from torchtune import models"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbontrager I'd like to have this discussion again. Can you link the ArgsParse PR here as well as any design doc where we spoke about configs? This is the entry point for every user bucket and we have to get it right. So I'd like to understand the option space and the associated pros-and-cons.
torchtune/utils/__init__.py
Outdated
from .device import get_device | ||
from .distributed import init_distributed | ||
from .precision import autocast, get_dtype, get_gradient_autoscaler, list_dtypes | ||
from .seed import set_seed | ||
|
||
__all__ = [ | ||
"TuneArgumentParser", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: TuneArgumentParser is a really weird name. Tune is a verb and so this makes it sound like we're tuneing the argument parser. Can we either expand this to TorchTuneArgParser or something better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you create an issue for this?
from torch.utils.data.dataloader import _get_distributed_settings | ||
|
||
# TokenPair is a pair (tuple) of two lists: tokenized text inputs and labels. | ||
TokenPair = Tuple[List[int], List[int]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This should be TokenPairs
or TokensPairs
. We're defining a tuple over lists and so the singular form of this doesn't make sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below, I don't really want to update data utils in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a comment on this. Let's roll in all of the changes into this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbontrager can we address this comment?
torchtune/utils/data.py
Outdated
TokenPair = Tuple[List[int], List[int]] | ||
|
||
_DEFAULT_INPUT_PADDING_IDX: int = 0 | ||
_DEFAULT_LABEL_PADDING_IDX: int = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats the reasoning behind these magical values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I don't think we really need these, but I don't really want to make changes to the data utils in this PR. I was just moving them into a consistent namespace.
torchtune/utils/data.py
Outdated
input_ids = pad_sequence( | ||
[torch.tensor(x[0]) for x in batch], | ||
batch_first=True, | ||
padding_value=input_padding_idx, | ||
) | ||
labels = pad_sequence( | ||
[torch.tensor(x[1]) for x in batch], | ||
batch_first=True, | ||
padding_value=label_padding_idx, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making a note of how inefficient this is. We have two for loops of size B (for creating the input_ids and the labels sequences) and then each call to pad_sequence is going to iterate through B inputs to find the longest sequence and pad to that length. Just because we're calling built-in functions doesn't mean this is efficient. Can we do this better?
Not sure who the original author here is but cc: @rohan-varma, @gokulavasan
@@ -5,14 +5,15 @@ | |||
# LICENSE file in the root directory of this source tree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I'm grossly misunderstanding this section, but this file is very poorly written.
- The functions are not cleanly defined i.e. both of them seem to be reading from and validating environment variables.
- MPS support is thrown in as an after-thought without validation, testing or any plan for future support.
get_device
seems to be a generalized entry point (the signature takes in a Union) but I don't see where this generalization is coming from. The only call is from the recipe which takes in the string from the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended as a rework of the device logic in "init_from_env". The issue with init_from_env is that it doesn't allow the user much control over the device they want to use. So get_device() allows a user request a device and get_device will check the environment to see if that device is possible, if there is missing information (like cuda index) it'll set that. This utility makes it easy for the recipe and other utilities to all work with devices the same way and safely, if any utility that needs a device gets a string or a torch.device, it can convert it to torch.device and won't require a lot of conditional logic for converting between types.
- get_device has to validate that the requested device is possible, while _get_device_from_env is responsible for returning a device if the user didn't provide it. I can provide a separate _validate_device_from_env for better separation.
- I'm not sure why this appears as an after thought. Before, we were explicitly blocking users from using mps for no reason which is very frustrating to users who want to take advantage of their hardware. Having our util automatically select mps over cpu is how you'd expect the util to work independent of the recipe, it's up to the user as to whether they want to run the recipe on mps.
- other utils use this internally, if we require that it only takes in a string, then we implicitly enforce an order on when get_device has to be called in people's scripts which makes it more brittle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This utility makes it easy for the recipe and other utilities to all work with devices the same way and safely
If this is the case, then the core logic should be separated out into a private function and get_device (exposed to the user) should only do what the name suggests i.e. take in a string, validate this requirement and return a torch.device. Function called from other utilities should then wrap around this private function and should themselves be private.
torchtune/utils/device.py
Outdated
|
||
import torch | ||
|
||
|
||
def _get_device_from_env() -> torch.device: | ||
"""Function that gets the torch.device based on the current environment. | ||
|
||
This currently supports only CPU and GPU devices. If CUDA is available, this function also sets the CUDA device. | ||
This currently supports CPU, GPU and MPS. If CUDA is available, this function also sets the CUDA device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we validated that MPS backend works? Do we have tests for this? Are we going to provide first-class support for this by consistently checking and making sure things aren't breaking when there are updates from Apple? If the answer to any of this is NO (which I believe it is), then please remove any references to MPS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our device util should support every backend that torch core supports (currently only added mps here but we should add all of them). This does not mean that our recipes work on those devices, but it does mean that users can experiment with them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our device util should support every backend that torch core supports
We need to align on what "supporting every backend" means. Just adding the device doesnt mean anything. You need to make sure all of the functionality works, numerics are on par and then have solid tests to make sure nothing breaks. We haven't done any of this. In the future we need to support MPS, AMD, Windows etc. All of this will need work and thought on how we support this. For now, please remove mentions of MPS.
torchtune/utils/device.py
Outdated
else: | ||
device = torch.device("cpu") | ||
return device | ||
|
||
|
||
def set_float32_precision(precision: str = "high") -> None: | ||
"""Sets the precision of float32 matrix multiplications and convolution operations. | ||
def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I don't understand this function signature. The recipe is calling get_device
on the value specified within the config. Why is this a union? When is it called with a torch.device
as input?
Also, what are the sets of valid values for device
? This is taking arbitrary strings as input and we need to check for typos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_device isn't just used from the config, it is used by other utils and the user might not be interacting with the recipe from a util but instead a jupyter notebook. The set of strings available is whatever is supported by torch.device. Think of this as an extension of torch.device that takes both string names and devices as input. This just adds additional logic to get/test that the device is compatible with your env settings (including distributed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refer to my comment about cleanly separating public and private functions.
torchtune/utils/device.py
Outdated
else: | ||
torch.backends.cudnn.allow_tf32 = True | ||
# Convert device string to torch.device | ||
if type(device) != torch.device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along the lines of the comment above, I don't understand when this check will fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user passes in a torch.device instead of a string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't mean logically :)
torchtune/utils/device.py
Outdated
torch.backends.cudnn.allow_tf32 = True | ||
# Convert device string to torch.device | ||
if type(device) != torch.device: | ||
if device is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the case when device is None? When the user doesnt specify this in config? Why is that a valid setting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our default device value in our configs should likely be None. Then if a user runs it on a machine without cuda, it'll automatically set the device to cpu, though they can set it in the config. This would become more useful if we add ROCm support to _get_device_from_env in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would become more useful if we add ROCm support to _get_device_from_env
For now can we define the utilities for the features we support? We can define how to generalize when we get there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the feature is kept, what happens when None
is passed should be documented
torchtune/utils/device.py
Outdated
device = torch.device(device) | ||
|
||
# Get device rank for cuda devices if not provided, and set Cuda device | ||
if device.type == "cuda": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this separate from the block above i.e. we're parsing a string from the config so _get_device_from_env
will be called and so the index will be set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string could be "cuda" or "cuda:4" for example. None of these call _get_device_from_env, they instead call torch.device(device_str). Also, as mentioned, the input might not be a string at all. This section ensures that device.index is set if it hasn't been, and also that the index is possible if it was set. (one thing we have to test for is if the user sets cuda:4, useful for a single process run, but it's a distributed run, so the distributed_launcher assigned that process to "cuda:7")
torchtune/utils/device.py
Outdated
torch.cuda.set_device(device) | ||
|
||
# Check if the device index is correct when distributed training | ||
local_rank = os.environ.get("LOCAL_RANK", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this not checked inside _get_device_from_env
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is checking if a specific index provided by the user matches the local rank. If _get_device_from_env is called, then the index would be set from LOCAL_RANK and there wouldn't be an issue.
@@ -0,0 +1,115 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi this file has been deleted upstream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drive-by review, mainly waiting for the next push
recipes/finetune_llm.py
Outdated
world_size, rank = utils.get_world_size_and_rank() | ||
seed = utils.set_seed(kwargs["seed"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
recipes/finetune_llm.py
Outdated
@@ -88,6 +65,7 @@ def recipe(kwargs): | |||
{TransformerDecoderLayer} | |||
) # TODO: remove model specific components | |||
if kwargs["fsdp"]: | |||
utils.init_distributed(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer to have distributed initialization at the very beginning. We need to avoid bugs such as get_world_size
being called before distributed is initialized, and more nuanced things such as how seed setting works in distributed scenario.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the distributed functions to call init_distributed themselves now, so whichever function gets called first will initialize the process group.
|
||
if seed is None: | ||
seed = random.randint(min_val, max_val) | ||
local_seed = seed + rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, though I guess if dist is not initialized, rank should be 0 as returned by get_world_size_and_rank
. If we don't do this here, curious where is your suggestion to do it?
torchtune/utils/seed.py
Outdated
if seed is None: | ||
seed = random.randint(min_val, max_val) | ||
local_seed = seed + rank | ||
_log.debug(f"Setting seed to {seed} and rank local seed to {local_seed}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the log should clarify whether distributed is initialized or not, and if it is, clarify we're offsetting the local seed by the rank.
Also, this is confusing because it says "Setting seed to {seed}", but actually what's set in code is the manual_seed
.
_log: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
def set_seed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder what our principle for utils is in general on them doing multiple (but somewhat related) things: in this particular case, setting the random seed and the deterministic debug mode. Would it be simpler and lead to better testable components if we had separate utils to do these two tasks, or do we feel that these are intertwined enough to have to be done together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be fine with pulling the debug_mode stuff into its own separate function but wouldn't want it included by default in the recipe then. In general I'm on the fence where it comes to functions that do multiple things, they should do just one thing as much as possible but sometimes we need higher level functions that stitch them together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this is useful to discuss, and I'd vote for pulling the determinism to a separate function as it makes this function more in line with its name and less complicated. But feel free to do this in a separate PR.
recipes/finetune_llm.py
Outdated
model=model, | ||
device=device, | ||
dtype=dtype, | ||
strategy="FUll_SHARD", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo? How does this pass the getattr check in _get_sharding_strategy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, it was a typo, it was not getting passed to _get_sharding_strategy because of an incorrect "not" in "if strategy is None" which resulted in it being run with the default "NO_SHARD".
recipes/finetune_llm.py
Outdated
auto_wrap_policy=auto_wrap_policy, | ||
device_id=device, | ||
param_init_fn=lambda m: m.to_empty(device=device, recurse=False), | ||
logger = logging.getLogger() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: All functions are following the format "get_" except this one. Let's be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't our function, it's the python logger. I wonder for a user facing script if we should just use "print"
recipes/finetune_llm.py
Outdated
logger.info(msg=f"Loaded tokenizer from {tokenizer_checkpoint}") | ||
|
||
model = models.get_model(model, device=device) | ||
if distributed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code is applying FSDP if the flag is True. is_distributed
assumes that distributed training == FSDP, which is a bad assumption. This flag should be renamed to enable_fsdp
or is_fsdp
recipes/finetune_llm.py
Outdated
model = models.get_model(model, device=device) | ||
if distributed: | ||
# TODO: initialize models for distributed on meta or cpu device to avoid OOMs | ||
model = utils.get_distributed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more appropriate name for this function is something like apply_fsdp
.
recipes/finetune_llm.py
Outdated
loss_fn = get_loss(kwargs["loss"]) | ||
loss_fn = losses.get_loss(loss) | ||
|
||
grad_scaler = utils.get_gradient_scaler(dtype, distributed) or GradScaler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few unrelated comments on this:
I'm not sure why this needs to be so complicated (you lost me at Optional[Union[...]]). Since you've already called get_dtype
and converted the string into torch.dtype, why does get_gradient_scaler
need to take in a string? You're unnecessarily making a second call to get_type
inside the function. Can we just pass in the resolved type and remove the Union? This also removes the need for this to be Optional
(which doesn't make any sense since the function has a hard dependency on dtype.
I'm also not a fan of binary operators between a function which MAY return a None and an object - it's just not very readable code.
Zooming out a bit, I don't think this utility is saving the reader/user anything. You've added the cognitive load of trying to figure out when get_gradient_scaler
returns None on a user who doesn't care about mixed_precision training. So they still need to read through a bunch of code, but you've made it harder since they need to navigate to an entirely different file. In this case, just unrolling the code here would have at least spared the user the need to find this extra file and figure what's going on.
If I put on my "user hat", then the following would have made my life easier by explicitly differentiating between the scenario where I'm using mixed_precision and where I'm not.
if dtype == 'fp16':
grad_scaler = utils.get_mixed_precision_scaler(enable_fsdp=True)
else:
grad_scaler = GradScaler(enabled=False)
My larger point is that simply abstracting code out of the recipe into helper functions without thinking about whether it's making the reader's life easier or not is not going to lead to the outcomes you're hoping for. Our original hope was to reduce cognitive load by not unnecessarily littering the recipe with if-else blocks. I don't think this is achieving that goal.
recipes/finetune_llm.py
Outdated
grad_scaler = utils.get_gradient_scaler(dtype, distributed) or GradScaler( | ||
enabled=False | ||
) | ||
autocast = utils.get_autocast(dtype, device) or contextlib.nullcontext() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have exactly the same comments here. You already have resolved the dtype and device from the input string, there's no need to add Union
s for these params and make the reader/user's life harder.
Same comment about cognitive load on the user. I'm not sure what we're saving here. Unrolling the code at least saves me a level of indirection.
autocast = contextlib.nullcontext()
if dtype == 'fp16' or dtype == 'bf16':
autocast = torch.autocast(
device_type=device.type,
dtype=dtype,
)
from torchtune.utils.distributed import get_world_size_and_rank | ||
|
||
|
||
class TestDistributed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to rebase on top of
https://github.com/pytorch-labs/torchtune/pull/193/files
This already adds a number of tests related to world size and rank.
cc: @gokulavasan who's the author of this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's rebase and use Gokul's tests here which have distributed multiprocessing set up.
from torch import nn | ||
|
||
|
||
def get_loss(loss: str) -> nn.Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a config parsing function (you called this a "getter") inside a top-level file called losses.py
doesn't make sense to me. losses.py
is where I would expect loss implementations, especially in a top-level file. This function needs to move into config_utils.py. Folder structure will likely be fixed once this lands and we have a complete view of everything.
@@ -0,0 +1,34 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as losses.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, these are more getter code, feel like torchtune.optim
should house something like memory-efficient optimizers, for example
raise RuntimeError( | ||
f"You can't specify a device index when using distributed training. \ | ||
Device specified is {device} but was assigned cuda:{local_rank}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does this check actually pass? Is this a user-side error when the device index is not the same as the machine this is running on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would pass when we're in distributed, and the CUDA device this rank is using is properly set (i.e. to the local rank of the process we're on).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if the user passes in "cuda:4" as their device but it's a distributed run so the local process has its own device index given from the environment.
torch.backends.cudnn.allow_tf32 = True | ||
if device is None: | ||
device = _get_device_type_from_env() | ||
device = torch.device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_device_type_from_env
returns torch.device
and then we pass this torch.device
to torch.device
? That doesn't make any sense?
torchtune/utils/distributed.py
Outdated
return 1, 0 | ||
|
||
|
||
def get_distributed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I mentioned in the recipe, this function should be updated to something like apply_fsdp
cc: @rohan-varma
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super awesome work, love how this is shaping up and thanks for painstakingly going through our feedback!
Left a bunch of comments and indicated which ones in particular can be punted on. Biggest concern is probably the just in time distributed init style, and the fact that there's no explicit call to distributed init in the recipe and I believe it happens in set_seed
, which is an odd side effect.
@@ -17,5 +17,6 @@ optimizer: SGD | |||
loss: CrossEntropyLoss | |||
output_dir: /tmp/alpaca-llama2-finetune | |||
device: cuda | |||
fsdp: False | |||
dtype: fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Can punt to next PR] Not clear to me what dtype this refers to. For example, if it is fp16, is the entire model loaded in fp16, is it autocast fp16, etc.
@@ -17,5 +17,6 @@ optimizer: SGD | |||
loss: CrossEntropyLoss | |||
output_dir: /tmp/alpaca-llama2-finetune | |||
device: cuda | |||
fsdp: False | |||
dtype: fp32 | |||
distributed: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Address in this PR] Should specify the parallelism algorithm instead of "distributed" which is very generic IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I floated the "distributed" idea as I think it's a more general and well understood name and the fsdp is actually a generalized data parallel paradigm, if I run FSDP with No_Shard, it's not actually FSDP. But I also understand, keep the name with the underlying library we're calling will better inform the user so I'll change it back.
logger = logging.getLogger() | ||
device = utils.get_device(device) | ||
dtype = utils.get_dtype(dtype) | ||
seed = utils.set_seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Address in this PR] Think this reintroduces the same bug as #193.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, technically not due to the side effect of init distributed
recipes/finetune_llm.py
Outdated
logger.info(msg=f"Loaded tokenizer from {tokenizer_checkpoint}") | ||
|
||
model = models.get_model(model, device=device) | ||
if distributed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Address in this PR] Also it's not clear to me where distributed is being initialized, i.e. init_process_group
call. This should be separate from setting up the parallelism algorithm, and should occur at the very beginning of the recipe.
opt.step() | ||
grad_scaler.scale(loss).backward() | ||
grad_scaler.step(opt) | ||
grad_scaler.update() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still would prefer explicit use of grad scaler and regular loss.backward() ; optimizer.step()
but think we've discussed this enough and if the consensus advocates for this, then let's just do it.
torchtune/utils/memory.py
Outdated
def set_activation_checkpointing( | ||
model: nn.Module, auto_wrap_policy: Optional[Set[nn.Module]] = None, **kwargs | ||
) -> None: | ||
"""Utility to setup activation checkpointing and setup the model sharding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Activation checkpointing won't shard the model so let's remove "set up the model sharding"
torchtune/utils/memory.py
Outdated
|
||
Args: | ||
model (nn.Module): Model to setup activation checkpointing. | ||
auto_wrap_policy (Optional[Set[nn.Module]]): Policy to wrap module for sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharding doesn't happen here
_log: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
def set_seed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this is useful to discuss, and I'd vote for pulling the determinism to a separate function as it makes this function more in line with its name and less complicated. But feel free to do this in a separate PR.
|
||
if seed is None: | ||
seed = random.randint(min_val, max_val) | ||
local_seed = seed + rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re the concern around using this before distributed has been initialized, it's a very valid concern as this happened before, and seems like its actually happening in this PR again, but isn't because distributed is sort of just in time initialized. I'm not a fan of just in time initializing distributed though, it should be an explicit call so that the user knows when it happens.
For example in the current recipe, the function set_seed
is what initializes distributed AFAICT which seems like an odd side effect for such a function to have
np.random.seed(local_seed) | ||
random.seed(local_seed) | ||
|
||
if debug_mode is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, looking at this the debug mode seems entirely separate. I'd just pull it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it was already part of seed before this PR, I'll leave it as is for now and it can be split out later.
@@ -4,61 +4,35 @@ | |||
# This source code is licensed under the BSD-style license found in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you updated the name of this file, are we planning to update finetune_llm as well to match?
import logging | ||
|
||
import torch | ||
|
||
from torchtune.models.llama2 import llama2_7b, llama2_tokenizer | ||
from torchtune.utils.env import _get_device_from_env, seed | ||
from torchtune.utils import get_device, set_seed, TuneArgumentParser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In llama_generate we do from torchtune.utils import ...
but in finetune_llm we do from torchtune import utils
. Should we update finetune_llm to match this? (I think the way it's done in llama_generate is cleaner)
In an effort to update the recipe to match the design of the recipe RFC, I needed to update the current utils to function in a similar way to the utils in the RFC. This PR refactors existing utils and doesn't add new utils or update the recipe, that will be a followup PR. These utils are meant to hide a lot of complexity for the user so they can add features without thinking too much about it, if they want more granular control they have direct access in the recipe to not use them and use core Pytorch.
Changelog
Utils folder after this PR: argparse, data, device, distributed, precision, seed, generation, logits_transforms
(Utils are not meant to be finetuning specific and not anything, logits_transforms might still need to be recategorized)
Test plan