Skip to content

Commit 5478049

Browse files
committed
fixes requirements,code and readme
1 parent 9c23024 commit 5478049

File tree

4 files changed

+9
-19
lines changed

4 files changed

+9
-19
lines changed

run_python_examples.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ function stop() {
195195
time_sequence_prediction/traindata.pt \
196196
word_language_model/model.pt \
197197
gcn/cora/ \
198-
gat/cora/ || error "couldn't clean up some files"
199-
swin_transformer/swin_cifar10.pt || error "couldn't clean up some files"
200-
198+
gat/cora/ || error "couldn't clean up some files" \
199+
swin_transformer/ \
200+
swin_trasformer/swin_cifar10.pt || error "command swin_transformer/swin_cifar10.pt not found" \
201201
git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image"
202202

203203
base_stop "$1"
@@ -225,7 +225,7 @@ function run_all() {
225225
run fx
226226
run gcn
227227
run gat
228-
run swin
228+
run swin_transformer
229229
}
230230

231231
# by default, run all examples

swin_transformer/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ pip install -r requirements.txt
3232

3333
## Usage
3434

35-
### Train the model
35+
### Train & Save the model
3636

3737
```bash
38-
python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001
38+
python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001 --save-model
3939
```
4040

4141
### Test the model
@@ -44,13 +44,7 @@ Testing is done automatically after each epoch. To only test, run with:
4444

4545
```bash
4646
python swin_transformer.py --epochs 1
47-
```
48-
49-
### Save the model
50-
51-
```bash
52-
python swin_transformer.py --save-model
53-
```
47+
``
5448

5549
The model will be saved as `swin_cifar10.pt`.
5650

swin_transformer/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
torch
1+
torch>=2.6
22
torchvision

swin_transformer/swin_transformer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from __future__ import print_function
21
import argparse
3-
42
import torch
53
import torch.nn as nn
64
import torch.nn.functional as F
@@ -202,6 +200,4 @@ def main():
202200

203201
if args.save_model:
204202
torch.save(model.state_dict(), "swin_cifar10.pt")
205-
206-
if __name__ == '__main__':
207-
main()
203+
main()

0 commit comments

Comments
 (0)