Skip to content

Commit 9c23024

Browse files
committed
with accelerator API
1 parent c3ac0e9 commit 9c23024

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

swin_transformer/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ It includes:
2525
Install dependencies:
2626

2727
```bash
28-
pip install torch torchvision
28+
pip install -r requirements.txt
2929
```
3030

3131
---
@@ -43,7 +43,7 @@ python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001
4343
Testing is done automatically after each epoch. To only test, run with:
4444

4545
```bash
46-
python swin_transformer.py --epochs 0
46+
python swin_transformer.py --epochs 1
4747
```
4848

4949
### Save the model

swin_transformer/swin_transformer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,10 @@ def main():
172172
parser.add_argument('--save-model', action='store_true')
173173
args = parser.parse_args()
174174

175-
use_cuda = torch.cuda.is_available()
176-
device = torch.device("cuda" if use_cuda else "cpu")
177-
175+
use_accel = torch.accelerator.is_available()
176+
device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu")
177+
print(f"Using device: {device}")
178+
178179
torch.manual_seed(args.seed)
179180

180181
transform = transforms.Compose([

0 commit comments

Comments
 (0)