maybe finished collector?

This commit is contained in:
Trinkle23897 2020-03-13 17:49:22 +08:00
parent f58c1397c6
commit f16e05c0e7
15 changed files with 165 additions and 49 deletions

View File

@ -33,7 +33,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
- name: Test with pytest - name: Test with pytest
run: | run: |
pip install pytest pytest-cov pip install pytest pytest-cov

View File

@ -1,3 +1,3 @@
#!/bin/sh #!/bin/sh
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics

View File

@ -42,7 +42,7 @@ setup(
'tqdm', 'tqdm',
'numpy', 'numpy',
'torch', 'torch',
'cloudpickle' 'cloudpickle',
'tensorboard', 'tensorboard',
], ],
) )

16
test/test_batch.py Normal file
View File

@ -0,0 +1,16 @@
import numpy as np
from tianshou.data import Batch
def test_batch():
batch = Batch(obs=[0], np=np.zeros([3, 4]))
batch.update(obs=[1])
assert batch.obs == [1]
batch.append(batch)
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
if __name__ == '__main__':
test_batch()

View File

@ -14,8 +14,7 @@ def test_replaybuffer(size=10, bufsize=20):
obs_next, rew, done, info = env.step(a) obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info) buf.add(obs, a, rew, done, obs_next, info)
assert len(buf) == min(bufsize, i + 1), print(len(buf), i) assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
indice = buf.sample_indice(4) data, indice = buf.sample(4)
data = buf.sample(4)
assert (indice < len(buf)).all() assert (indice < len(buf)).all()
assert (data.obs < size).all() assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all() assert (0 <= data.done).all() and (data.done <= 1).all()

View File

@ -79,7 +79,9 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
for a in action_list: for a in action_list:
e.step([a] * num) e.step([a] * num)
t[i] = time.time() - t[i] t[i] = time.time() - t[i]
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s') print(f'VectorEnv: {t[0]:.6f}s')
print(f'SubprocVectorEnv: {t[1]:.6f}s')
print(f'RayVectorEnv: {t[2]:.6f}s')
for v in venv: for v in venv:
v.close() v.close()

View File

@ -1,4 +1,9 @@
from tianshou import data, env, utils from tianshou import data, env, utils, policy
__version__ = '0.2.0' __version__ = '0.2.0'
__all__ = ['data', 'env', 'utils'] __all__ = [
'data',
'env',
'utils',
'policy'
]

View File

@ -2,4 +2,9 @@ from tianshou.data.batch import Batch
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector from tianshou.data.collector import Collector
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector'] __all__ = [
'Batch',
'ReplayBuffer',
'PrioritizedReplayBuffer',
'Collector'
]

View File

@ -1,9 +1,29 @@
import numpy as np
class Batch(object): class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]""" """Suggested keys: [obs, act, rew, done, obs_next, info]"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
self.obs_next = None
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
def update(self, **kwargs): def update(self, **kwargs):
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
def append(self, batch):
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = batch.__dict__[k]
elif isinstance(batch.__dict__[k], np.ndarray):
self.__dict__[k] = np.concatenate([
self.__dict__[k], batch.__dict__[k]])
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
raise TypeError(
'No support append method with {} in class Batch.'
.format(type(batch.__dict__[k])))

View File

