Optimizing PyTorch training by wrapping torch.utils.data.Dataset
with tensordict.TensorDict.MemoryMappedTensor
mapped, pinned,
and loaded onto an Nvidia GPU and inputting TensorDict(Dataset)
into torch.utils.data.DataLoader
--to boost model training speed.
To run the demo:
git clone https://github.com/OriYarden/pytorch_training_optimization_using_tensordict_memory_mapping
cd pytorch_training_optimization_using_tensordict_memory_mapping
python run_demo.py
(and what run_demo.py
looks like in gifs)
torch.utils.data.Dataset # Training 1 Epoch:
tensordict.TensorDict.MemoryMappedTensor(torch.utils.data.Dataset) # Training 1 Epoch:
torch.utils.data.Dataset
's POV:
The only thing you have to change in your code (along with potentially a few other minor changes, see comments in code):
ds = Dataset() # <--- potentially requires minor changes in __getitem__ method
ds = dataset_to_tensordict( # <--- Wraps here, this must be added into your existing code (from tensordict_packages).
ds=ds,
DEVICE=DEVICE,
)
loader = DataLoader(ds, collate_fn) # <--- requires inputting Collate_Fn wrapper (from tensordict_packages).
# That's it! Just two lines of code.
The TensorDict Memory Mapping tools that I've provided in tensordict_packages
boosts PyTorch model training speed.
However, the initial tensordict_packages
wrapping runtime is approximately equal to 1 epoch of torch.utils.data.Dataset
:
So there may not be a scenario in which tensordict_packages
can benefit PyTorch model inferencing alone.
Still, PyTorch model training speed can be improved by orders of magnitude when using tensordict_packages
, and therefore,
we should make the most out of the Nvidia GPU resources (i.e. memory) available so that we can speed up PyTorch model training time,
reduce PyTorch model training cost, and shorten the gap between initially developing PyTorch models and having PyTorch models in production.
And with the current AI boom, where LLMs and text-to-video PyTorch models require months of training, we can save time, resources, and
Nvidia GPUs via tensordict_packages
's ability to leverage TensorDict
and MemoryMappedTensor
s with torch.utils.data.DataLoader
.