diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 7b83d3e5b..3ffdc6356 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -15,6 +15,7 @@ from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs from neuralmonkey.logging import log, warn from neuralmonkey.model.sequence import EmbeddedSequence +from neuralmonkey.model.stateful import TemporalStateful from neuralmonkey.nn.utils import dropout from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants from neuralmonkey.vocabulary import Vocabulary, START_TOKEN, UNK_TOKEN_INDEX @@ -93,7 +94,7 @@ class DecoderFeedables(NamedTuple( # pylint: disable=too-many-public-methods,too-many-instance-attributes -class AutoregressiveDecoder(ModelPart): +class AutoregressiveDecoder(ModelPart, TemporalStateful): # pylint: disable=too-many-arguments def __init__(self, @@ -475,3 +476,21 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: fd[self.train_mask] = weights return fd + + @tensor + def temporal_states(self) -> tf.Tensor: + # strip the last symbol which is + return tf.cond( + self.train_mode, + lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1], + lambda: tf.transpose( + self.runtime_output_states, [1, 0, 2])[:, :-1]) + + @tensor + def temporal_mask(self) -> tf.Tensor: + # strip the last symbol which is + return tf.cond( + self.train_mode, + lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1], + lambda: tf.to_float(tf.transpose( + self.runtime_mask, [1, 0])[:, :-1])) diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 89add3de3..f0a9b7974 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -1,38 +1,47 @@ -from typing import Optional, Union - import tensorflow as tf +from typeguard import check_argument_types from neuralmonkey.dataset import Dataset from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs -from neuralmonkey.encoders.recurrent import RecurrentEncoder -from neuralmonkey.encoders.facebook_conv import SentenceEncoder +from neuralmonkey.model.stateful import TemporalStateful from neuralmonkey.vocabulary import Vocabulary from neuralmonkey.decorators import tensor from neuralmonkey.tf_utils import get_variable class SequenceLabeler(ModelPart): - """Classifier assing a label to each encoder's state.""" + """Classifier assigning a label to each input state. + + If the labeler output has an input sequence with embeddings, these are used + as additional input to the labeler. + + Note that when the labeler is stacked on an autoregressive decoder, it + labels the symbol that is currently generated by the decoder, i.e., the + decoder state has not yet been updated by putting the decoded symbol on + its input. + """ # pylint: disable=too-many-arguments def __init__(self, name: str, - encoder: Union[RecurrentEncoder, SentenceEncoder], + input_sequence: TemporalStateful, vocabulary: Vocabulary, data_id: str, dropout_keep_prob: float = 1.0, - save_checkpoint: Optional[str] = None, - load_checkpoint: Optional[str] = None, + save_checkpoint: str = None, + load_checkpoint: str = None, initializers: InitializerSpecs = None) -> None: + check_argument_types() ModelPart.__init__(self, name, save_checkpoint, load_checkpoint, initializers) - self.encoder = encoder + self.input_sequence = input_sequence self.vocabulary = vocabulary self.data_id = data_id self.dropout_keep_prob = dropout_keep_prob - self.rnn_size = int(self.encoder.temporal_states.get_shape()[-1]) + self.input_size = int( + self.input_sequence.temporal_states.get_shape()[-1]) with self.use_scope(): self.train_targets = tf.placeholder( @@ -45,7 +54,7 @@ def __init__(self, def decoding_w(self) -> tf.Variable: return get_variable( name="state_to_word_W", - shape=[self.rnn_size, len(self.vocabulary)], + shape=[self.input_size, len(self.vocabulary)], initializer=tf.glorot_normal_initializer()) @tensor @@ -57,7 +66,8 @@ def decoding_b(self) -> tf.Variable: @tensor def decoding_residual_w(self) -> tf.Variable: - input_dim = self.encoder.input_sequence.dimension + input_dim = ( + self.input_sequence.input_sequence.dimension) # type: ignore return get_variable( name="emb_to_word_W", shape=[input_dim, len(self.vocabulary)], @@ -71,25 +81,27 @@ def logits(self) -> tf.Tensor: # TODO dropout needs to be revisited - encoder_states = tf.expand_dims(self.encoder.temporal_states, 2) + intpus_states = tf.expand_dims(self.input_sequence.temporal_states, 2) weights_4d = tf.expand_dims(tf.expand_dims(self.decoding_w, 0), 0) multiplication = tf.nn.conv2d( - encoder_states, weights_4d, [1, 1, 1, 1], "SAME") + intpus_states, weights_4d, [1, 1, 1, 1], "SAME") multiplication_3d = tf.squeeze(multiplication, squeeze_dims=[2]) biases_3d = tf.expand_dims(tf.expand_dims(self.decoding_b, 0), 0) + logits = multiplication_3d + biases_3d - embedded_inputs = tf.expand_dims( - self.encoder.input_sequence.temporal_states, 2) - dweights_4d = tf.expand_dims( - tf.expand_dims(self.decoding_residual_w, 0), 0) + if hasattr(self.input_sequence, "input_sequence"): + inputs_input = self.input_sequence.input_sequence # type: ignore + embedded_inputs = tf.expand_dims(inputs_input.temporal_states, 2) + dweights_4d = tf.expand_dims( + tf.expand_dims(self.decoding_residual_w, 0), 0) - dmultiplication = tf.nn.conv2d( - embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME") - dmultiplication_3d = tf.squeeze(dmultiplication, squeeze_dims=[2]) + dmultiplication = tf.nn.conv2d( + embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME") + dmultiplication_3d = tf.squeeze(dmultiplication, squeeze_dims=[2]) - logits = multiplication_3d + dmultiplication_3d + biases_3d + logits += dmultiplication_3d return logits @tensor @@ -102,13 +114,20 @@ def decoded(self) -> tf.Tensor: @tensor def cost(self) -> tf.Tensor: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=self.train_targets, logits=self.logits) - - # loss is now of shape [batch, time]. Need to mask it now by - # element-wise multiplication with weights placeholder - weighted_loss = loss * self.train_weights - return tf.reduce_sum(weighted_loss) + min_time = tf.minimum(tf.shape(self.train_targets)[1], + tf.shape(self.logits)[1]) + + # In case the labeler is stacked on a decoder which emits also an end + # symbol (or for some reason emits more symbol than we have in the + # ground truth labels), we trim the sequences to the length of a + # shorter one. + + # pylint: disable=unsubscriptable-object + return tf.contrib.seq2seq.sequence_loss( + logits=self.logits[:, :min_time], + targets=self.train_targets[:, :min_time], + weights=self.input_sequence.temporal_mask[:, :min_time]) + # pylint: enable=unsubscriptable-object @property def train_loss(self) -> tf.Tensor: diff --git a/neuralmonkey/decoders/transformer.py b/neuralmonkey/decoders/transformer.py index 00f517547..ed250974c 100644 --- a/neuralmonkey/decoders/transformer.py +++ b/neuralmonkey/decoders/transformer.py @@ -126,10 +126,15 @@ def __init__(self, self.encoder_states = get_attention_states(self.encoder) self.encoder_mask = get_attention_mask(self.encoder) - self.dimension = ( - self.encoder_states.get_shape()[2].value) # type: ignore - if self.embedding_size != self.dimension: + # This assertion (and the "int" type declaration below) here is because + # of mypy not being able to handle the tf.Tensor type. + assert self.encoder_states is not None + + self.model_dimension = ( + self.encoder_states.get_shape()[2].value) # type: int + + if self.embedding_size != self.model_dimension: raise ValueError("Model dimension and input embedding size" "do not match") @@ -140,7 +145,7 @@ def __init__(self, @property def output_dimension(self) -> int: - return self.dimension + return self.model_dimension def embed_inputs(self, inputs: tf.Tensor) -> tf.Tensor: embedded = tf.nn.embedding_lookup(self.embedding_matrix, inputs) @@ -156,7 +161,7 @@ def embed_inputs(self, inputs: tf.Tensor) -> tf.Tensor: embedded *= math.sqrt(embedding_size) length = tf.shape(inputs)[1] - return embedded + position_signal(self.dimension, length) + return embedded + position_signal(self.model_dimension, length) @tensor def embedded_train_inputs(self) -> tf.Tensor: @@ -241,7 +246,8 @@ def feedforward_sublayer(self, layer_input: tf.Tensor) -> tf.Tensor: ff_hidden = dropout(ff_hidden, self.dropout_keep_prob, self.train_mode) # Feed-forward output projection - ff_output = tf.layers.dense(ff_hidden, self.dimension, name="output") + ff_output = tf.layers.dense( + ff_hidden, self.model_dimension, name="output") # Apply dropout on the output projection ff_output = dropout(ff_output, self.dropout_keep_prob, self.train_mode) diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index 97d9baced..0f8ffbacb 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -87,7 +87,7 @@ def get_executable(self, num_sessions: int) -> LabelRunExecutable: fetches = { "label_logprobs": self._decoder.logprobs, - "input_mask": self._decoder.encoder.input_sequence.temporal_mask} + "input_mask": self._decoder.input_sequence.temporal_mask} if compute_losses: fetches["loss"] = self._decoder.cost diff --git a/tests/bpe.ini b/tests/bpe.ini index 482aede74..c37baa598 100644 --- a/tests/bpe.ini +++ b/tests/bpe.ini @@ -10,7 +10,7 @@ epochs=2 train_dataset= val_dataset= trainer= -runners=[] +runners=[,] evaluation=[("target", evaluators.BLEU), ("target_greedy", "target", evaluators.BLEU)] val_preview_num_examples=5 val_preview_input_series=["source", "target", "target_bpe"] @@ -94,10 +94,18 @@ data_id="target_bpe" max_output_len=10 vocabulary= +[labeler] +class=decoders.sequence_labeler.SequenceLabeler +name="tagger" +input_sequence= +data_id="target_bpe" +dropout_keep_prob=0.5 +vocabulary= + [trainer] ; This block just fills the arguments of the trainer __init__ method. class=trainers.cross_entropy_trainer.CrossEntropyTrainer -decoders=[] +decoders=[,] l2_weight=1.0e-8 clip_norm=1.0 optimizer= @@ -114,3 +122,8 @@ class=runners.GreedyRunner decoder= postprocess= output_series="target_greedy" + +[lab_runner] +class=runners.LabelRunner +decoder= +output_series="tags" diff --git a/tests/labeler.ini b/tests/labeler.ini index 8b2c8db64..a6da1069d 100644 --- a/tests/labeler.ini +++ b/tests/labeler.ini @@ -63,7 +63,7 @@ vocabulary= [decoder] class=decoders.sequence_labeler.SequenceLabeler name="tagger" -encoder= +input_sequence= data_id="tags" dropout_keep_prob=0.5 vocabulary=