1414# limitations under the License.
1515
1616import argparse
17- import os
1817
1918import torch
2019import torch .distributed as dist
21- import torch .multiprocessing as mp
2220from eagle_utils import DataCollatorWithPadding , make_eagle_supervised_data_module
2321from trainer .distill_trainer import EagleSGLTrainer , EagleTPTrainer
2422from transformers import AutoTokenizer
2523
2624torch .manual_seed (0 )
2725
2826
29- def _setup_distributed (rank , args , backend = "nccl" ):
30- """Initialize distributed environment"""
31- os .environ ["MASTER_ADDR" ] = "localhost"
32- os .environ ["MASTER_PORT" ] = args .master_port
33- os .environ ["LOCAL_RANK" ] = str (rank )
34- # Initialize process group
35- dist .init_process_group (backend , rank = rank , world_size = args .world_size )
27+ def _check_args (args ):
28+ """Sanity check for arguments."""
29+ # TODO: (hg)
30+
31+
32+ def _setup_pgroups (args ):
33+ """Initialize student/teacher pgroups and set devices."""
34+ rank = dist .get_rank ()
35+ args .teacher_ranks = list (range (len (args .teacher_devices )))
36+ args .student_ranks = list (
37+ range (len (args .teacher_devices ), len (args .teacher_devices ) + len (args .student_devices ))
38+ )
3639 if rank in args .teacher_ranks :
3740 torch .cuda .set_device (args .teacher_devices [rank ])
3841 else :
3942 torch .cuda .set_device (args .student_devices [rank - len (args .teacher_ranks )])
4043 print (
41- f"Starting process rank={ rank } , device={ torch .cuda .current_device ()} , world_size={ args . world_size } "
44+ f"Starting process rank={ rank } , device={ torch .cuda .current_device ()} , world_size={ dist . get_world_size () } "
4245 )
4346 args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
4447 args .student_pgroup = dist .new_group (ranks = args .student_ranks )
4548
4649
47- def train (rank , args ):
48- _setup_distributed (rank , args )
49-
50+ def train (args ):
51+ """Entrance for training."""
5052 tokenizer = AutoTokenizer .from_pretrained (
5153 args .model_path , model_max_length = args .training_seq_len
5254 )
@@ -55,19 +57,24 @@ def train(rank, args):
5557 args .offline_data_path = None
5658 data_module = make_eagle_supervised_data_module (tokenizer , args )
5759
60+ # Ensure different ranks load the same data
61+ g = torch .Generator ()
62+ g .manual_seed (0 )
63+
5864 train_dataloader = torch .utils .data .DataLoader (
5965 data_module ["train_dataset" ],
6066 batch_size = args .batch_size ,
6167 shuffle = True ,
6268 num_workers = 0 ,
6369 collate_fn = DataCollatorWithPadding (max_length = args .training_seq_len ),
6470 drop_last = True ,
71+ generator = g ,
6572 )
6673 trainer_cls = {
6774 "sglang" : EagleSGLTrainer ,
6875 "hf" : EagleTPTrainer ,
6976 }[args .teacher_backend ]
70- trainer = trainer_cls (rank , args , tokenizer , train_dataloader )
77+ trainer = trainer_cls (dist . get_rank () , args , tokenizer , train_dataloader )
7178 trainer .train ()
7279 trainer .save (args .out_path )
7380
@@ -76,7 +83,7 @@ def main():
7683 parser = argparse .ArgumentParser (description = "Multi-GPU distributed two-stage forward example" )
7784
7885 # Training args
79- parser .add_argument ("--model_path" , type = str , default = "TinyLlama/TinyLlama-1.1B-Chat-v1.0 " )
86+ parser .add_argument ("--model_path" , type = str , required = True , help = "Target model path. " )
8087 parser .add_argument ("--data_path" , type = str , required = True , help = "Training dataset." )
8188 parser .add_argument ("--training_seq_len" , type = str , default = 1024 )
8289 parser .add_argument ("--eagle_config_path" , type = str , default = "eagle_config.json" )
@@ -103,26 +110,16 @@ def main():
103110 parser .add_argument (
104111 "--total_steps" , type = int , default = 60000 , help = "Total number of steps for debugging."
105112 )
113+ parser .add_argument ("--master_addr" , type = str , default = "localhost" )
106114 parser .add_argument ("--master_port" , type = str , default = "12357" )
107115
108116 args = parser .parse_args ()
109- # TODO: add sanity check for args
110-
111- def set_ranks (args ):
112- args .world_size = len (args .teacher_devices ) + len (args .student_devices )
113- args .teacher_ranks = list (range (len (args .teacher_devices )))
114- args .student_ranks = list (
115- range (len (args .teacher_devices ), len (args .teacher_devices ) + len (args .student_devices ))
116- )
117-
118- set_ranks (args )
119- # Launch multiple processes
120- mp .spawn (
121- train ,
122- args = (args ,),
123- nprocs = args .world_size ,
124- join = True ,
125- )
117+
118+ dist .init_process_group ("nccl" )
119+
120+ _check_args (args )
121+ _setup_pgroups (args )
122+ train (args )
126123
127124
128125if __name__ == "__main__" :
0 commit comments