Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_
python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5

```
## Experiments
## New Experiment Command
```
expid=bs_64
mkdir $expid
python train.py --batch_size 64 --learning_rate .004 --do_graph_conv --addaptadj --randomadj --save $expid | tee -a $expid.log
```

## Old Experiment Commands
Train models configured in Table 3 of the paper.

```
Expand Down
16 changes: 9 additions & 7 deletions engine.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import torch.optim as optim
from model import *
import util
class trainer():
def __init__(self, scaler, in_dim, seq_length, num_nodes, nhid , dropout, lrate, wdecay, device, supports, gcn_bool, addaptadj, aptinit):
self.model = gwnet(device, num_nodes, dropout, supports=supports, gcn_bool=gcn_bool, addaptadj=addaptadj, aptinit=aptinit, in_dim=in_dim, out_dim=seq_length, residual_channels=nhid, dilation_channels=nhid, skip_channels=nhid * 8, end_channels=nhid * 16)
class Trainer():
def __init__(self, scaler, in_dim, seq_length, num_nodes, nhid , dropout, lrate, wdecay, device,
supports, gcn_bool, addaptadj, aptinit, clip=5):
# TODO(SS): pass model in.
self.model = GWNet(device, num_nodes, dropout, supports=supports, do_graph_conv=gcn_bool, addaptadj=addaptadj, aptinit=aptinit, in_dim=in_dim,
out_dim=seq_length, residual_channels=nhid, dilation_channels=nhid, skip_channels=nhid * 8, end_channels=nhid * 16)
self.model.to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay)
self.loss = util.masked_mae
self.scaler = scaler
self.clip = 5
self.clip = clip

def train(self, input, real_val):
self.model.train()
self.optimizer.zero_grad()
input = nn.functional.pad(input,(1,0,0,0))
output = self.model(input)
output = output.transpose(1,3)
#output = [batch_size,12,num_nodes,1]
real = torch.unsqueeze(real_val,dim=1)
predict = self.scaler.inverse_transform(output)

#output = [batch_size,12,num_nodes,1]
real = torch.unsqueeze(real_val, dim=1)
loss = self.loss(predict, real, 0.0)
loss.backward()
if self.clip is not None:
Expand Down
129 changes: 48 additions & 81 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,51 @@
from torch.autograd import Variable
import sys


class nconv(nn.Module):
def __init__(self):
super(nconv,self).__init__()

def forward(self,x, A):
x = torch.einsum('ncvl,vw->ncwl',(x,A))
return x.contiguous()

class linear(nn.Module):
def __init__(self,c_in,c_out):
super(linear,self).__init__()
class GraphConvNet(nn.Module):
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super().__init__()
c_in = (order * support_len + 1) * c_in
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True)

def forward(self,x):
return self.mlp(x)

class gcn(nn.Module):
def __init__(self,c_in,c_out,dropout,support_len=3,order=2):
super(gcn,self).__init__()
self.nconv = nconv()
c_in = (order*support_len+1)*c_in
self.mlp = linear(c_in,c_out)
self.dropout = dropout
self.order = order

def forward(self,x,support):
def nconv(self, x, A):
return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()

def forward(self, x, support):
out = [x]
for a in support:
x1 = self.nconv(x,a)
out.append(x1)
for k in range(2, self.order + 1):
x2 = self.nconv(x1,a)
x2 = self.nconv(x1, a)
out.append(x2)
x1 = x2

h = torch.cat(out,dim=1)
h = torch.cat(out, dim=1)
h = self.mlp(h)
h = F.dropout(h, self.dropout, training=self.training)
return h


class gwnet(nn.Module):
def __init__(self, device, num_nodes, dropout=0.3, supports=None, gcn_bool=True, addaptadj=True, aptinit=None, in_dim=2,out_dim=12,residual_channels=32,dilation_channels=32,skip_channels=256,end_channels=512,kernel_size=2,blocks=4,layers=2):
super(gwnet, self).__init__()
class GWNet(nn.Module):
def __init__(self, device, num_nodes, dropout=0.3, supports=None, do_graph_conv=True, addaptadj=True, aptinit=None,
in_dim=2, out_dim=12, residual_channels=32, dilation_channels=32, skip_channels=256,
end_channels=512, kernel_size=2, blocks=4, layers=2):
super().__init__()
self.dropout = dropout
self.blocks = blocks
self.layers = layers
self.gcn_bool = gcn_bool
self.do_graph_conv = do_graph_conv
self.addaptadj = addaptadj

# Each of these will
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.bn = nn.ModuleList()
self.gconv = nn.ModuleList()
self.graph_convs = nn.ModuleList()

self.start_conv = nn.Conv2d(in_channels=in_dim,
out_channels=residual_channels,
Expand All @@ -73,34 +61,28 @@ def __init__(self, device, num_nodes, dropout=0.3, supports=None, gcn_bool=True,
if supports is not None:
self.supports_len += len(supports)

if gcn_bool and addaptadj:
if do_graph_conv and addaptadj:
if supports is None: self.supports = []
if aptinit is None:
if supports is None:
self.supports = []
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10), requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes), requires_grad=True).to(device)
self.supports_len +=1
else:
if supports is None:
self.supports = []
m, p, n = torch.svd(aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(device)
self.supports_len += 1




for b in range(blocks):
additional_scope = kernel_size - 1
new_dilation = 1
for i in range(layers):
# dilated convolutions
self.filter_convs.append(nn.Conv2d(in_channels=residual_channels,
out_channels=dilation_channels,
kernel_size=(1,kernel_size),dilation=new_dilation))
kernel_size=(1,kernel_size), dilation=new_dilation))

