Skip to content

Commit 7fe7e01

Browse files
committed
time distributed blocks
1 parent bb176b6 commit 7fe7e01

File tree

4 files changed

+139
-11
lines changed

4 files changed

+139
-11
lines changed

keras_resnet/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .block import (
22
basic,
3-
bottleneck,
4-
shortcut
3+
bottleneck
54
)
65

76
from .models import (

keras_resnet/block.py renamed to keras_resnet/block/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def f(x):
4545
y = keras.layers.Conv2D(filters, (3, 3), padding="same", **parameters)(y)
4646

4747
y = keras.layers.BatchNormalization(axis=axis)(y)
48-
y = shortcut(x, y)
48+
y = _shortcut(x, y)
4949
y = keras.layers.Activation("relu")(y)
5050

5151
return y
@@ -89,29 +89,29 @@ def f(x):
8989
y = keras.layers.Conv2D(filters * 4, (1, 1), **parameters)(y)
9090

9191
y = keras.layers.BatchNormalization(axis=axis)(y)
92-
y = shortcut(x, y)
92+
y = _shortcut(x, y)
9393
y = keras.layers.Activation("relu")(y)
9494

9595
return y
9696

9797
return f
9898

9999

100-
def shortcut(a, b):
100+
def _shortcut(a, b):
101101
a_shape = keras.backend.int_shape(a)
102102
b_shape = keras.backend.int_shape(b)
103103

104104
if keras.backend.image_data_format() == "channels_last":
105-
x = int(round(a_shape[1] / b_shape[1]))
106-
y = int(round(a_shape[2] / b_shape[2]))
105+
x = int(round(a_shape[1] // b_shape[1]))
106+
y = int(round(a_shape[2] // b_shape[2]))
107107

108108
if x > 1 or y > 1 or not a_shape[3] == b_shape[3]:
109109
a = keras.layers.Conv2D(b_shape[3], (1, 1), strides=(x, y), padding="same", **parameters)(a)
110110

111111
a = keras.layers.BatchNormalization(axis=3)(a)
112112
else:
113-
x = int(round(a_shape[2] / b_shape[2]))
114-
y = int(round(a_shape[3] / b_shape[3]))
113+
x = int(round(a_shape[2] // b_shape[2]))
114+
y = int(round(a_shape[3] // b_shape[3]))
115115

116116
if x > 1 or y > 1 or not a_shape[1] == b_shape[1]:
117117
a = keras.layers.Conv2D(b_shape[1], (1, 1), strides=(x, y), padding="same", **parameters)(a)

keras_resnet/block/temporal.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
5+
keras_resnet.block.temporal
6+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
7+
8+
This module implements a number of popular time distributed residual blocks.
9+
10+
"""
11+
12+
import keras.layers
13+
import keras.regularizers
14+
15+
parameters = {
16+
"kernel_initializer": "he_normal"
17+
}
18+
19+
20+
def basic(filters, strides=(1, 1), first=False):
21+
"""
22+
23+
A time distributed basic block.
24+
25+
:param filters: the output’s feature space
26+
:param strides: the convolution’s stride
27+
:param first: whether this is the first instance inside a residual block
28+
29+
Usage::
30+
>>> import keras_resnet.block.temporal
31+
>>> keras_resnet.block.temporal.basic(64)
32+
33+
"""
34+
def f(x):
35+
if keras.backend.image_data_format() == "channels_last":
36+
axis = 3
37+
else:
38+
axis = 1
39+
40+
y = keras.layers.TimeDistributed(keras.layers.Conv2D(filters, (3, 3), strides=strides, padding="same", **parameters))(x)
41+
42+
y = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=axis))(y)
43+
y = keras.layers.TimeDistributed(keras.layers.Activation("relu"))(y)
44+
45+
y = keras.layers.TimeDistributed(keras.layers.Conv2D(filters, (3, 3), padding="same", **parameters))(y)
46+
47+
y = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=axis))(y)
48+
y = _shortcut(x, y)
49+
y = keras.layers.TimeDistributed(keras.layers.Activation("relu"))(y)
50+
51+
return y
52+
53+
return f
54+
55+
56+
def bottleneck(filters, strides=(1, 1), first=False):
57+
"""
58+
59+
A time distributed bottleneck block.
60+
61+
:param filters: the output’s feature space
62+
:param strides: the convolution’s stride
63+
:param first: whether this is the first instance inside a residual block
64+
65+
Usage::
66+
>>> import keras_resnet.block.temporal
67+
>>> keras_resnet.block.temporal.bottleneck(64)
68+
69+
"""
70+
def f(x):
71+
if keras.backend.image_data_format() == "channels_last":
72+
axis = 3
73+
else:
74+
axis = 1
75+
76+
if first:
77+
y = keras.layers.TimeDistributed(keras.layers.Conv2D(filters, (1, 1), strides=strides, padding="same", **parameters))(x)
78+
else:
79+
y = keras.layers.TimeDistributed(keras.layers.Conv2D(filters, (3, 3), strides=strides, padding="same", **parameters))(x)
80+
81+
y = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=axis))(y)
82+
y = keras.layers.TimeDistributed(keras.layers.Activation("relu"))(y)
83+
84+
y = keras.layers.TimeDistributed(keras.layers.Conv2D(filters, (3, 3), padding="same", **parameters))(y)
85+
86+
y = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=axis))(y)
87+
y = keras.layers.TimeDistributed(keras.layers.Activation("relu"))(y)
88+
89+
y = keras.layers.TimeDistributed(keras.layers.Conv2D(filters * 4, (1, 1), **parameters))(y)
90+
91+
y = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=axis))(y)
92+
y = _shortcut(x, y)
93+
y = keras.layers.TimeDistributed(keras.layers.Activation("relu"))(y)
94+
95+
return y
96+
97+
return f
98+
99+
100+
def _shortcut(a, b):
101+
a_shape = keras.backend.int_shape(a)
102+
b_shape = keras.backend.int_shape(b)
103+
104+
if keras.backend.image_data_format() == "channels_last":
105+
x = int(round(a_shape[1] // b_shape[1]))
106+
y = int(round(a_shape[2] // b_shape[2]))
107+
108+
if x > 1 or y > 1 or not a_shape[3] == b_shape[3]:
109+
a = keras.layers.TimeDistributed(keras.layers.Conv2D(b_shape[3], (1, 1), strides=(x, y), padding="same", **parameters))(a)
110+
111+
a = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=3))(a)
112+
else:
113+
x = int(round(a_shape[2] // b_shape[2]))
114+
y = int(round(a_shape[3] // b_shape[3]))
115+
116+
if x > 1 or y > 1 or not a_shape[1] == b_shape[1]:
117+
a = keras.layers.TimeDistributed(keras.layers.Conv2D(b_shape[1], (1, 1), strides=(x, y), padding="same", **parameters))(a)
118+
119+
a = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=1))(a)
120+
121+
return keras.layers.add([a, b])

keras_resnet/models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ def __init__(self, inputs, blocks, block):
4242
else:
4343
axis = 1
4444

45-
x = keras.layers.Conv2D(64, (7, 7), strides=(2, 2), padding="same")(inputs)
45+
x = keras.layers.Conv2D(64, (7, 7), strides=(2, 2), padding="same")(
46+
inputs)
4647
x = keras.layers.BatchNormalization(axis=axis)(x)
4748
x = keras.layers.Activation("relu")(x)
48-
x = keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x)
49+
x = keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding="same")(
50+
x)
4951

5052
features = 64
5153

@@ -94,6 +96,7 @@ class ResNet18(ResNet):
9496
>>> model.compile("adam", "categorical_crossentropy", ["accuracy"])
9597
9698
"""
99+
97100
def __init__(self, inputs):
98101
block = keras_resnet.block.basic
99102

@@ -115,6 +118,7 @@ class ResNet34(ResNet):
115118
>>> model.compile("adam", "categorical_crossentropy", ["accuracy"])
116119
117120
"""
121+
118122
def __init__(self, inputs):
119123
block = keras_resnet.block.basic
120124

@@ -136,6 +140,7 @@ class ResNet50(ResNet):
136140
>>> model.compile("adam", "categorical_crossentropy", ["accuracy"])
137141
138142
"""
143+
139144
def __init__(self, inputs):
140145
block = keras_resnet.block.bottleneck
141146

@@ -157,6 +162,7 @@ class ResNet101(ResNet):
157162
>>> model.compile("adam", "categorical_crossentropy", ["accuracy"])
158163
159164
"""
165+
160166
def __init__(self, inputs):
161167
block = keras_resnet.block.bottleneck
162168

@@ -178,6 +184,7 @@ class ResNet152(ResNet):
178184
>>> model.compile("adam", "categorical_crossentropy", ["accuracy"])
179185
180186
"""
187+
181188
def __init__(self, inputs):
182189
block = keras_resnet.block.bottleneck
183190

@@ -199,6 +206,7 @@ class ResNet200(ResNet):
199206
>>> model.compile("adam", "categorical_crossentropy", ["accuracy"])
200207
201208
"""
209+
202210
def __init__(self, inputs):
203211
block = keras_resnet.block.bottleneck
204212

0 commit comments

Comments
 (0)