half of collector

This commit is contained in:
Trinkle23897 2020-03-12 22:20:33 +08:00
parent 4a1a7dd670
commit f58c1397c6
14 changed files with 194 additions and 30 deletions

View File

@ -1,3 +1,4 @@
# Tianshou
![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg)

View File

@ -20,14 +20,12 @@ setup(
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
# Indicate who your project is intended for
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries :: Python Modules',
# Pick your license as you wish (should match "license" above)
'License :: OSI Approved :: MIT License',
# Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both.
'Programming Language :: Python :: 3.6',
@ -35,18 +33,16 @@ setup(
'Programming Language :: Python :: 3.8',
],
keywords='reinforcement learning platform',
# You can just specify the packages manually here if your project is
# simple. Or you can use find_packages().
packages=find_packages(exclude=['tests', 'tests.*',
'examples', 'examples.*',
'docs', 'docs.*']),
install_requires=[
'numpy',
'torch',
'tensorboard',
'tqdm',
# 'ray',
'gym',
'tqdm',
'numpy',
'torch',
'cloudpickle'
'tensorboard',
],
)

View File

@ -1,7 +1,7 @@
from tianshou.data import ReplayBuffer
if __name__ == '__main__':
from test_env import MyTestEnv
else:
else: # pytest
from test.test_env import MyTestEnv

View File

@ -1,4 +1,5 @@
from tianshou.data.batch import Batch
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer']
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector']

View File

@ -2,4 +2,8 @@ class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
def __init__(self, **kwargs):
super().__init__()
self.obs_next = None
self.__dict__.update(kwargs)
def update(self, **kwargs):
self.__dict__.update(kwargs)

View File

@ -40,12 +40,20 @@ class ReplayBuffer(object):
def reset(self):
self._index = self._size = 0
self.indice = []
def sample_indice(self, batch_size):
return np.random.choice(self._size, batch_size)
if batch_size > 0:
self.indice = np.random.choice(self._size, batch_size)
else:
self.indice = np.arange(self._size)
return self.indice
def sample(self, batch_size):
indice = self.sample_indice(batch_size)
def sample(self, batch_size, indice=None):
if indice is None:
indice = self.sample_indice(batch_size)
else:
self.indice = indice
return Batch(
obs=self.obs[indice],
act=self.act[indice],

View File

@ -0,0 +1,86 @@
import numpy as np
from copy import deepcopy
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer
from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
def __init__(self, policy, env, buffer):
super().__init__()
self.env = env
self.env_num = 1
self.buffer = buffer
self.policy = policy
self.process_fn = policy.process_fn
self.multi_env = isinstance(env, BaseVectorEnv)
if self.multi_env:
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.'
elif isinstance(self.buffer, ReplayBuffer):
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()
self.clear_buffer()
# state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape')
self.state = None
def clear_buffer(self):
if self.multi_env:
for b in self.buffer:
b.reset()
else:
self.buffer.reset()
def reset_env(self):
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None
def collect(self, n_step=0, n_episode=0, tqdm_hook=None):
assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
while True:
if self.multi_env:
batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info)
else:
batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info])
result = self.policy.act(batch_data, self.state)
self.state = result.state
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(self._act)
cur_step += 1
if self.multi_env:
for i in range(self.env_num):
if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0:
self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i])
if self._done[i]:
cur_episode[i] += 1
if isinstance(self.state, list):
self.state[i] = None
else:
self.state[i] = self.state[i] * 0
if hasattr(self.state, 'detach'): # remove count in torch
self.state = self.state.detach()
if n_episode > 0 and (cur_episode >= n_episode).all():
break
else:
self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info)
if self._done:
cur_episode += 1
self.state = None
if n_episode > 0 and cur_episode >= n_episode:
break
if n_step > 0 and cur_step >= n_step:
break
self._obs = obs_next
self._obs = obs_next
def sample(self):
pass
def stat(self):
pass

View File

