design exploration and evaluators for off-policy algos
This commit is contained in:
parent
2eb056a721
commit
54a7b1343d
@ -78,13 +78,10 @@ if __name__ == '__main__':
|
|||||||
critic.sync_weights()
|
critic.sync_weights()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(100):
|
data_collector.collect(num_timesteps=1e3) # warm-up
|
||||||
|
for i in range(int(1e8)):
|
||||||
# collect data
|
# collect data
|
||||||
data_collector.collect(num_episodes=50)
|
data_collector.collect()
|
||||||
|
|
||||||
# print current return
|
|
||||||
print('Epoch {}:'.format(i))
|
|
||||||
data_collector.statistics()
|
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
@ -92,3 +89,7 @@ if __name__ == '__main__':
|
|||||||
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
|
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
|
|
||||||
|
# test every 1000 training steps
|
||||||
|
if i % 1000 == 0:
|
||||||
|
test(env, actor)
|
||||||
|
@ -77,17 +77,28 @@ if __name__ == '__main__':
|
|||||||
pi.sync_weights() # TODO: automate this for policies with target network
|
pi.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(100):
|
epsilon = 0.5
|
||||||
|
pi.set_epsilon_train(epsilon)
|
||||||
|
data_collector.collect(num_timesteps=1e3) # warm-up
|
||||||
|
for i in range(int(1e8)): # number of training steps
|
||||||
|
# anneal epsilon step-wise
|
||||||
|
if (i + 1) % 1e4 == 0 and epsilon > 0.1:
|
||||||
|
epsilon -= 0.1
|
||||||
|
pi.set_epsilon_train(epsilon)
|
||||||
|
|
||||||
# collect data
|
# collect data
|
||||||
data_collector.collect()
|
data_collector.collect()
|
||||||
|
|
||||||
# print current return
|
|
||||||
print('Epoch {}:'.format(i))
|
|
||||||
data_collector.statistics()
|
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
feed_dict = data_collector.next_batch(batch_size)
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
|
|
||||||
|
# test every 1000 training steps
|
||||||
|
# tester could share some code with batch!
|
||||||
|
if i % 1000 == 0:
|
||||||
|
# epsilon 0.05 as in nature paper
|
||||||
|
pi.set_epsilon_test(0.05)
|
||||||
|
test(env, pi) # go for act_test of pi, not act
|
||||||
|
@ -59,7 +59,7 @@ class DataCollector(object):
|
|||||||
if self.process_mode == 'minibatch':
|
if self.process_mode == 'minibatch':
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# flatten rank-2 list to numpy array
|
# flatten rank-2 list to numpy array, construct feed_dict
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user