diff --git a/vae/README.md b/vae/README.md index cda6a33672..81e6458d7a 100644 --- a/vae/README.md +++ b/vae/README.md @@ -8,14 +8,12 @@ pip install -r requirements.txt python main.py ``` -The main.py script accepts the following arguments: +The main.py script accepts the following optional arguments: ```bash -optional arguments: - --batch-size input batch size for training (default: 128) - --epochs number of epochs to train (default: 10) - --no-cuda enables CUDA training - --mps enables GPU on macOS - --seed random seed (default: 1) - --log-interval how many batches to wait before logging training status +--batch-size input batch size for training (default: 128) +--epochs number of epochs to train (default: 10) +--accel use accelerator +--seed random seed (default: 1) +--log-interval how many batches to wait before logging training status ``` \ No newline at end of file diff --git a/vae/main.py b/vae/main.py index d69833fbe0..9a6850ccd1 100644 --- a/vae/main.py +++ b/vae/main.py @@ -13,28 +13,31 @@ help='input batch size for training (default: 128)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') -parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') -parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') +parser.add_argument('--accel', action='store_true', + help='use accelerator') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') args = parser.parse_args() -args.cuda = not args.no_cuda and torch.cuda.is_available() -use_mps = not args.no_mps and torch.backends.mps.is_available() + torch.manual_seed(args.seed) -if args.cuda: - device = torch.device("cuda") -elif use_mps: - device = torch.device("mps") +if args.accel and not torch.accelerator.is_available(): + print("ERROR: accelerator is not available, try running on CPU") + sys.exit(1) +if not args.accel and torch.accelerator.is_available(): + print("WARNING: accelerator is available, run with --accel to enable it") + +if args.accel: + device = torch.accelerator.current_accelerator() else: device = torch.device("cpu") -kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} +print(f"Using device: {device}") + +kwargs = {'num_workers': 1, 'pin_memory': True} if device=="cuda" else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), diff --git a/vae/requirements.txt b/vae/requirements.txt index 9a7236577d..73348074bf 100644 --- a/vae/requirements.txt +++ b/vae/requirements.txt @@ -1,4 +1,4 @@ torch -torchvision==0.20.0 +torchvision tqdm six