self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,
out_channels=dilation_channels,
Expand All @@ -119,88 +101,73 @@ def __init__(self, device, num_nodes, dropout=0.3, supports=None, gcn_bool=True,
new_dilation *=2
receptive_field += additional_scope
additional_scope *= 2
if self.gcn_bool:
self.gconv.append(gcn(dilation_channels,residual_channels,dropout,support_len=self.supports_len))


if self.do_graph_conv:
self.graph_convs.append(GraphConvNet(dilation_channels, residual_channels, dropout, support_len=self.supports_len))

self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
out_channels=end_channels,
kernel_size=(1,1),
bias=True)
out_channels=end_channels,
kernel_size=(1, 1),
bias=True)

self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
out_channels=out_dim,
kernel_size=(1,1),
kernel_size=(1, 1),
bias=True)

self.receptive_field = receptive_field



def forward(self, input):
# Input shape is (bs, features, n_nodes, n_timesteps)
in_len = input.size(3)
if in_len<self.receptive_field:
x = nn.functional.pad(input,(self.receptive_field-in_len,0,0,0))
x = nn.functional.pad(input, (self.receptive_field - in_len,0,0,0))
else:
x = input
x = self.start_conv(x)
skip = 0

# calculate the current adaptive adj matrix once per iteration
new_supports = None
if self.gcn_bool and self.addaptadj and self.supports is not None:
if self.do_graph_conv and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp]
adjacency_matrices = self.supports + [adp]
else:
adjacency_matrices = self.supports

# WaveNet layers
for i in range(self.blocks * self.layers):
# EACH BLOCK

# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# | |-dil_conv -- tanh --| |
# ---| * ----|-- 1x1 -- + --> *input*
# |-dil_conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*

#(dilation, init_dilation) = self.dilations[i]

#residual = dilation_func(x, dilation, init_dilation, i)
residual = x
# dilated convolution
filter = self.filter_convs[i](residual)
filter = torch.tanh(filter)
gate = self.gate_convs[i](residual)
gate = torch.sigmoid(gate)
filter = torch.tanh(self.filter_convs[i](residual))
gate = torch.sigmoid(self.gate_convs[i](residual))
x = filter * gate

# parametrized skip connection

s = x
s = self.skip_convs[i](s)
try:
skip = skip[:, :, :, -s.size(3):]
try: # if i > 0 this works
skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool?
except:
skip = 0
skip = s + skip


if self.gcn_bool and self.supports is not None:
if self.addaptadj:
x = self.gconv[i](x, new_supports)
else:
x = self.gconv[i](x,self.supports)
if self.do_graph_conv and self.supports is not None:
support = adjacency_matrices if self.addaptadj else self.supports
x = self.graph_convs[i](x, support)
else:
x = self.residual_convs[i](x)

x = x + residual[:, :, :, -x.size(3):]


x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool?
x = self.bn[i](x)

x = F.relu(skip)
x = F.relu(skip) # ignore last X?
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
return x
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
durbango
tqdm
6 changes: 3 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
parser.add_argument('--data',type=str,default='data/METR-LA',help='data path')
parser.add_argument('--adjdata',type=str,default='data/sensor_graph/adj_mx.pkl',help='adj data path')
parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type')
parser.add_argument('--gcn_bool',action='store_true',help='whether to add graph convolution layer')
parser.add_argument('--do_graph_conv',action='store_true',help='whether to add graph convolution layer')
parser.add_argument('--aptonly',action='store_true',help='whether only adaptive adj')
parser.add_argument('--addaptadj',action='store_true',help='whether add adaptive adj')
parser.add_argument('--randomadj',action='store_true',help='whether random initialize adaptive adj')
Expand Down Expand Up @@ -45,13 +45,13 @@ def main():
if args.aptonly:
supports = None

model = gwnet(device, args.num_nodes, args.dropout, supports=supports, gcn_bool=args.gcn_bool, addaptadj=args.addaptadj, aptinit=adjinit)
model = GWNet(device, args.num_nodes, args.dropout, supports=supports, do_graph_conv=args.do_graph_conv, addaptadj=args.addaptadj, aptinit=adjinit)
model.to(device)
model.load_state_dict(torch.load(args.checkpoint))
model.eval()


print('model load successfully')
print('model loaded successfully')

dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size)
scaler = dataloader['scaler']
Expand Down
Binary file added test_args.pkl
Binary file not shown.
19 changes: 19 additions & 0 deletions test_gwnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from train import main
import unittest
from durbango import pickle_load
import pandas as pd
TEST_ARGS_PATH = 'test_args.pkl'

class TestTrain(unittest.TestCase):

def test_1_epoch(self):
args = pickle_load(TEST_ARGS_PATH)
args.epochs = 2
args.n_iters = 1
args.batch_size = 4
args.n_obs = 4
main(args)
df = pd.read_csv(f'{args.save}/metrics.csv', index_col=0)
self.assertEqual(df.shape, (2,6))
test_df = pd.read_csv(f'{args.save}/test_metrics.csv', index_col=0)
self.assertEqual(test_df.shape, (12, 3))
Loading