tensorflow - SyncReplicasOptimizer Not Looping Over All Epochs -
after running following code, global_step
variable not equal num_batches * epochs
expected. instead, stops @ end of 1 epoch. asynchronous code without syncreplicasoptimiser works expected though. there i'm doing wrong? here's code:
(args.steps
more num_batches * epochs
)
sess = sv.prepare_or_wait_for_session(server.target) queue_runners = tf.get_collection(tf.graphkeys.queue_runners) sv.start_queue_runners(sess, queue_runners) tf.logging.info('started %d queues processing input data.', len(queue_runners)) if is_chief: sv.start_queue_runners(sess, chief_queue_runners) sess.run(init_tokens_op) print("{0} session ready".format(datetime.now().isoformat())) ##################################################################### ########################### training loop ########################### _current_state = np.zeros((batch_size, state_size)) batch_idx in range(args.steps): if sv.should_stop() or tf_feed.should_stop(): break batchx, batchy = feed_dict(tf_feed.next_batch(batch_size)) print('==========================================================') print(_current_state) if args.mode == "train": _total_loss, _train_step, _current_state, _predictions_series, _global_step = sess.run( [total_loss, train_step, current_state, predictions_series, global_step], feed_dict={ batchx_placeholder:batchx, batchy_placeholder:batchy, init_state:_current_state }) print(_global_step, batch_idx) print(_current_state) print('==========================================================') if _global_step % 5 == 0: print("step", _global_step, "loss", _total_loss) # else: # todo code checking ################################################################# if sv.should_stop() or batch_idx > args.steps: tf_feed.terminate() print("{0} stopping supervisor".format(datetime.now().isoformat())) sv.stop()
Comments
Post a Comment