half of collector

This commit is contained in:
Trinkle23897 2020-03-12 22:20:33 +08:00
parent 4a1a7dd670
commit f58c1397c6
14 changed files with 194 additions and 30 deletions

View File

@ -1,3 +1,4 @@
# Tianshou # Tianshou
![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg) ![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg)

View File

@ -20,14 +20,12 @@ setup(
# 4 - Beta # 4 - Beta
# 5 - Production/Stable # 5 - Production/Stable
'Development Status :: 3 - Alpha', 'Development Status :: 3 - Alpha',
# Indicate who your project is intended for # Indicate who your project is intended for
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
# Pick your license as you wish (should match "license" above) # Pick your license as you wish (should match "license" above)
'License :: OSI Approved :: MIT License', 'License :: OSI Approved :: MIT License',
# Specify the Python versions you support here. In particular, ensure # Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both. # that you indicate whether you support Python 2, Python 3 or both.
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
@ -35,18 +33,16 @@ setup(
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
], ],
keywords='reinforcement learning platform', keywords='reinforcement learning platform',
# You can just specify the packages manually here if your project is
# simple. Or you can use find_packages().
packages=find_packages(exclude=['tests', 'tests.*', packages=find_packages(exclude=['tests', 'tests.*',
'examples', 'examples.*', 'examples', 'examples.*',
'docs', 'docs.*']), 'docs', 'docs.*']),
install_requires=[ install_requires=[
'numpy',
'torch',
'tensorboard',
'tqdm',
# 'ray', # 'ray',
'gym', 'gym',
'tqdm',
'numpy',
'torch',
'cloudpickle' 'cloudpickle'
'tensorboard',
], ],
) )

View File

@ -1,7 +1,7 @@
from tianshou.data import ReplayBuffer from tianshou.data import ReplayBuffer
if __name__ == '__main__': if __name__ == '__main__':
from test_env import MyTestEnv from test_env import MyTestEnv
else: else: # pytest
from test.test_env import MyTestEnv from test.test_env import MyTestEnv

View File

@ -1,4 +1,5 @@
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer'] __all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector']

View File

@ -2,4 +2,8 @@ class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]""" """Suggested keys: [obs, act, rew, done, obs_next, info]"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
self.obs_next = None
self.__dict__.update(kwargs)
def update(self, **kwargs):
self.__dict__.update(kwargs) self.__dict__.update(kwargs)

View File

@ -40,12 +40,20 @@ class ReplayBuffer(object):
def reset(self): def reset(self):
self._index = self._size = 0 self._index = self._size = 0
self.indice = []
def sample_indice(self, batch_size): def sample_indice(self, batch_size):
return np.random.choice(self._size, batch_size) if batch_size > 0:
self.indice = np.random.choice(self._size, batch_size)
else:
self.indice = np.arange(self._size)
return self.indice
def sample(self, batch_size): def sample(self, batch_size, indice=None):
indice = self.sample_indice(batch_size) if indice is None:
indice = self.sample_indice(batch_size)
else:
self.indice = indice
return Batch( return Batch(
obs=self.obs[indice], obs=self.obs[indice],
act=self.act[indice], act=self.act[indice],

View File

@ -0,0 +1,86 @@
import numpy as np
from copy import deepcopy
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer
from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
def __init__(self, policy, env, buffer):
super().__init__()
self.env = env
self.env_num = 1
self.buffer = buffer
self.policy = policy
self.process_fn = policy.process_fn
self.multi_env = isinstance(env, BaseVectorEnv)
if self.multi_env:
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.'
elif isinstance(self.buffer, ReplayBuffer):
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()
self.clear_buffer()
# state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape')
self.state = None
def clear_buffer(self):
if self.multi_env:
for b in self.buffer:
b.reset()
else:
self.buffer.reset()
def reset_env(self):
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None
def collect(self, n_step=0, n_episode=0, tqdm_hook=None):
assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
while True:
if self.multi_env:
batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info)
else:
batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info])
result = self.policy.act(batch_data, self.state)
self.state = result.state
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(self._act)
cur_step += 1
if self.multi_env:
for i in range(self.env_num):
if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0:
self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i])
if self._done[i]:
cur_episode[i] += 1
if isinstance(self.state, list):
self.state[i] = None
else:
self.state[i] = self.state[i] * 0
if hasattr(self.state, 'detach'): # remove count in torch
self.state = self.state.detach()
if n_episode > 0 and (cur_episode >= n_episode).all():
break
else:
self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info)
if self._done:
cur_episode += 1
self.state = None
if n_episode > 0 and cur_episode >= n_episode:
break
if n_step > 0 and cur_step >= n_step:
break
self._obs = obs_next
self._obs = obs_next
def sample(self):
pass
def stat(self):
pass

View File

