Skip to content
This repository was archived by the owner on Mar 11, 2021. It is now read-only.
34 changes: 34 additions & 0 deletions preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1,
return dataset


def get_many_tpu_bt_input_tensors(games, games_nr, batch_size,
start_at, num_datasets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indenting is wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

moves=2**21,
window_size=500e3,
window_increment=25000):
dataset = None
for i in range(num_datasets):
# TODO(amj) mixin calibration games with some math. (from start_at that
# is proportionally along compared to last_game_number? comparing
# timestamps?)
ds = games.moves_from_games(start_at + (i * window_increment),
start_at + (i * window_increment) + window_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit indenting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

moves=moves,
shuffle=True,
column_family=bigtable_input.TFEXAMPLE,
column='example')
ds = ds.repeat(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably move the repeat and map out of this loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ds = ds.map(lambda row_name, s: s)
dataset = dataset.concatenate(ds) if dataset else ds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the general approach: if the training loop does multiple scans, I would expect to create a new dataset for each pass, rather than try to create a single enormous dataset, which I imagine would be harder to debug, inspect, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but multiple calls to tpuestimator.train will create new graphs :( I am not sure what a good solution for lazy evaluating of these Datasets would be. As it is, it takes a real long time to build the datasets before training even starts -- i suspect the concatenate is doing something bad as things get slower and slower.


dataset = dataset.batch(batch_size,drop_remainder=False)
dataset = dataset.map(
functools.partial(batch_parse_tf_example, batch_size))
# Unbatch the dataset so we can rotate it
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.apply(tf.contrib.data.map_and_batch(
_random_rotation_pure_tf,
batch_size,
drop_remainder=True))

dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
return dataset


def make_dataset_from_selfplay(data_extracts):
'''
Returns an iterable of tf.Examples.
Expand Down
30 changes: 30 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,36 @@ def after_run(self, run_context, run_values):
self.before_weights = None


def train_many(start_at=1000000, num_datasets=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you expose moves here also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? number of steps?

""" Trains on a set of bt_datasets, skipping eval for now.
(from preprocessing.get_many_tpu_bt_input_tensors)
"""
if not FLAGS.use_tpu and FLAGS.use_bt:
raise ValueError("Only tpu & bt mode supported")

tf.logging.set_verbosity(tf.logging.INFO)
estimator = dual_net.get_estimator()
effective_batch_size = FLAGS.train_batch_size * FLAGS.num_tpu_cores

def _input_fn(params):
games = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
games_nr = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr')

return preprocessing.get_many_tpu_bt_input_tensors(
games, games_nr, params['batch_size'],
start_at=start_at, num_datasets=num_datasets)

hooks = []
steps = num_datasets * FLAGS.steps_to_train
logging.info("Training, steps = %s, batch = %s -> %s examples",
steps or '?', effective_batch_size,
(steps * effective_batch_size) if steps else '?')

estimator.train(_input_fn, steps=steps, hooks=hooks)


def train(*tf_records: "Records to train on"):
"""Train on examples."""
tf.logging.set_verbosity(tf.logging.INFO)
Expand Down