fix the env -> self._env bug

This commit is contained in:
Dong Yan 2018-02-10 03:42:00 +08:00
parent 50b2d98d0a
commit 2163d18728
3 changed files with 6 additions and 6 deletions

View File

@ -79,4 +79,4 @@ if __name__ == '__main__':
feed_dict = data_collector.next_batch(batch_size) feed_dict = data_collector.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict) sess.run(train_op, feed_dict=feed_dict)
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -83,4 +83,4 @@ if __name__ == '__main__':
# assigning actor to pi_old # assigning actor to pi_old
pi.update_weights() pi.update_weights()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -39,10 +39,10 @@ class Batch(object):
if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines
t = 0 t = 0
ac = self.env.action_space.sample() # not used, just so we have the datatype ac = self._env.action_space.sample() # not used, just so we have the datatype
new = True # marks if we're on first timestep of an episode new = True # marks if we're on first timestep of an episode
if self.is_first_collect: if self.is_first_collect:
ob = self.env.reset() ob = self._env.reset()
self.is_first_collect = False self.is_first_collect = False
else: else:
ob = self.raw_data['observations'][0] # last observation! ob = self.raw_data['observations'][0] # last observation!
@ -69,7 +69,7 @@ class Batch(object):
actions[i] = ac actions[i] = ac
prevacs[i] = prevac prevacs[i] = prevac
ob, rew, new, _ = env.step(ac) ob, rew, new, _ = self._env.step(ac)
rewards[i] = rew rewards[i] = rew
cur_ep_ret += rew cur_ep_ret += rew
@ -79,7 +79,7 @@ class Batch(object):
ep_lens.append(cur_ep_len) ep_lens.append(cur_ep_len)
cur_ep_ret = 0 cur_ep_ret = 0
cur_ep_len = 0 cur_ep_len = 0
ob = env.reset() ob = self._env.reset()
t += 1 t += 1
if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail