add cache buf in collector

This commit is contained in:
Trinkle23897 2020-03-14 21:48:31 +08:00
parent 543e57cdbd
commit c804662457
8 changed files with 119 additions and 40 deletions

View File

@ -33,7 +33,7 @@ setup(
'Programming Language :: Python :: 3.8',
],
keywords='reinforcement learning platform',
packages=find_packages(exclude=['tests', 'tests.*',
packages=find_packages(exclude=['test', 'test.*',
'examples', 'examples.*',
'docs', 'docs.*']),
install_requires=[

View File

@ -14,7 +14,7 @@ def test_replaybuffer(size=10, bufsize=20):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info)
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
data, indice = buf.sample(4)
data, indice = buf.sample(bufsize * 2)
assert (indice < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()

View File

@ -1,3 +1,4 @@
import torch
import numpy as np
@ -8,6 +9,13 @@ class Batch(object):
super().__init__()
self.__dict__.update(kwargs)
def __getitem__(self, index):
b = Batch()
for k in self.__dict__.keys():
if self.__dict__[k] is not None:
b.update(k=self.__dict__[k][index])
return b
def update(self, **kwargs):
self.__dict__.update(kwargs)
@ -21,9 +29,12 @@ class Batch(object):
elif isinstance(batch.__dict__[k], np.ndarray):
self.__dict__[k] = np.concatenate([
self.__dict__[k], batch.__dict__[k]])
elif isinstance(batch.__dict__[k], torch.Tensor):
self.__dict__[k] = torch.cat([
self.__dict__[k], batch.__dict__[k]])
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
raise TypeError(
'Do not support append with type {} in class Batch.'
'No support for append with type {} in class Batch.'
.format(type(batch.__dict__[k])))

View File

@ -26,6 +26,12 @@ class ReplayBuffer(object):
self.__dict__[name] = np.zeros([self._maxsize])
self.__dict__[name][self._index] = inst
def update(self, buffer):
for i in range(len(buffer)):
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i],
buffer.done[i], buffer.obs_next[i], buffer.info[i])
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
'''
weight: importance weights, disabled here

View File

@ -1,3 +1,4 @@
import torch
import numpy as np
from copy import deepcopy
@ -9,21 +10,26 @@ from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
def __init__(self, policy, env, buffer):
def __init__(self, policy, env, buffer, contiguous=True):
super().__init__()
self.env = env
self.env_num = 1
self.buffer = buffer
self.policy = policy
self.process_fn = policy.process_fn
self.multi_env = isinstance(env, BaseVectorEnv)
if self.multi_env:
self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # buf is a list
# need multiple cache buffers only if contiguous in one buffer
self._cached_buf = []
if self._multi_env:
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num,\
'# of data buffer does not match the # of input env.'
elif isinstance(self.buffer, ReplayBuffer):
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer) and contiguous:
self._cached_buf = [
deepcopy(buffer) for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()
@ -34,7 +40,7 @@ class Collector(object):
self.stat_length = MovAvg()
def clear_buffer(self):
if self.multi_env:
if self._multi_buf:
for b in self.buffer:
b.reset()
else:
@ -43,17 +49,24 @@ class Collector(object):
def reset_env(self):
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None
if self.multi_env:
if self._multi_env:
self.reward = np.zeros(self.env_num)
self.length = np.zeros(self.env_num)
else:
self.reward, self.length = 0, 0
for b in self._cached_buf:
b.reset()
def _make_batch(data):
if isinstance(data, np.ndarray):
return data[None]
else:
return [data]
def collect(self, n_step=0, n_episode=0):
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
cur_step, cur_episode = 0, 0
while True:
if self.multi_env:
batch_data = Batch(
@ -61,41 +74,55 @@ class Collector(object):
done=self._done, obs_next=None, info=self._info)
else:
batch_data = Batch(
obs=[self._obs], act=[self._act], rew=[self._rew],
done=[self._done], obs_next=None, info=[self._info])
obs=self._make_batch(self._obs),
act=self._make_batch(self._act),
rew=self._make_batch(self._rew),
done=self._make_batch(self._done),
obs_next=None, info=self._make_batch(self._info))
result = self.policy.act(batch_data, self.state)
self.state = result.state
self.state = result.state if hasattr(result, 'state') else None
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(
self._act)
cur_step += 1
self.length += 1
self.reward += self._rew
if self.multi_env:
if self._multi_env:
for i in range(self.env_num):
if n_episode > 0 and \
cur_episode[i] < n_episode or n_episode == 0:
self.buffer[i].add(
self._obs[i], self._act[i], self._rew[i],
self._done[i], obs_next[i], self._info[i])
if self._done[i]:
cur_episode[i] += 1
self.stat_reward.add(self.reward[i])
self.stat_length.add(self.length[i])
self.reward[i], self.length[i] = 0, 0
if isinstance(self.state, list):
self.state[i] = None
else:
self.state[i] = self.state[i] * 0
if hasattr(self.state, 'detach'):
# remove ref in torch
self.state = self.state.detach()
if n_episode > 0 and (cur_episode >= n_episode).all():
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i]}
if self._cached_buf:
self._cached_buf[i].add(**data)
elif self._multi_buf:
self.buffer[i].add(**data)
cur_step += 1
else:
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
cur_episode += 1
self.stat_reward.add(self.reward[i])
self.stat_length.add(self.length[i])
self.reward[i], self.length[i] = 0, 0
if self._cached_buf:
self.buffer.update(self._cached_buf[i])
cur_step += len(self._cached_buf[i])
self._cached_buf[i].reset()
if isinstance(self.state, list):
self.state[i] = None
else:
self.state[i] = self.state[i] * 0
if isinstance(self.state, torch.Tensor):
# remove ref in torch (?)
self.state = self.state.detach()
if n_episode > 0 and cur_episode >= n_episode:
break
else:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
cur_step += 1
if self._done:
cur_episode += 1
self.stat_reward.add(self.reward)
@ -110,7 +137,7 @@ class Collector(object):
self._obs = obs_next
def sample(self, batch_size):
if self.multi_env:
if self._multi_buf:
if batch_size > 0:
lens = [len(b) for b in self.buffer]
total = sum(lens)

View File

@ -21,9 +21,11 @@ class BasePolicy(ABC):
def reset(self):
pass
@staticmethod
def process_fn(batch, buffer, indice):
def process_fn(self, batch, buffer, indice):
return batch
def sync_weights(self):
pass
def exploration(self):
pass

33
tianshou/policy/dqn.py Normal file
View File

@ -0,0 +1,33 @@
import torch
from torch import nn
from copy import deepcopy
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy, nn.Module):
"""docstring for DQNPolicy"""
def __init__(self, model, discount_factor=0.99, estimation_step=1,
use_target_network=True):
super().__init__()
self.model = model
self._gamma = discount_factor
self._n_step = estimation_step
self._target = use_target_network
if use_target_network:
self.model_old = deepcopy(self.model)
def act(self, batch, hidden_state=None):
batch_result = Batch()
return batch_result
def sync_weights(self):
if self._use_target_network:
for old, new in zip(
self.model_old.parameters(), self.model.parameters()):
old.data.copy_(new.data)
def process_fn(self, batch, buffer, indice):
return batch

View File

@ -1,3 +1,4 @@
import torch
import numpy as np
@ -8,8 +9,7 @@ class MovAvg(object):
self.cache = []
def add(self, x):
if hasattr(x, 'detach'):
# which means x is torch.Tensor (?)
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if x != np.inf:
self.cache.append(x)