diff --git a/setup.py b/setup.py index ae5ce99..fa5d74f 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ setup( 'Programming Language :: Python :: 3.8', ], keywords='reinforcement learning platform', - packages=find_packages(exclude=['tests', 'tests.*', + packages=find_packages(exclude=['test', 'test.*', 'examples', 'examples.*', 'docs', 'docs.*']), install_requires=[ diff --git a/test/test_buffer.py b/test/test_buffer.py index 4477482..41ab3e4 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -14,7 +14,7 @@ def test_replaybuffer(size=10, bufsize=20): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) assert len(buf) == min(bufsize, i + 1), print(len(buf), i) - data, indice = buf.sample(4) + data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7dd2ca8..07fdbe4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,3 +1,4 @@ +import torch import numpy as np @@ -8,6 +9,13 @@ class Batch(object): super().__init__() self.__dict__.update(kwargs) + def __getitem__(self, index): + b = Batch() + for k in self.__dict__.keys(): + if self.__dict__[k] is not None: + b.update(k=self.__dict__[k][index]) + return b + def update(self, **kwargs): self.__dict__.update(kwargs) @@ -21,9 +29,12 @@ class Batch(object): elif isinstance(batch.__dict__[k], np.ndarray): self.__dict__[k] = np.concatenate([ self.__dict__[k], batch.__dict__[k]]) + elif isinstance(batch.__dict__[k], torch.Tensor): + self.__dict__[k] = torch.cat([ + self.__dict__[k], batch.__dict__[k]]) elif isinstance(batch.__dict__[k], list): self.__dict__[k] += batch.__dict__[k] else: raise TypeError( - 'Do not support append with type {} in class Batch.' + 'No support for append with type {} in class Batch.' .format(type(batch.__dict__[k]))) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 0cf5168..80fb6f7 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -26,6 +26,12 @@ class ReplayBuffer(object): self.__dict__[name] = np.zeros([self._maxsize]) self.__dict__[name][self._index] = inst + def update(self, buffer): + for i in range(len(buffer)): + self.add( + buffer.obs[i], buffer.act[i], buffer.rew[i], + buffer.done[i], buffer.obs_next[i], buffer.info[i]) + def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): ''' weight: importance weights, disabled here diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index c5da90c..1c7b448 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,3 +1,4 @@ +import torch import numpy as np from copy import deepcopy @@ -9,21 +10,26 @@ from tianshou.utils import MovAvg class Collector(object): """docstring for Collector""" - def __init__(self, policy, env, buffer): + def __init__(self, policy, env, buffer, contiguous=True): super().__init__() self.env = env self.env_num = 1 self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn - self.multi_env = isinstance(env, BaseVectorEnv) - if self.multi_env: + self._multi_env = isinstance(env, BaseVectorEnv) + self._multi_buf = False # buf is a list + # need multiple cache buffers only if contiguous in one buffer + self._cached_buf = [] + if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num,\ '# of data buffer does not match the # of input env.' - elif isinstance(self.buffer, ReplayBuffer): - self.buffer = [deepcopy(buffer) for _ in range(self.env_num)] + self._multi_buf = True + elif isinstance(self.buffer, ReplayBuffer) and contiguous: + self._cached_buf = [ + deepcopy(buffer) for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() @@ -34,7 +40,7 @@ class Collector(object): self.stat_length = MovAvg() def clear_buffer(self): - if self.multi_env: + if self._multi_buf: for b in self.buffer: b.reset() else: @@ -43,17 +49,24 @@ class Collector(object): def reset_env(self): self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None - if self.multi_env: + if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 + for b in self._cached_buf: + b.reset() + + def _make_batch(data): + if isinstance(data, np.ndarray): + return data[None] + else: + return [data] def collect(self, n_step=0, n_episode=0): assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ "One and only one collection number specification permitted!" - cur_step = 0 - cur_episode = np.zeros(self.env_num) if self.multi_env else 0 + cur_step, cur_episode = 0, 0 while True: if self.multi_env: batch_data = Batch( @@ -61,41 +74,55 @@ class Collector(object): done=self._done, obs_next=None, info=self._info) else: batch_data = Batch( - obs=[self._obs], act=[self._act], rew=[self._rew], - done=[self._done], obs_next=None, info=[self._info]) + obs=self._make_batch(self._obs), + act=self._make_batch(self._act), + rew=self._make_batch(self._rew), + done=self._make_batch(self._done), + obs_next=None, info=self._make_batch(self._info)) result = self.policy.act(batch_data, self.state) - self.state = result.state + self.state = result.state if hasattr(result, 'state') else None self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act) - cur_step += 1 self.length += 1 self.reward += self._rew - if self.multi_env: + if self._multi_env: for i in range(self.env_num): - if n_episode > 0 and \ - cur_episode[i] < n_episode or n_episode == 0: - self.buffer[i].add( - self._obs[i], self._act[i], self._rew[i], - self._done[i], obs_next[i], self._info[i]) - if self._done[i]: - cur_episode[i] += 1 - self.stat_reward.add(self.reward[i]) - self.stat_length.add(self.length[i]) - self.reward[i], self.length[i] = 0, 0 - if isinstance(self.state, list): - self.state[i] = None - else: - self.state[i] = self.state[i] * 0 - if hasattr(self.state, 'detach'): - # remove ref in torch - self.state = self.state.detach() - if n_episode > 0 and (cur_episode >= n_episode).all(): + data = { + 'obs': self._obs[i], 'act': self._act[i], + 'rew': self._rew[i], 'done': self._done[i], + 'obs_next': obs_next[i], 'info': self._info[i]} + if self._cached_buf: + self._cached_buf[i].add(**data) + elif self._multi_buf: + self.buffer[i].add(**data) + cur_step += 1 + else: + self.buffer.add(**data) + cur_step += 1 + if self._done[i]: + cur_episode += 1 + self.stat_reward.add(self.reward[i]) + self.stat_length.add(self.length[i]) + self.reward[i], self.length[i] = 0, 0 + if self._cached_buf: + self.buffer.update(self._cached_buf[i]) + cur_step += len(self._cached_buf[i]) + self._cached_buf[i].reset() + if isinstance(self.state, list): + self.state[i] = None + else: + self.state[i] = self.state[i] * 0 + if isinstance(self.state, torch.Tensor): + # remove ref in torch (?) + self.state = self.state.detach() + if n_episode > 0 and cur_episode >= n_episode: break else: self.buffer.add( self._obs, self._act[0], self._rew, self._done, obs_next, self._info) + cur_step += 1 if self._done: cur_episode += 1 self.stat_reward.add(self.reward) @@ -110,7 +137,7 @@ class Collector(object): self._obs = obs_next def sample(self, batch_size): - if self.multi_env: + if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index afbeacd..c94c5e0 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -21,9 +21,11 @@ class BasePolicy(ABC): def reset(self): pass - @staticmethod - def process_fn(batch, buffer, indice): + def process_fn(self, batch, buffer, indice): return batch + def sync_weights(self): + pass + def exploration(self): pass diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py new file mode 100644 index 0000000..3179ba6 --- /dev/null +++ b/tianshou/policy/dqn.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +from copy import deepcopy + +from tianshou.data import Batch +from tianshou.policy import BasePolicy + + +class DQNPolicy(BasePolicy, nn.Module): + """docstring for DQNPolicy""" + + def __init__(self, model, discount_factor=0.99, estimation_step=1, + use_target_network=True): + super().__init__() + self.model = model + self._gamma = discount_factor + self._n_step = estimation_step + self._target = use_target_network + if use_target_network: + self.model_old = deepcopy(self.model) + + def act(self, batch, hidden_state=None): + batch_result = Batch() + return batch_result + + def sync_weights(self): + if self._use_target_network: + for old, new in zip( + self.model_old.parameters(), self.model.parameters()): + old.data.copy_(new.data) + + def process_fn(self, batch, buffer, indice): + return batch diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index 01d0c91..b9e4250 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -1,3 +1,4 @@ +import torch import numpy as np @@ -8,8 +9,7 @@ class MovAvg(object): self.cache = [] def add(self, x): - if hasattr(x, 'detach'): - # which means x is torch.Tensor (?) + if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if x != np.inf: self.cache.append(x)