half of collector
This commit is contained in:
parent
4a1a7dd670
commit
f58c1397c6
@ -1,3 +1,4 @@
|
||||
# Tianshou
|
||||
|
||||

|
||||
|
||||
|
12
setup.py
12
setup.py
@ -20,14 +20,12 @@ setup(
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
'Development Status :: 3 - Alpha',
|
||||
|
||||
# Indicate who your project is intended for
|
||||
'Intended Audience :: Science/Research',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
# Pick your license as you wish (should match "license" above)
|
||||
'License :: OSI Approved :: MIT License',
|
||||
|
||||
# Specify the Python versions you support here. In particular, ensure
|
||||
# that you indicate whether you support Python 2, Python 3 or both.
|
||||
'Programming Language :: Python :: 3.6',
|
||||
@ -35,18 +33,16 @@ setup(
|
||||
'Programming Language :: Python :: 3.8',
|
||||
],
|
||||
keywords='reinforcement learning platform',
|
||||
# You can just specify the packages manually here if your project is
|
||||
# simple. Or you can use find_packages().
|
||||
packages=find_packages(exclude=['tests', 'tests.*',
|
||||
'examples', 'examples.*',
|
||||
'docs', 'docs.*']),
|
||||
install_requires=[
|
||||
'numpy',
|
||||
'torch',
|
||||
'tensorboard',
|
||||
'tqdm',
|
||||
# 'ray',
|
||||
'gym',
|
||||
'tqdm',
|
||||
'numpy',
|
||||
'torch',
|
||||
'cloudpickle'
|
||||
'tensorboard',
|
||||
],
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou.data import ReplayBuffer
|
||||
if __name__ == '__main__':
|
||||
from test_env import MyTestEnv
|
||||
else:
|
||||
else: # pytest
|
||||
from test.test_env import MyTestEnv
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from tianshou.data.batch import Batch
|
||||
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.collector import Collector
|
||||
|
||||
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer']
|
||||
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector']
|
||||
|
@ -2,4 +2,8 @@ class Batch(object):
|
||||
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.obs_next = None
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
@ -40,12 +40,20 @@ class ReplayBuffer(object):
|
||||
|
||||
def reset(self):
|
||||
self._index = self._size = 0
|
||||
self.indice = []
|
||||
|
||||
def sample_indice(self, batch_size):
|
||||
return np.random.choice(self._size, batch_size)
|
||||
if batch_size > 0:
|
||||
self.indice = np.random.choice(self._size, batch_size)
|
||||
else:
|
||||
self.indice = np.arange(self._size)
|
||||
return self.indice
|
||||
|
||||
def sample(self, batch_size):
|
||||
indice = self.sample_indice(batch_size)
|
||||
def sample(self, batch_size, indice=None):
|
||||
if indice is None:
|
||||
indice = self.sample_indice(batch_size)
|
||||
else:
|
||||
self.indice = indice
|
||||
return Batch(
|
||||
obs=self.obs[indice],
|
||||
act=self.act[indice],
|
||||
|
86
tianshou/data/collector.py
Normal file
86
tianshou/data/collector.py
Normal file
@ -0,0 +1,86 @@
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.utils import MovAvg
|
||||
|
||||
class Collector(object):
|
||||
"""docstring for Collector"""
|
||||
def __init__(self, policy, env, buffer):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
self.multi_env = isinstance(env, BaseVectorEnv)
|
||||
if self.multi_env:
|
||||
self.env_num = len(env)
|
||||
if isinstance(self.buffer, list):
|
||||
assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.'
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
||||
else:
|
||||
raise TypeError('The buffer in data collector is invalid!')
|
||||
self.reset_env()
|
||||
self.clear_buffer()
|
||||
# state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape')
|
||||
self.state = None
|
||||
|
||||
def clear_buffer(self):
|
||||
if self.multi_env:
|
||||
for b in self.buffer:
|
||||
b.reset()
|
||||
else:
|
||||
self.buffer.reset()
|
||||
|
||||
def reset_env(self):
|
||||
self._obs = self.env.reset()
|
||||
self._act = self._rew = self._done = self._info = None
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, tqdm_hook=None):
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!"
|
||||
cur_step = 0
|
||||
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
|
||||
while True:
|
||||
if self.multi_env:
|
||||
batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info)
|
||||
else:
|
||||
batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info])
|
||||
result = self.policy.act(batch_data, self.state)
|
||||
self.state = result.state
|
||||
self._act = result.act
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(self._act)
|
||||
cur_step += 1
|
||||
if self.multi_env:
|
||||
for i in range(self.env_num):
|
||||
if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0:
|
||||
self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i])
|
||||
if self._done[i]:
|
||||
cur_episode[i] += 1
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
else:
|
||||
self.state[i] = self.state[i] * 0
|
||||
if hasattr(self.state, 'detach'): # remove count in torch
|
||||
self.state = self.state.detach()
|
||||
if n_episode > 0 and (cur_episode >= n_episode).all():
|
||||
break
|
||||
else:
|
||||
self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info)
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
self.state = None
|
||||
if n_episode > 0 and cur_episode >= n_episode:
|
||||
break
|
||||
if n_step > 0 and cur_step >= n_step:
|
||||
break
|
||||
self._obs = obs_next
|
||||
self._obs = obs_next
|
||||
|
||||
def sample(self):
|
||||
pass
|
||||
|
||||
def stat(self):
|
||||
pass
|
4
tianshou/env/__init__.py
vendored
4
tianshou/env/__init__.py
vendored
@ -1,3 +1,3 @@
|
||||
from tianshou.env.wrapper import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
|
||||
__all__ = ['FrameStack', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv']
|
||||
__all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv']
|
||||
|
35
tianshou/env/wrapper.py
vendored
35
tianshou/env/wrapper.py
vendored
@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, Pipe
|
||||
try:
|
||||
import ray
|
||||
@ -56,13 +57,18 @@ class FrameStack(EnvWrapper):
|
||||
return np.stack(self._frames, axis=-1)
|
||||
|
||||
|
||||
class VectorEnv(object):
|
||||
class BaseVectorEnv(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""docstring for VectorEnv"""
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__()
|
||||
self.envs = [_() for _ in env_fns]
|
||||
self.env_num = len(self.envs)
|
||||
self._reset_after_done = kwargs.get('reset_after_done', False)
|
||||
self._reset_after_done = reset_after_done
|
||||
|
||||
def __len__(self):
|
||||
return len(self.envs)
|
||||
@ -97,8 +103,7 @@ class VectorEnv(object):
|
||||
e.close()
|
||||
|
||||
|
||||
def worker(parent, p, env_fn_wrapper, kwargs):
|
||||
reset_after_done = kwargs.get('reset_after_done', True)
|
||||
def worker(parent, p, env_fn_wrapper, reset_after_done):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
while True:
|
||||
@ -115,22 +120,22 @@ def worker(parent, p, env_fn_wrapper, kwargs):
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render())
|
||||
p.send(env.render() if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data))
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SubprocVectorEnv(object):
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""docstring for SubProcVectorEnv"""
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__()
|
||||
self.env_num = len(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
|
||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
for p in self.processes:
|
||||
@ -178,12 +183,12 @@ class SubprocVectorEnv(object):
|
||||
p.join()
|
||||
|
||||
|
||||
class RayVectorEnv(object):
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""docstring for RayVectorEnv"""
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__()
|
||||
self.env_num = len(env_fns)
|
||||
self._reset_after_done = kwargs.get('reset_after_done', False)
|
||||
self._reset_after_done = reset_after_done
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
@ -213,6 +218,8 @@ class RayVectorEnv(object):
|
||||
return np.stack([ray.get(r) for r in result_obj])
|
||||
|
||||
def seed(self, seed=None):
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||
@ -220,6 +227,8 @@ class RayVectorEnv(object):
|
||||
ray.get(r)
|
||||
|
||||
def render(self):
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return
|
||||
result_obj = [e.render.remote() for e in self.envs]
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
|
3
tianshou/policy/__init__.py
Normal file
3
tianshou/policy/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
__all__ = ['BasePolicy']
|
28
tianshou/policy/base.py
Normal file
28
tianshou/policy/base.py
Normal file
@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePolicy(ABC):
|
||||
"""docstring for BasePolicy"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def act(self, batch, hidden_state=None):
|
||||
# return {policy, action, hidden}
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
pass
|
||||
|
||||
def eval(self):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def process_fn(batch, buffer, index):
|
||||
pass
|
||||
|
||||
def exploration(self):
|
||||
pass
|
0
tianshou/policy/reward_processor.py
Normal file
0
tianshou/policy/reward_processor.py
Normal file
@ -1,4 +1,9 @@
|
||||
from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper
|
||||
from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.moving_average import MovAvg
|
||||
from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper
|
||||
|
||||
__all__ = ['CloudpickleWrapper', 'tqdm_config']
|
||||
__all__ = [
|
||||
'CloudpickleWrapper',
|
||||
'tqdm_config',
|
||||
'MovAvg'
|
||||
]
|
||||
|
23
tianshou/utils/moving_average.py
Normal file
23
tianshou/utils/moving_average.py
Normal file
@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
def __init__(self, size=100):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache = []
|
||||
|
||||
def add(self, x):
|
||||
if hasattr(x, 'detach'):
|
||||
# which means x is torch.Tensor (?)
|
||||
x = x.detach().cpu().numpy()
|
||||
if x != np.inf:
|
||||
self.cache.append(x)
|
||||
if self.size > 0 and len(self.cache) > self.size:
|
||||
self.cache = self.cache[-self.size:]
|
||||
return self.get()
|
||||
|
||||
def get(self):
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
Loading…
x
Reference in New Issue
Block a user