-
Notifications
You must be signed in to change notification settings - Fork 560
"sliding window" bigtable training mode #713
base: master
Are you sure you want to change the base?
Changes from 10 commits
d40fa04
32fdca9
39807ee
3a3f2a9
145a6fc
68f3e79
7389196
b68bc5e
936638a
9091cca
0cab4e9
82f5a8b
fc0e3a0
8518517
b29875a
e90ce64
cdfe391
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,6 +261,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1, | |
| return dataset | ||
|
|
||
|
|
||
| def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, | ||
| start_at, num_datasets, | ||
| moves=2**21, | ||
| window_size=500e3, | ||
| window_increment=25000): | ||
| dataset = None | ||
| for i in range(num_datasets): | ||
| # TODO(amj) mixin calibration games with some math. (from start_at that | ||
| # is proportionally along compared to last_game_number? comparing | ||
| # timestamps?) | ||
| ds = games.moves_from_games(start_at + (i * window_increment), | ||
| start_at + (i * window_increment) + window_size, | ||
|
||
| moves=moves, | ||
| shuffle=True, | ||
| column_family=bigtable_input.TFEXAMPLE, | ||
| column='example') | ||
| ds = ds.repeat(1) | ||
|
||
| ds = ds.map(lambda row_name, s: s) | ||
| dataset = dataset.concatenate(ds) if dataset else ds | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding the general approach: if the training loop does multiple scans, I would expect to create a new dataset for each pass, rather than try to create a single enormous dataset, which I imagine would be harder to debug, inspect, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but multiple calls to tpuestimator.train will create new graphs :( I am not sure what a good solution for lazy evaluating of these Datasets would be. As it is, it takes a real long time to build the datasets before training even starts -- i suspect the concatenate is doing something bad as things get slower and slower. |
||
|
|
||
| dataset = dataset.batch(batch_size,drop_remainder=False) | ||
| dataset = dataset.map( | ||
| functools.partial(batch_parse_tf_example, batch_size)) | ||
| # Unbatch the dataset so we can rotate it | ||
| dataset = dataset.apply(tf.contrib.data.unbatch()) | ||
| dataset = dataset.apply(tf.contrib.data.map_and_batch( | ||
| _random_rotation_pure_tf, | ||
| batch_size, | ||
| drop_remainder=True)) | ||
|
|
||
| dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) | ||
| return dataset | ||
|
|
||
|
|
||
| def make_dataset_from_selfplay(data_extracts): | ||
| ''' | ||
| Returns an iterable of tf.Examples. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,36 @@ def after_run(self, run_context, run_values): | |
| self.before_weights = None | ||
|
|
||
|
|
||
| def train_many(start_at=1000000, num_datasets=3): | ||
|
||
| """ Trains on a set of bt_datasets, skipping eval for now. | ||
| (from preprocessing.get_many_tpu_bt_input_tensors) | ||
| """ | ||
| if not FLAGS.use_tpu and FLAGS.use_bt: | ||
| raise ValueError("Only tpu & bt mode supported") | ||
|
|
||
| tf.logging.set_verbosity(tf.logging.INFO) | ||
| estimator = dual_net.get_estimator() | ||
| effective_batch_size = FLAGS.train_batch_size * FLAGS.num_tpu_cores | ||
|
|
||
| def _input_fn(params): | ||
| games = bigtable_input.GameQueue( | ||
| FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table) | ||
| games_nr = bigtable_input.GameQueue( | ||
| FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr') | ||
|
|
||
| return preprocessing.get_many_tpu_bt_input_tensors( | ||
| games, games_nr, params['batch_size'], | ||
| start_at=start_at, num_datasets=num_datasets) | ||
|
|
||
| hooks = [] | ||
| steps = num_datasets * FLAGS.steps_to_train | ||
| logging.info("Training, steps = %s, batch = %s -> %s examples", | ||
| steps or '?', effective_batch_size, | ||
| (steps * effective_batch_size) if steps else '?') | ||
|
|
||
| estimator.train(_input_fn, steps=steps, hooks=hooks) | ||
|
|
||
|
|
||
| def train(*tf_records: "Records to train on"): | ||
| """Train on examples.""" | ||
| tf.logging.set_verbosity(tf.logging.INFO) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indenting is wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.