Skip to content

Commit 1816149

Browse files
committed
Adding NVIDIA TFT model
1 parent ed58d27 commit 1816149

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed

images/tft_architecture.png

58.6 KB
Loading

nvidia_deeplearningexamples_tft.md

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
---
2+
layout: hub_detail
3+
background-class: hub-background
4+
body-class: hub
5+
title: Temporal Fusion Transformer
6+
summary: The Temporal Fusion Transformer (TFT) model is a state-of-the-art architecture for interpretable, multi-horizon time-series prediction.
7+
category: researchers
8+
image: nvidia_logo.png
9+
author: NVIDIA
10+
tags: [forecasting]
11+
github-link: https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Forecasting/TFT
12+
github-id: NVIDIA/DeepLearningExamples
13+
featured_image_1: tft_architecture.png
14+
featured_image_2: no-image
15+
accelerator: cuda
16+
---
17+
18+
19+
# Model Description
20+
The Temporal Fusion Transformer [TFT](https://arxiv.org/abs/1912.09363) model is a state-of-the-art architecture for interpretable, multi-horizon time-series prediction. The model was first developed and [implemented by Google](https://github.com/google-research/google-research/tree/master/tft) with the collaboration with the University of Oxford.
21+
[This implementation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Forecasting/TFT) differs from the reference implementation by addressing the issue of missing data, which is common in production datasets, by either masking their values in attention matrices or embedding them as a special value in the latent space.
22+
This model enables the prediction of confidence intervals for future values of time series for multiple future timesteps.
23+
# Example
24+
In the following example, we will use the pretrained ***TFT*** model to perform inference on some preprocessed samples from the ***Electricity*** dataset. To run the example, you need some extra python packages installed for loading and visualization.
25+
```python
26+
!pip install scikit-learn==1.2.1
27+
!pip install pandas==1.5.3
28+
!pip install matplotlib==3.6.3
29+
```
30+
31+
```python
32+
import torch
33+
import os
34+
import numpy as np
35+
import matplotlib.pyplot as plt
36+
import pickle
37+
from mpl_toolkits.axes_grid1 import make_axes_locatable
38+
import warnings
39+
warnings.filterwarnings('ignore')
40+
os.environ["TFT_SCRIPTING"] = "True"
41+
if torch.cuda.is_available():
42+
device = torch.device("cuda")
43+
!nvidia-smi
44+
else:
45+
device = torch.device("cpu")
46+
print(f'Using {device} for inference')
47+
```
48+
49+
Load the model pretrained on the ***Electricity*** dataset.
50+
```python
51+
tft_model = torch.hub.load("../../../../public", "nvidia_tft", dataset="electricity", pretrained=True, source="local")
52+
utils = torch.hub.load("../../../../public", "nvidia_tft_data_utils", source="local")
53+
```
54+
55+
Download and preprocess the data. This can take a few minutes.
56+
```python
57+
utils.download_data(torch.hub._get_torch_home())
58+
```
59+
60+
```python
61+
utils.preprocess(torch.hub._get_torch_home())
62+
```
63+
64+
Initialize functions used to get interpretable attention graphs.
65+
```python
66+
activations = {}
67+
def get_attention_heatmap_fig(heads, max_size=16, min_size=4):
68+
row_size = max(min_size, max_size / len(heads))
69+
fig, axes = plt.subplots(1, len(heads), figsize=(max_size, row_size))
70+
for i, (head, ax) in enumerate(zip(heads, axes), 1):
71+
im = ax.imshow(head, cmap='hot', interpolation='nearest')
72+
if i < len(heads):
73+
ax.set_title(f'HEAD {i}')
74+
else:
75+
ax.set_title(f'MEAN')
76+
divider = make_axes_locatable(ax)
77+
cax = divider.append_axes('right', size='5%', pad=0.05)
78+
fig.colorbar(im, cax=cax, orientation='vertical')
79+
return fig
80+
81+
def get_attn_heads(activations, sample_number):
82+
heads = []
83+
_, attn_prob = activations
84+
sample_attn_prob = attn_prob[sample_number]
85+
n_heads = sample_attn_prob.shape[0]
86+
for head_index in range(n_heads):
87+
head = sample_attn_prob[head_index] * 255
88+
heads.append(head.detach().cpu())
89+
mean_head = torch.mean(sample_attn_prob, dim=0) * 255
90+
heads.append(mean_head.detach().cpu())
91+
fig = get_attention_heatmap_fig(heads)
92+
return fig
93+
94+
def _get_activation(name):
95+
def hook(model, input, output):
96+
activations[name] = output
97+
98+
return hook
99+
```
100+
101+
Register the hook on the model to save the data.
102+
```python
103+
tft_model.attention.register_forward_hook(_get_activation('attention'))
104+
```
105+
106+
Load the sample preprocessed batch of data.
107+
```python
108+
batch = utils.get_batch(torch.hub._get_torch_home())
109+
```
110+
111+
```python
112+
batch = {key: tensor.to(device) if tensor.numel() else None for key, tensor in batch.items()}
113+
```
114+
115+
Run inference on the ***TFT***.
116+
```python
117+
tft_model.to(device)
118+
tft_model.eval()
119+
with torch.no_grad():
120+
output= tft_model(batch)
121+
```
122+
123+
```python
124+
output.shape
125+
```
126+
127+
```python
128+
batch['target'].shape
129+
```
130+
131+
Plot the full 192 timestep window of the ***Electricity*** dataset. We use the previous week of data to predict the following day of power usage. Since our data is organized by hour, this means we use 168 previous time points to predict the following 24.
132+
```python
133+
index = 9
134+
fig, ax = plt.subplots()
135+
ax.plot(batch['target'][index].cpu().numpy(), label="Ground Truth")
136+
ax.plot(np.arange(168, 192), output[index].detach().cpu().numpy(), label=["10th Quantile", "50th Quantile", "90th Quantile"])
137+
ax.legend(loc='upper left')
138+
139+
ax.set_xlabel('Timestep')
140+
ax.set_ylabel('Power Usage')
141+
142+
143+
```
144+
145+
Below is the same graph as above, but only focusing on the prediction window, which is the last 24 values.
146+
```python
147+
fig, ax = plt.subplots()
148+
ax.plot(np.arange(168, 192), batch['target'][index][-24:].cpu().numpy(), label="Ground Truth")
149+
ax.plot(np.arange(168, 192), output[index].detach().cpu().numpy(), label=["10th Quantile", "50th Quantile", "90th Quantile"])
150+
ax.legend(loc='upper left')
151+
152+
ax.set_xlabel('Timestep')
153+
ax.set_ylabel('Power Usage')
154+
```
155+
156+
Using the hook we defined earlier we can generate plots of the attention heads. There is a clear trend that more recent values are given more weight. In addition, the striations in the graphs are every 24 hours, which indicates that there is some correlation between data on the same hour of different days. Indeed, looking back on the graphs above there is a cyclical pattern to the power usage that repeats every 24 hours.
157+
```python
158+
attn_graphs = get_attn_heads(activations['attention'], index)
159+
```
160+
161+
### Details
162+
163+
For detailed information on model input and output, training recipies, inference and performance visit:
164+
[Deep Learning Examples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Forecasting/TFT)
165+
### References
166+
167+
- [Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363)
168+
- [model on github](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Forecasting/TFT)
169+
- [pretrained model on NGC (Electricity)](https://catalog.ngc.nvidia.com/orgs/nvidia/models/tft_pyt_ckpt_base_eletricity_amp)
170+
- [pretrained model on NGC (Traffic)](https://catalog.ngc.nvidia.com/orgs/nvidia/models/tft_pyt_ckpt_base_traffic_amp)

scripts/tags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
'generative',
44
'audio',
55
'scriptable',
6+
'forecasting'
67
]

0 commit comments

Comments
 (0)