Tianshou/tianshou/data/collector.py

230 lines
9.0 KiB
Python
Raw Normal View History

2020-03-16 11:11:29 +08:00
import time
2020-03-14 21:48:31 +08:00
import torch
2020-03-12 22:20:33 +08:00
import numpy as np
from copy import deepcopy
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer
from tianshou.utils import MovAvg
2020-03-13 17:49:22 +08:00
2020-03-12 22:20:33 +08:00
class Collector(object):
"""docstring for Collector"""
2020-03-13 17:49:22 +08:00
2020-03-25 14:08:28 +08:00
def __init__(self, policy, env, buffer=None, stat_size=100):
2020-03-12 22:20:33 +08:00
super().__init__()
self.env = env
self.env_num = 1
2020-03-16 15:04:58 +08:00
self.collect_step = 0
2020-03-19 17:23:46 +08:00
self.collect_episode = 0
2020-03-20 19:52:29 +08:00
self.collect_time = 0
2020-03-25 14:08:28 +08:00
if buffer is None:
self.buffer = ReplayBuffer(20000)
else:
self.buffer = buffer
2020-03-12 22:20:33 +08:00
self.policy = policy
self.process_fn = policy.process_fn
2020-03-14 21:48:31 +08:00
self._multi_env = isinstance(env, BaseVectorEnv)
2020-03-18 21:45:41 +08:00
self._multi_buf = False # True if buf is a list
2020-03-15 17:41:00 +08:00
# need multiple cache buffers only if storing in one buffer
2020-03-14 21:48:31 +08:00
self._cached_buf = []
if self._multi_env:
2020-03-12 22:20:33 +08:00
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num, \
'The number of data buffer does not match the number of ' \
2020-03-15 17:41:00 +08:00
'input env.'
2020-03-14 21:48:31 +08:00
self._multi_buf = True
2020-03-15 17:41:00 +08:00
elif isinstance(self.buffer, ReplayBuffer):
2020-03-14 21:48:31 +08:00
self._cached_buf = [
2020-03-25 14:08:28 +08:00
deepcopy(self.buffer) for _ in range(self.env_num)]
2020-03-12 22:20:33 +08:00
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()
2020-03-15 17:41:00 +08:00
self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor
2020-03-12 22:20:33 +08:00
self.state = None
2020-03-19 17:23:46 +08:00
self.step_speed = MovAvg(stat_size)
self.episode_speed = MovAvg(stat_size)
2020-03-12 22:20:33 +08:00
2020-03-15 17:41:00 +08:00
def reset_buffer(self):
2020-03-14 21:48:31 +08:00
if self._multi_buf:
2020-03-12 22:20:33 +08:00
for b in self.buffer:
b.reset()
else:
self.buffer.reset()
2020-03-27 09:04:29 +08:00
def get_env_num(self):
return self.env_num
2020-03-12 22:20:33 +08:00
def reset_env(self):
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None
2020-03-14 21:48:31 +08:00
if self._multi_env:
2020-03-13 17:49:22 +08:00
self.reward = np.zeros(self.env_num)
self.length = np.zeros(self.env_num)
else:
self.reward, self.length = 0, 0
2020-03-14 21:48:31 +08:00
for b in self._cached_buf:
b.reset()
2020-03-15 17:41:00 +08:00
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
2020-03-25 14:08:28 +08:00
return self.env.seed(seed)
2020-03-15 17:41:00 +08:00
2020-03-18 21:45:41 +08:00
def render(self, **kwargs):
2020-03-15 17:41:00 +08:00
if hasattr(self.env, 'render'):
2020-03-25 14:08:28 +08:00
return self.env.render(**kwargs)
2020-03-15 17:41:00 +08:00
def close(self):
if hasattr(self.env, 'close'):
self.env.close()
2020-03-16 11:11:29 +08:00
def _make_batch(self, data):
2020-03-14 21:48:31 +08:00
if isinstance(data, np.ndarray):
return data[None]
else:
2020-03-25 14:08:28 +08:00
return np.array([data])
2020-03-12 22:20:33 +08:00
2020-03-16 11:11:29 +08:00
def collect(self, n_step=0, n_episode=0, render=0):
2020-03-25 14:08:28 +08:00
if not self._multi_env:
n_episode = np.sum(n_episode)
2020-03-16 15:04:58 +08:00
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
2020-03-13 17:49:22 +08:00
"One and only one collection number specification permitted!"
2020-03-15 17:41:00 +08:00
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
2020-03-19 17:23:46 +08:00
reward_sum = 0
length_sum = 0
2020-03-12 22:20:33 +08:00
while True:
2020-03-15 17:41:00 +08:00
if self._multi_env:
2020-03-13 17:49:22 +08:00
batch_data = Batch(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info)
2020-03-12 22:20:33 +08:00
else:
2020-03-13 17:49:22 +08:00
batch_data = Batch(
2020-03-14 21:48:31 +08:00
obs=self._make_batch(self._obs),
act=self._make_batch(self._act),
rew=self._make_batch(self._rew),
done=self._make_batch(self._done),
2020-03-15 17:41:00 +08:00
obs_next=None,
info=self._make_batch(self._info))
result = self.policy(batch_data, self.state)
2020-03-14 21:48:31 +08:00
self.state = result.state if hasattr(result, 'state') else None
2020-03-18 21:45:41 +08:00
if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy()
2020-03-25 14:08:28 +08:00
elif not isinstance(self._act, np.ndarray):
2020-03-18 21:45:41 +08:00
self._act = np.array(result.act)
2020-03-25 14:08:28 +08:00
else:
self._act = result.act
2020-03-13 17:49:22 +08:00
obs_next, self._rew, self._done, self._info = self.env.step(
2020-03-16 11:11:29 +08:00
self._act if self._multi_env else self._act[0])
if render > 0:
self.env.render()
time.sleep(render)
2020-03-13 17:49:22 +08:00
self.length += 1
self.reward += self._rew
2020-03-14 21:48:31 +08:00
if self._multi_env:
2020-03-12 22:20:33 +08:00
for i in range(self.env_num):
2020-03-14 21:48:31 +08:00
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i]}
if self._cached_buf:
self._cached_buf[i].add(**data)
elif self._multi_buf:
self.buffer[i].add(**data)
cur_step += 1
else:
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
2020-03-25 14:08:28 +08:00
if n_step != 0 or np.isscalar(n_episode) or \
cur_episode[i] < n_episode[i]:
cur_episode[i] += 1
reward_sum += self.reward[i]
length_sum += self.length[i]
if self._cached_buf:
cur_step += len(self._cached_buf[i])
self.buffer.update(self._cached_buf[i])
2020-03-14 21:48:31 +08:00
self.reward[i], self.length[i] = 0, 0
if self._cached_buf:
self._cached_buf[i].reset()
if isinstance(self.state, list):
self.state[i] = None
2020-03-15 17:41:00 +08:00
elif self.state is not None:
2020-03-18 21:45:41 +08:00
if isinstance(self.state[i], dict):
self.state[i] = {}
else:
self.state[i] = self.state[i] * 0
2020-03-14 21:48:31 +08:00
if isinstance(self.state, torch.Tensor):
2020-03-15 17:41:00 +08:00
# remove ref count in pytorch (?)
2020-03-14 21:48:31 +08:00
self.state = self.state.detach()
2020-03-25 14:08:28 +08:00
if sum(self._done):
obs_next = self.env.reset(np.where(self._done)[0])
if n_episode != 0:
if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \
np.isscalar(n_episode) and \
cur_episode.sum() >= n_episode:
break
2020-03-12 22:20:33 +08:00
else:
2020-03-13 17:49:22 +08:00
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
2020-03-14 21:48:31 +08:00
cur_step += 1
2020-03-12 22:20:33 +08:00
if self._done:
cur_episode += 1
2020-03-19 17:23:46 +08:00
reward_sum += self.reward
length_sum += self.length
2020-03-13 17:49:22 +08:00
self.reward, self.length = 0, 0
2020-03-12 22:20:33 +08:00
self.state = None
2020-03-25 14:08:28 +08:00
obs_next = self.env.reset()
if n_episode != 0 and cur_episode >= n_episode:
2020-03-12 22:20:33 +08:00
break
2020-03-25 14:08:28 +08:00
if n_step != 0 and cur_step >= n_step:
2020-03-12 22:20:33 +08:00
break
self._obs = obs_next
self._obs = obs_next
2020-03-17 15:16:30 +08:00
if self._multi_env:
cur_episode = sum(cur_episode)
2020-03-19 17:23:46 +08:00
duration = time.time() - start_time
self.step_speed.add(cur_step / duration)
self.episode_speed.add(cur_episode / duration)
self.collect_step += cur_step
self.collect_episode += cur_episode
2020-03-20 19:52:29 +08:00
self.collect_time += duration
2020-03-25 14:08:28 +08:00
if isinstance(n_episode, list):
n_episode = np.sum(n_episode)
else:
n_episode = max(cur_episode, 1)
2020-03-16 15:04:58 +08:00
return {
2020-03-19 17:23:46 +08:00
'n/ep': cur_episode,
'n/st': cur_step,
2020-03-20 19:52:29 +08:00
'v/st': self.step_speed.get(),
'v/ep': self.episode_speed.get(),
2020-03-25 14:08:28 +08:00
'rew': reward_sum / n_episode,
'len': length_sum / n_episode,
2020-03-16 15:04:58 +08:00
}
2020-03-12 22:20:33 +08:00
2020-03-13 17:49:22 +08:00
def sample(self, batch_size):
2020-03-14 21:48:31 +08:00
if self._multi_buf:
2020-03-13 17:49:22 +08:00
if batch_size > 0:
lens = [len(b) for b in self.buffer]
total = sum(lens)
2020-03-15 17:41:00 +08:00
batch_index = np.random.choice(
2020-03-13 17:49:22 +08:00
total, batch_size, p=np.array(lens) / total)
else:
2020-03-15 17:41:00 +08:00
batch_index = np.array([])
2020-03-13 17:49:22 +08:00
batch_data = Batch()
for i, b in enumerate(self.buffer):
2020-03-15 17:41:00 +08:00
cur_batch = (batch_index == i).sum()
2020-03-13 17:49:22 +08:00
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)
batch_data.append(batch)
else:
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data