@ -1,3 +1,3 @@
from tianshou.env.wrapper import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv
__all__ = ['FrameStack', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv']
__all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv']

View File

@ -1,5 +1,6 @@
import numpy as np
from collections import deque
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe
try:
import ray
@ -56,13 +57,18 @@ class FrameStack(EnvWrapper):
return np.stack(self._frames, axis=-1)
class VectorEnv(object):
class BaseVectorEnv(ABC):
def __init__(self):
pass
class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv"""
def __init__(self, env_fns, **kwargs):
def __init__(self, env_fns, reset_after_done=False):
super().__init__()
self.envs = [_() for _ in env_fns]
self.env_num = len(self.envs)
self._reset_after_done = kwargs.get('reset_after_done', False)
self._reset_after_done = reset_after_done
def __len__(self):
return len(self.envs)
@ -97,8 +103,7 @@ class VectorEnv(object):
e.close()
def worker(parent, p, env_fn_wrapper, kwargs):
reset_after_done = kwargs.get('reset_after_done', True)
def worker(parent, p, env_fn_wrapper, reset_after_done):
parent.close()
env = env_fn_wrapper.data()
while True:
@ -115,22 +120,22 @@ def worker(parent, p, env_fn_wrapper, kwargs):
p.close()
break
elif cmd == 'render':
p.send(env.render())
p.send(env.render() if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data))
p.send(env.seed(data) if hasattr(env, 'seed') else None)
else:
raise NotImplementedError
class SubprocVectorEnv(object):
class SubprocVectorEnv(BaseVectorEnv):
"""docstring for SubProcVectorEnv"""
def __init__(self, env_fns, **kwargs):
def __init__(self, env_fns, reset_after_done=False):
super().__init__()
self.env_num = len(env_fns)
self.closed = False
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
@ -178,12 +183,12 @@ class SubprocVectorEnv(object):
p.join()
class RayVectorEnv(object):
class RayVectorEnv(BaseVectorEnv):
"""docstring for RayVectorEnv"""
def __init__(self, env_fns, **kwargs):
def __init__(self, env_fns, reset_after_done=False):
super().__init__()
self.env_num = len(env_fns)
self._reset_after_done = kwargs.get('reset_after_done', False)
self._reset_after_done = reset_after_done
try:
if not ray.is_initialized():
ray.init()
@ -213,6 +218,8 @@ class RayVectorEnv(object):
return np.stack([ray.get(r) for r in result_obj])
def seed(self, seed=None):
if not hasattr(self.envs[0], 'seed'):
return
if np.isscalar(seed) or seed is None:
seed = [seed for _ in range(self.env_num)]
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
@ -220,6 +227,8 @@ class RayVectorEnv(object):
ray.get(r)
def render(self):
if not hasattr(self.envs[0], 'render'):
return
result_obj = [e.render.remote() for e in self.envs]
for r in result_obj:
ray.get(r)

View File

@ -0,0 +1,3 @@
from tianshou.policy import BasePolicy
__all__ = ['BasePolicy']

28
tianshou/policy/base.py Normal file
View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
class BasePolicy(ABC):
"""docstring for BasePolicy"""
def __init__(self):
super().__init__()
@abstractmethod
def act(self, batch, hidden_state=None):
# return {policy, action, hidden}
pass
def train(self):
pass
def eval(self):
pass
def reset(self):
pass
@staticmethod
def process_fn(batch, buffer, index):
pass
def exploration(self):
pass

View File

View File

@ -1,4 +1,9 @@
from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper
from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg
from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper
__all__ = ['CloudpickleWrapper', 'tqdm_config']
__all__ = [
'CloudpickleWrapper',
'tqdm_config',
'MovAvg'
]

View File

@ -0,0 +1,23 @@
import numpy as np
class MovAvg(object):
def __init__(self, size=100):
super().__init__()
self.size = size
self.cache = []
def add(self, x):
if hasattr(x, 'detach'):
# which means x is torch.Tensor (?)
x = x.detach().cpu().numpy()
if x != np.inf:
self.cache.append(x)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
def get(self):
if len(self.cache) == 0:
return 0
return np.mean(self.cache)