Skip to content

Commit 0fbd562

Browse files
committed
Added SAG-ViT.md and images
1 parent c7895df commit 0fbd562

File tree

5 files changed

+172
-0
lines changed

5 files changed

+172
-0
lines changed

SAG-ViT.md

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
---
2+
layout: hub_detail
3+
background-class: hub-background
4+
body-class: hub
5+
category: researchers
6+
title: SAG-ViT
7+
summary: SAG-ViT improves image classification by combining multi-scale feature extraction with high-fidelity patching and graph attention for Transformer architectures.
8+
image: SAG-ViT.png
9+
author: Shravan Venkatraman
10+
tags: [vision]
11+
github-link: https://github.com/shravan-18/SAG-ViT/blob/main/sag_vit_model.py
12+
github-id: shravan-18/SAG-ViT
13+
featured_image_1: SAG-ViT.png
14+
featured_image_2: SAG-ViT_Ablation.png
15+
accelerator: "cuda-optional"
16+
---
17+
18+
### Model Description
19+
20+
Implementation of the ***SAG-ViT*** model as proposed in the [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420) paper.
21+
22+
It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding.
23+
24+
### Model Architecture
25+
26+
<p align="center">
27+
<img src="images/SAG-ViT.png" alt="SAG-ViT Architecture Overview">
28+
</p>
29+
30+
_Image source: [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420)_
31+
32+
### Usage
33+
34+
SAG-ViT expect input images normalized in the same way,
35+
i.e. mini-batches of 3-channel RGB images of shape `(N, 3, H, W)`, where `N` is the number of images, `H` and `W` are expected to be at least `49` pixels.
36+
The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]`
37+
and `std = [0.229, 0.224, 0.225]`.
38+
39+
To train or run inference on our model, you will need to install the following Python packages.
40+
41+
```bash
42+
pip install -q torch-geometric==2.6.1 networkx==3.3 torch==2.4.0 torchvision==0.19.0 scikit-learn==1.2.2 numpy==1.26.4 pandas==2.2.3 matplotlib==3.7.5
43+
```
44+
45+
Load the model pretrained on CIFAR-10 dataset.
46+
```python
47+
import torch
48+
49+
# Load the SAG-ViT model
50+
model = torch.hub.load('shravan-18/SAG-ViT', 'SAGViT', pretrained=True)
51+
# Change to eval mode for prediction
52+
model.eval()
53+
```
54+
55+
Sample execution to predict on an input image.
56+
```python
57+
from PIL import Image
58+
import torch
59+
60+
def predict_image(model, img_tensor, device='cpu'):
61+
"""
62+
Predicts the class label for the given image tensor.
63+
"""
64+
with torch.no_grad():
65+
img_tensor = img_tensor.to(device)
66+
outputs = model(img_tensor)
67+
probs = torch.softmax(outputs, dim=1)
68+
_, preds = torch.max(probs, 1)
69+
return preds.item(), probs[0, preds.item()].item()
70+
71+
# Set input image path to predict
72+
image_path = "path/to/input/image"
73+
74+
# Preprocess the input image
75+
transform = transforms.Compose([
76+
transforms.Resize((224, 224)),
77+
transforms.ToTensor(),
78+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
79+
std=[0.229, 0.224, 0.225])
80+
])
81+
img = Image.open(image_path).convert("RGB")
82+
img_tensor = transform(img)
83+
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
84+
85+
# Predict
86+
pred_class, confidence = predict_image(model, img_tensor, device)
87+
88+
# CIFAR-10 label mapping
89+
class_names = [
90+
'airplane', 'automobile', 'bird', 'cat', 'deer',
91+
'dog', 'frog', 'horse', 'ship', 'truck'
92+
]
93+
94+
predicted_label = class_names[pred_class]
95+
print(f"Predicted class: {predicted_label} with confidence: {confidence:.4f}")
96+
```
97+
98+
### Running Tests
99+
100+
If you clone our [repository](https://github.com/shravan-18/SAG-ViT), the *'tests'* folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run:
101+
```bash
102+
python -m unittest discover -s tests
103+
```
104+
105+
or, if you are using `pytest`, you can run:
106+
```bash
107+
pytest tests
108+
```
109+
110+
### Results
111+
112+
We evaluated SAG-ViT on diverse datasets:
113+
- **CIFAR-10** (natural images)
114+
- **GTSRB** (traffic sign recognition)
115+
- **NCT-CRC-HE-100K** (histopathological images)
116+
- **NWPU-RESISC45** (remote sensing imagery)
117+
- **PlantVillage** (agricultural imagery)
118+
119+
SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):
120+
121+
<center>
122+
123+
| Backbone | CIFAR-10 | GTSRB | NCT-CRC-HE-100K | NWPU-RESISC45 | PlantVillage |
124+
|--------------------|----------|--------|-----------------|---------------|--------------|
125+
| DenseNet201 | 0.5427 | 0.9862 | 0.9214 | 0.4493 | 0.8725 |
126+
| Vgg16 | 0.5345 | 0.8180 | 0.8234 | 0.4114 | 0.7064 |
127+
| Vgg19 | 0.5307 | 0.7551 | 0.8178 | 0.3844 | 0.6811 |
128+
| DenseNet121 | 0.5290 | 0.9813 | 0.9247 | 0.4381 | 0.8321 |
129+
| AlexNet | 0.6126 | 0.9059 | 0.8743 | 0.4397 | 0.7684 |
130+
| Inception | 0.7734 | 0.8934 | 0.8707 | 0.8707 | 0.8216 |
131+
| ResNet | 0.9172 | 0.9134 | 0.9478 | 0.9103 | 0.8905 |
132+
| MobileNet | 0.9169 | 0.3006 | 0.4965 | 0.1667 | 0.2213 |
133+
| ViT - S | 0.8465 | 0.8542 | 0.8234 | 0.6116 | 0.8654 |
134+
| ViT - L | 0.8637 | 0.8613 | 0.8345 | 0.8358 | 0.8842 |
135+
| MNASNet1_0 | 0.1032 | 0.0024 | 0.0212 | 0.0011 | 0.0049 |
136+
| ShuffleNet_V2_x1_0 | 0.3523 | 0.4244 | 0.4598 | 0.1808 | 0.3190 |
137+
| SqueezeNet1_0 | 0.4328 | 0.8392 | 0.7843 | 0.3913 | 0.6638 |
138+
| GoogLeNet | 0.4954 | 0.9455 | 0.8631 | 0.3720 | 0.7726 |
139+
| **Proposed (SAG-ViT)** | **0.9574** | **0.9958** | **0.9861** | **0.9549** | **0.9772** |
140+
141+
</center>
142+
143+
### Ablation
144+
145+
In our ablation study on the CIFAR-10 dataset, we examined the impact of each component in our model. Removing the Transformer encoder while keeping the EfficientNet backbone and GAT resulted in a drop in the F1 score to 0.7785, highlighting the importance of the Transformer for capturing global dependencies. Excluding the GAT while keeping the EfficientNet backbone and Transformer encoder led to an F1 score of 0.7593, emphasizing the role of the GAT in refining local feature representations. When the EfficientNet backbone was removed, leaving only the GAT and Transformer encoder, the F1 score drastically decreased to 0.5032, demonstrating the critical role of EfficientNet in providing rich feature embeddings. The ablation results are summarized in the table below.
146+
147+
<center>
148+
149+
| **Model** | **F1** | **RAM (GB)** | **GPU (VRAM) (GB)** | **Time per Epoch** |
150+
|-------------------------------------------|----------|--------------|---------------------|----------------------|
151+
| Backbone + GAT (no transformer) | 0.7785 | 5.6 | 4.9 | 14 mins 30 sec |
152+
| Backbone + transformer (no GAT) | 0.7593 | 3.1 | 4.5 | 16 mins 7 sec |
153+
| GAT + Transformer (no Backbone) | 0.5032 | 4.3 | 5.3 | 1 hour 33 mins |
154+
155+
</center>
156+
157+
## Citation
158+
159+
If you find our [paper](https://arxiv.org/abs/2411.09420) and [code](https://github.com/shravan-18/SAG-ViT) helpful for your research, please consider citing our work and giving the repository a star:
160+
161+
```bibtex
162+
@misc{SAGViT,
163+
title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers},
164+
author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R},
165+
year={2024},
166+
eprint={2411.09420},
167+
archivePrefix={arXiv},
168+
primaryClass={cs.CV},
169+
url={https://arxiv.org/abs/2411.09420},
170+
}
171+
```

images/SAG-ViT.png

349 KB
Loading

images/SAG-ViT_Ablation.png

547 KB
Loading
279 Bytes
Binary file not shown.

scripts/install_deps.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ pip install -q --upgrade google-api-python-client
2020
pip install pytorchvideo
2121
pip install -q prefetch_generator # yolop
2222
pip install -q pretrainedmodels efficientnet_pytorch webcolors # hybridnets
23+
pip install -q networkx torch-geometric scikit-learn # SAG-ViT

0 commit comments

Comments
 (0)