Skip to content

Commit a14730f

Browse files
author
Ervin T
authored
[feature] Add small CNN for grids 5x5 and up (#4434)
1 parent 70476af commit a14730f

File tree

8 files changed

+105
-10
lines changed

8 files changed

+105
-10
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ and this project adheres to
2222
Note that PyTorch 1.6.0 or greater should be installed to use this feature; see
2323
[the PyTorch website](https://pytorch.org/) for installation instructions. (#4335)
2424
- The minimum supported version of TensorFlow was increased to 1.14.0. (#4411)
25+
- A CNN (`vis_encode_type: match3`) for smaller grids, e.g. board games, has been added.
26+
(#4434)
2527

2628
### Bug Fixes
2729
#### com.unity.ml-agents (C#)

docs/Training-Configuration-File.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ choice of the trainer (which we review on subsequent sections).
4141
| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger. <br><br> Typical range: `32` - `512` |
4242
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems. <br><br> Typical range: `1` - `3` |
4343
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
44-
| `network_settings -> vis_encoder_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. |
44+
| `network_settings -> vis_encoder_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |
4545

4646

4747
## Trainer-specific Configurations

ml-agents/mlagents/trainers/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def as_dict(self):
5959

6060

6161
class EncoderType(Enum):
62+
MATCH3 = "match3"
6263
SIMPLE = "simple"
6364
NATURE_CNN = "nature_cnn"
6465
RESNET = "resnet"

ml-agents/mlagents/trainers/tests/test_simple_rl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ def test_visual_ppo(num_visual, use_discrete):
192192

193193

194194
@pytest.mark.parametrize("num_visual", [1, 2])
195-
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
195+
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
196196
def test_visual_advanced_ppo(vis_encode_type, num_visual):
197197
env = SimpleEnvironment(
198198
[BRAIN_NAME],
199199
use_discrete=True,
200200
num_visual=num_visual,
201201
num_vector=0,
202202
step_size=0.5,
203-
vis_obs_size=(36, 36, 3),
203+
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
204204
)
205205
new_networksettings = attr.evolve(
206206
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)
@@ -271,15 +271,15 @@ def test_visual_sac(num_visual, use_discrete):
271271

272272

273273
@pytest.mark.parametrize("num_visual", [1, 2])
274-
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
274+
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
275275
def test_visual_advanced_sac(vis_encode_type, num_visual):
276276
env = SimpleEnvironment(
277277
[BRAIN_NAME],
278278
use_discrete=True,
279279
num_visual=num_visual,
280280
num_vector=0,
281281
step_size=0.5,
282-
vis_obs_size=(36, 36, 3),
282+
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
283283
)
284284
new_networksettings = attr.evolve(
285285
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)

ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,15 @@ def test_visual_ppo(num_visual, use_discrete):
193193

194194

195195
@pytest.mark.parametrize("num_visual", [1, 2])
196-
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
196+
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
197197
def test_visual_advanced_ppo(vis_encode_type, num_visual):
198198
env = SimpleEnvironment(
199199
[BRAIN_NAME],
200200
use_discrete=True,
201201
num_visual=num_visual,
202202
num_vector=0,
203203
step_size=0.5,
204-
vis_obs_size=(36, 36, 3),
204+
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
205205
)
206206
new_networksettings = attr.evolve(
207207
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)
@@ -272,15 +272,15 @@ def test_visual_sac(num_visual, use_discrete):
272272

273273

274274
@pytest.mark.parametrize("num_visual", [1, 2])
275-
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
275+
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
276276
def test_visual_advanced_sac(vis_encode_type, num_visual):
277277
env = SimpleEnvironment(
278278
[BRAIN_NAME],
279279
use_discrete=True,
280280
num_visual=num_visual,
281281
num_vector=0,
282282
step_size=0.5,
283-
vis_obs_size=(36, 36, 3),
283+
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
284284
)
285285
new_networksettings = attr.evolve(
286286
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)

ml-agents/mlagents/trainers/tf/models.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ModelUtils:
3232
# Minimum supported side for each encoder type. If refactoring an encoder, please
3333
# adjust these also.
3434
MIN_RESOLUTION_FOR_ENCODER = {
35+
EncoderType.MATCH3: 5,
3536
EncoderType.SIMPLE: 20,
3637
EncoderType.NATURE_CNN: 36,
3738
EncoderType.RESNET: 15,
@@ -211,7 +212,10 @@ def create_normalizer(vector_obs: tf.Tensor) -> NormalizerTensors:
211212
dtype=tf.float32,
212213
initializer=tf.ones_initializer(),
213214
)
214-
initialize_normalization, update_normalization = ModelUtils.create_normalizer_update(
215+
(
216+
initialize_normalization,
217+
update_normalization,
218+
) = ModelUtils.create_normalizer_update(
215219
vector_obs, steps, running_mean, running_variance
216220
)
217221
return NormalizerTensors(
@@ -346,6 +350,53 @@ def create_visual_observation_encoder(
346350
)
347351
return hidden_flat
348352

353+
@staticmethod
354+
def create_match3_visual_observation_encoder(
355+
image_input: tf.Tensor,
356+
h_size: int,
357+
activation: ActivationFunction,
358+
num_layers: int,
359+
scope: str,
360+
reuse: bool,
361+
) -> tf.Tensor:
362+
"""
363+
Builds a CNN with the architecture used by King for Candy Crush. Optimized
364+
for grid-shaped boards, such as with Match-3 games.
365+
:param image_input: The placeholder for the image input to use.
366+
:param h_size: Hidden layer size.
367+
:param activation: What type of activation function to use for layers.
368+
:param num_layers: number of hidden layers to create.
369+
:param scope: The scope of the graph within which to create the ops.
370+
:param reuse: Whether to re-use the weights within the same scope.
371+
:return: List of hidden layer tensors.
372+
"""
373+
with tf.variable_scope(scope):
374+
conv1 = tf.layers.conv2d(
375+
image_input,
376+
35,
377+
kernel_size=[3, 3],
378+
strides=[1, 1],
379+
activation=tf.nn.elu,
380+
reuse=reuse,
381+
name="conv_1",
382+
)
383+
conv2 = tf.layers.conv2d(
384+
conv1,
385+
144,
386+
kernel_size=[3, 3],
387+
strides=[1, 1],
388+
activation=tf.nn.elu,
389+
reuse=reuse,
390+
name="conv_2",
391+
)
392+
hidden = tf.layers.flatten(conv2)
393+
394+
with tf.variable_scope(scope + "/" + "flat_encoding"):
395+
hidden_flat = ModelUtils.create_vector_observation_encoder(
396+
hidden, h_size, activation, num_layers, scope, reuse
397+
)
398+
return hidden_flat
399+
349400
@staticmethod
350401
def create_nature_cnn_visual_observation_encoder(
351402
image_input: tf.Tensor,
@@ -475,6 +526,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> EncoderFunction:
475526
EncoderType.SIMPLE: ModelUtils.create_visual_observation_encoder,
476527
EncoderType.NATURE_CNN: ModelUtils.create_nature_cnn_visual_observation_encoder,
477528
EncoderType.RESNET: ModelUtils.create_resnet_visual_observation_encoder,
529+
EncoderType.MATCH3: ModelUtils.create_match3_visual_observation_encoder,
478530
}
479531
return ENCODER_FUNCTION_BY_TYPE.get(
480532
encoder_type, ModelUtils.create_visual_observation_encoder

ml-agents/mlagents/trainers/torch/encoders.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,43 @@ def update_normalization(self, inputs: torch.Tensor) -> None:
107107
self.normalizer.update(inputs)
108108

109109

110+
class SmallVisualEncoder(nn.Module):
111+
"""
112+
CNN architecture used by King in their Candy Crush predictor
113+
https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning
114+
"""
115+
116+
def __init__(
117+
self, height: int, width: int, initial_channels: int, output_size: int
118+
):
119+
super().__init__()
120+
self.h_size = output_size
121+
conv_1_hw = conv_output_shape((height, width), 3, 1)
122+
conv_2_hw = conv_output_shape(conv_1_hw, 3, 1)
123+
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 144
124+
125+
self.conv_layers = nn.Sequential(
126+
nn.Conv2d(initial_channels, 35, [3, 3], [1, 1]),
127+
nn.LeakyReLU(),
128+
nn.Conv2d(35, 144, [3, 3], [1, 1]),
129+
nn.LeakyReLU(),
130+
)
131+
self.dense = nn.Sequential(
132+
linear_layer(
133+
self.final_flat,
134+
self.h_size,
135+
kernel_init=Initialization.KaimingHeNormal,
136+
kernel_gain=1.0,
137+
),
138+
nn.LeakyReLU(),
139+
)
140+
141+
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
142+
hidden = self.conv_layers(visual_obs)
143+
hidden = torch.reshape(hidden, (-1, self.final_flat))
144+
return self.dense(hidden)
145+
146+
110147
class SimpleVisualEncoder(nn.Module):
111148
def __init__(
112149
self, height: int, width: int, initial_channels: int, output_size: int

ml-agents/mlagents/trainers/torch/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SimpleVisualEncoder,
77
ResNetVisualEncoder,
88
NatureVisualEncoder,
9+
SmallVisualEncoder,
910
VectorInput,
1011
)
1112
from mlagents.trainers.settings import EncoderType, ScheduleType
@@ -18,6 +19,7 @@ class ModelUtils:
1819
# Minimum supported side for each encoder type. If refactoring an encoder, please
1920
# adjust these also.
2021
MIN_RESOLUTION_FOR_ENCODER = {
22+
EncoderType.MATCH3: 5,
2123
EncoderType.SIMPLE: 20,
2224
EncoderType.NATURE_CNN: 36,
2325
EncoderType.RESNET: 15,
@@ -123,6 +125,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
123125
EncoderType.SIMPLE: SimpleVisualEncoder,
124126
EncoderType.NATURE_CNN: NatureVisualEncoder,
125127
EncoderType.RESNET: ResNetVisualEncoder,
128+
EncoderType.MATCH3: SmallVisualEncoder,
126129
}
127130
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
128131

0 commit comments

Comments
 (0)