@ -1,3 +1,3 @@
from tianshou.env.wrapper import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv
__all__ = ['FrameStack', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv'] __all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv']

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
from collections import deque from collections import deque
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
try: try:
import ray import ray
@ -56,13 +57,18 @@ class FrameStack(EnvWrapper):
return np.stack(self._frames, axis=-1) return np.stack(self._frames, axis=-1)
class VectorEnv(object): class BaseVectorEnv(ABC):
def __init__(self):
pass
class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv""" """docstring for VectorEnv"""
def __init__(self, env_fns, **kwargs): def __init__(self, env_fns, reset_after_done=False):
super().__init__() super().__init__()
self.envs = [_() for _ in env_fns] self.envs = [_() for _ in env_fns]
self.env_num = len(self.envs) self.env_num = len(self.envs)
self._reset_after_done = kwargs.get('reset_after_done', False) self._reset_after_done = reset_after_done
def __len__(self): def __len__(self):
return len(self.envs) return len(self.envs)
@ -97,8 +103,7 @@ class VectorEnv(object):
e.close() e.close()
def worker(parent, p, env_fn_wrapper, kwargs): def worker(parent, p, env_fn_wrapper, reset_after_done):
reset_after_done = kwargs.get('reset_after_done', True)
parent.close() parent.close()
env = env_fn_wrapper.data() env = env_fn_wrapper.data()
while True: while True:
@ -115,22 +120,22 @@ def worker(parent, p, env_fn_wrapper, kwargs):
p.close() p.close()
break break
elif cmd == 'render': elif cmd == 'render':
p.send(env.render()) p.send(env.render() if hasattr(env, 'render') else None)
elif cmd == 'seed': elif cmd == 'seed':
p.send(env.seed(data)) p.send(env.seed(data) if hasattr(env, 'seed') else None)
else: else:
raise NotImplementedError raise NotImplementedError
class SubprocVectorEnv(object): class SubprocVectorEnv(BaseVectorEnv):
"""docstring for SubProcVectorEnv""" """docstring for SubProcVectorEnv"""
def __init__(self, env_fns, **kwargs): def __init__(self, env_fns, reset_after_done=False):
super().__init__() super().__init__()
self.env_num = len(env_fns) self.env_num = len(env_fns)
self.closed = False self.closed = False
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [ self.processes = [
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
] ]
for p in self.processes: for p in self.processes:
@ -178,12 +183,12 @@ class SubprocVectorEnv(object):
p.join() p.join()
class RayVectorEnv(object): class RayVectorEnv(BaseVectorEnv):
"""docstring for RayVectorEnv""" """docstring for RayVectorEnv"""
def __init__(self, env_fns, **kwargs): def __init__(self, env_fns, reset_after_done=False):
super().__init__() super().__init__()
self.env_num = len(env_fns) self.env_num = len(env_fns)
self._reset_after_done = kwargs.get('reset_after_done', False) self._reset_after_done = reset_after_done
try: try:
if not ray.is_initialized(): if not ray.is_initialized():
ray.init() ray.init()
@ -213,6 +218,8 @@ class RayVectorEnv(object):
return np.stack([ray.get(r) for r in result_obj]) return np.stack([ray.get(r) for r in result_obj])
def seed(self, seed=None): def seed(self, seed=None):
if not hasattr(self.envs[0], 'seed'):
return
if np.isscalar(seed) or seed is None: if np.isscalar(seed) or seed is None:
seed = [seed for _ in range(self.env_num)] seed = [seed for _ in range(self.env_num)]
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)] result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
@ -220,6 +227,8 @@ class RayVectorEnv(object):
ray.get(r) ray.get(r)
def render(self): def render(self):
if not hasattr(self.envs[0], 'render'):
return
result_obj = [e.render.remote() for e in self.envs] result_obj = [e.render.remote() for e in self.envs]
for r in result_obj: for r in result_obj:
ray.get(r) ray.get(r)

View File

@ -0,0 +1,3 @@
from tianshou.policy import BasePolicy
__all__ = ['BasePolicy']

28
tianshou/policy/base.py Normal file
View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
class BasePolicy(ABC):
"""docstring for BasePolicy"""
def __init__(self):
super().__init__()
@abstractmethod
def act(self, batch, hidden_state=None):
# return {policy, action, hidden}
pass
def train(self):
pass
def eval(self):
pass
def reset(self):
pass
@staticmethod
def process_fn(batch, buffer, index):
pass
def exploration(self):
pass

View File

View File

@ -1,4 +1,9 @@
from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper
from tianshou.utils.config import tqdm_config from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg
from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper
__all__ = ['CloudpickleWrapper', 'tqdm_config'] __all__ = [
'CloudpickleWrapper',
'tqdm_config',
'MovAvg'
]

View File

@ -0,0 +1,23 @@
import numpy as np
class MovAvg(object):
def __init__(self, size=100):
super().__init__()
self.size = size
self.cache = []
def add(self, x):
if hasattr(x, 'detach'):
# which means x is torch.Tensor (?)
x = x.detach().cpu().numpy()
if x != np.inf:
self.cache.append(x)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
def get(self):
if len(self.cache) == 0:
return 0
return np.mean(self.cache)