@ -4,6 +4,7 @@ from tianshou.data.batch import Batch
class ReplayBuffer(object): class ReplayBuffer(object):
"""docstring for ReplayBuffer""" """docstring for ReplayBuffer"""
def __init__(self, size): def __init__(self, size):
super().__init__() super().__init__()
self._maxsize = size self._maxsize = size
@ -19,7 +20,8 @@ class ReplayBuffer(object):
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict): elif isinstance(inst, dict):
self.__dict__[name] = np.array([{} for _ in range(self._maxsize)]) self.__dict__[name] = np.array(
[{} for _ in range(self._maxsize)])
else: # assume `inst` is a number else: # assume `inst` is a number
self.__dict__[name] = np.zeros([self._maxsize]) self.__dict__[name] = np.zeros([self._maxsize])
self.__dict__[name][self._index] = inst self.__dict__[name][self._index] = inst
@ -28,7 +30,8 @@ class ReplayBuffer(object):
''' '''
weight: importance weights, disabled here weight: importance weights, disabled here
''' '''
assert isinstance(info, dict), 'You should return a dict in the last argument of env.step function.' assert isinstance(info, dict),\
'You should return a dict in the last argument of env.step().'
self._add_to_buffer('obs', obs) self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act) self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew) self._add_to_buffer('rew', rew)
@ -42,18 +45,11 @@ class ReplayBuffer(object):
self._index = self._size = 0 self._index = self._size = 0
self.indice = [] self.indice = []
def sample_indice(self, batch_size): def sample(self, batch_size):
if batch_size > 0: if batch_size > 0:
self.indice = np.random.choice(self._size, batch_size) indice = np.random.choice(self._size, batch_size)
else: else:
self.indice = np.arange(self._size) indice = np.arange(self._size)
return self.indice
def sample(self, batch_size, indice=None):
if indice is None:
indice = self.sample_indice(batch_size)
else:
self.indice = indice
return Batch( return Batch(
obs=self.obs[indice], obs=self.obs[indice],
act=self.act[indice], act=self.act[indice],
@ -61,11 +57,12 @@ class ReplayBuffer(object):
done=self.done[indice], done=self.done[indice],
obs_next=self.obs_next[indice], obs_next=self.obs_next[indice],
info=self.info[indice] info=self.info[indice]
) ), indice
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer""" """docstring for PrioritizedReplayBuffer"""
def __init__(self, size): def __init__(self, size):
super().__init__(size) super().__init__(size)

View File

@ -5,8 +5,10 @@ from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
from tianshou.utils import MovAvg from tianshou.utils import MovAvg
class Collector(object): class Collector(object):
"""docstring for Collector""" """docstring for Collector"""
def __init__(self, policy, env, buffer): def __init__(self, policy, env, buffer):
super().__init__() super().__init__()
self.env = env self.env = env
@ -18,15 +20,18 @@ class Collector(object):
if self.multi_env: if self.multi_env:
self.env_num = len(env) self.env_num = len(env)
if isinstance(self.buffer, list): if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.' assert len(self.buffer) == self.env_num,\
'Data buffer number does not match the input env number.'
elif isinstance(self.buffer, ReplayBuffer): elif isinstance(self.buffer, ReplayBuffer):
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)] self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
else: else:
raise TypeError('The buffer in data collector is invalid!') raise TypeError('The buffer in data collector is invalid!')
self.reset_env() self.reset_env()
self.clear_buffer() self.clear_buffer()
# state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape') # state over batch is either a list, an np.ndarray, or torch.Tensor
self.state = None self.state = None
self.stat_reward = MovAvg()
self.stat_length = MovAvg()
def clear_buffer(self): def clear_buffer(self):
if self.multi_env: if self.multi_env:
@ -38,39 +43,64 @@ class Collector(object):
def reset_env(self): def reset_env(self):
self._obs = self.env.reset() self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None self._act = self._rew = self._done = self._info = None
if self.multi_env:
self.reward = np.zeros(self.env_num)
self.length = np.zeros(self.env_num)
else:
self.reward, self.length = 0, 0
def collect(self, n_step=0, n_episode=0, tqdm_hook=None): def collect(self, n_step=0, n_episode=0):
assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!" assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!"
cur_step = 0 cur_step = 0
cur_episode = np.zeros(self.env_num) if self.multi_env else 0 cur_episode = np.zeros(self.env_num) if self.multi_env else 0
while True: while True:
if self.multi_env: if self.multi_env:
batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info) batch_data = Batch(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info)
else: else:
batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info]) batch_data = Batch(
obs=[self._obs], act=[self._act], rew=[self._rew],
done=[self._done], obs_next=None, info=[self._info])
result = self.policy.act(batch_data, self.state) result = self.policy.act(batch_data, self.state)
self.state = result.state self.state = result.state
self._act = result.act self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(self._act) obs_next, self._rew, self._done, self._info = self.env.step(
self._act)
cur_step += 1 cur_step += 1
self.length += 1
self.reward += self._rew
if self.multi_env: if self.multi_env:
for i in range(self.env_num): for i in range(self.env_num):
if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0: if n_episode > 0 and \
self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i]) cur_episode[i] < n_episode or n_episode == 0:
self.buffer[i].add(
self._obs[i], self._act[i], self._rew[i],
self._done[i], obs_next[i], self._info[i])
if self._done[i]: if self._done[i]:
cur_episode[i] += 1 cur_episode[i] += 1
self.stat_reward.add(self.reward[i])
self.stat_length.add(self.length[i])
self.reward[i], self.length[i] = 0, 0
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[i] = None self.state[i] = None
else: else:
self.state[i] = self.state[i] * 0 self.state[i] = self.state[i] * 0
if hasattr(self.state, 'detach'): # remove count in torch if hasattr(self.state, 'detach'):
# remove ref in torch
self.state = self.state.detach() self.state = self.state.detach()
if n_episode > 0 and (cur_episode >= n_episode).all(): if n_episode > 0 and (cur_episode >= n_episode).all():
break break
else: else:
self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info) self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
if self._done: if self._done:
cur_episode += 1 cur_episode += 1
self.stat_reward.add(self.reward)
self.stat_length.add(self.length)
self.reward, self.length = 0, 0
self.state = None self.state = None
if n_episode > 0 and cur_episode >= n_episode: if n_episode > 0 and cur_episode >= n_episode:
break break
@ -79,8 +109,29 @@ class Collector(object):
self._obs = obs_next self._obs = obs_next
self._obs = obs_next self._obs = obs_next
def sample(self): def sample(self, batch_size):
pass if self.multi_env:
if batch_size > 0:
lens = [len(b) for b in self.buffer]
total = sum(lens)
ib = np.random.choice(
total, batch_size, p=np.array(lens) / total)
else:
ib = np.array([])
batch_data = Batch()
for i, b in enumerate(self.buffer):
cur_batch = (ib == i).sum()
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)
batch_data.append(batch)
else:
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data
def stat(self): def stat(self):
pass return {
'reward': self.stat_reward.get(),
'length': self.stat_length.get(),
}

View File

@ -1,3 +1,11 @@
from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv from tianshou.env.wrapper import FrameStack,\
BaseVectorEnv, VectorEnv, SubprocVectorEnv,\
RayVectorEnv
__all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv'] __all__ = [
'FrameStack',
'BaseVectorEnv',
'VectorEnv',
'SubprocVectorEnv',
'RayVectorEnv'
]

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from abc import ABC
from collections import deque from collections import deque
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
try: try:
import ray import ray
@ -64,6 +64,7 @@ class BaseVectorEnv(ABC):
class VectorEnv(BaseVectorEnv): class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv""" """docstring for VectorEnv"""
def __init__(self, env_fns, reset_after_done=False): def __init__(self, env_fns, reset_after_done=False):
super().__init__() super().__init__()
self.envs = [_() for _ in env_fns] self.envs = [_() for _ in env_fns]
@ -129,14 +130,19 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
class SubprocVectorEnv(BaseVectorEnv): class SubprocVectorEnv(BaseVectorEnv):
"""docstring for SubProcVectorEnv""" """docstring for SubProcVectorEnv"""
def __init__(self, env_fns, reset_after_done=False): def __init__(self, env_fns, reset_after_done=False):
super().__init__() super().__init__()
self.env_num = len(env_fns) self.env_num = len(env_fns)
self.closed = False self.closed = False
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [ self.processes = [
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True) Process(target=worker, args=(
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns) parent, child,
CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)
] ]
for p in self.processes: for p in self.processes:
p.start() p.start()
@ -185,6 +191,7 @@ class SubprocVectorEnv(BaseVectorEnv):
class RayVectorEnv(BaseVectorEnv): class RayVectorEnv(BaseVectorEnv):
"""docstring for RayVectorEnv""" """docstring for RayVectorEnv"""
def __init__(self, env_fns, reset_after_done=False): def __init__(self, env_fns, reset_after_done=False):
super().__init__() super().__init__()
self.env_num = len(env_fns) self.env_num = len(env_fns)
@ -193,8 +200,11 @@ class RayVectorEnv(BaseVectorEnv):
if not ray.is_initialized(): if not ray.is_initialized():
ray.init() ray.init()
except NameError: except NameError:
raise ImportError('Please install ray to support VectorEnv: pip3 install ray -U') raise ImportError(
self.envs = [ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] 'Please install ray to support VectorEnv: pip3 install ray -U')
self.envs = [
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def __len__(self): def __len__(self):
return self.env_num return self.env_num

View File

@ -1,3 +1,5 @@
from tianshou.policy import BasePolicy from tianshou.policy.base import BasePolicy
__all__ = ['BasePolicy'] __all__ = [
'BasePolicy'
]

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
class BasePolicy(ABC): class BasePolicy(ABC):
"""docstring for BasePolicy""" """docstring for BasePolicy"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -21,8 +22,8 @@ class BasePolicy(ABC):
pass pass
@staticmethod @staticmethod
def process_fn(batch, buffer, index): def process_fn(batch, buffer, indice):
pass return batch
def exploration(self): def exploration(self):
pass pass