add cache buf in collector
This commit is contained in:
parent
543e57cdbd
commit
c804662457
2
setup.py
2
setup.py
@ -33,7 +33,7 @@ setup(
|
||||
'Programming Language :: Python :: 3.8',
|
||||
],
|
||||
keywords='reinforcement learning platform',
|
||||
packages=find_packages(exclude=['tests', 'tests.*',
|
||||
packages=find_packages(exclude=['test', 'test.*',
|
||||
'examples', 'examples.*',
|
||||
'docs', 'docs.*']),
|
||||
install_requires=[
|
||||
|
@ -14,7 +14,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info)
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
data, indice = buf.sample(4)
|
||||
data, indice = buf.sample(bufsize * 2)
|
||||
assert (indice < len(buf)).all()
|
||||
assert (data.obs < size).all()
|
||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -8,6 +9,13 @@ class Batch(object):
|
||||
super().__init__()
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
b = Batch()
|
||||
for k in self.__dict__.keys():
|
||||
if self.__dict__[k] is not None:
|
||||
b.update(k=self.__dict__[k][index])
|
||||
return b
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
@ -21,9 +29,12 @@ class Batch(object):
|
||||
elif isinstance(batch.__dict__[k], np.ndarray):
|
||||
self.__dict__[k] = np.concatenate([
|
||||
self.__dict__[k], batch.__dict__[k]])
|
||||
elif isinstance(batch.__dict__[k], torch.Tensor):
|
||||
self.__dict__[k] = torch.cat([
|
||||
self.__dict__[k], batch.__dict__[k]])
|
||||
elif isinstance(batch.__dict__[k], list):
|
||||
self.__dict__[k] += batch.__dict__[k]
|
||||
else:
|
||||
raise TypeError(
|
||||
'Do not support append with type {} in class Batch.'
|
||||
'No support for append with type {} in class Batch.'
|
||||
.format(type(batch.__dict__[k])))
|
||||
|
@ -26,6 +26,12 @@ class ReplayBuffer(object):
|
||||
self.__dict__[name] = np.zeros([self._maxsize])
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
for i in range(len(buffer)):
|
||||
self.add(
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i],
|
||||
buffer.done[i], buffer.obs_next[i], buffer.info[i])
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
'''
|
||||
weight: importance weights, disabled here
|
||||
|
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
@ -9,21 +10,26 @@ from tianshou.utils import MovAvg
|
||||
class Collector(object):
|
||||
"""docstring for Collector"""
|
||||
|
||||
def __init__(self, policy, env, buffer):
|
||||
def __init__(self, policy, env, buffer, contiguous=True):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
self.multi_env = isinstance(env, BaseVectorEnv)
|
||||
if self.multi_env:
|
||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||
self._multi_buf = False # buf is a list
|
||||
# need multiple cache buffers only if contiguous in one buffer
|
||||
self._cached_buf = []
|
||||
if self._multi_env:
|
||||
self.env_num = len(env)
|
||||
if isinstance(self.buffer, list):
|
||||
assert len(self.buffer) == self.env_num,\
|
||||
'# of data buffer does not match the # of input env.'
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
||||
self._multi_buf = True
|
||||
elif isinstance(self.buffer, ReplayBuffer) and contiguous:
|
||||
self._cached_buf = [
|
||||
deepcopy(buffer) for _ in range(self.env_num)]
|
||||
else:
|
||||
raise TypeError('The buffer in data collector is invalid!')
|
||||
self.reset_env()
|
||||
@ -34,7 +40,7 @@ class Collector(object):
|
||||
self.stat_length = MovAvg()
|
||||
|
||||
def clear_buffer(self):
|
||||
if self.multi_env:
|
||||
if self._multi_buf:
|
||||
for b in self.buffer:
|
||||
b.reset()
|
||||
else:
|
||||
@ -43,17 +49,24 @@ class Collector(object):
|
||||
def reset_env(self):
|
||||
self._obs = self.env.reset()
|
||||
self._act = self._rew = self._done = self._info = None
|
||||
if self.multi_env:
|
||||
if self._multi_env:
|
||||
self.reward = np.zeros(self.env_num)
|
||||
self.length = np.zeros(self.env_num)
|
||||
else:
|
||||
self.reward, self.length = 0, 0
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
|
||||
def _make_batch(data):
|
||||
if isinstance(data, np.ndarray):
|
||||
return data[None]
|
||||
else:
|
||||
return [data]
|
||||
|
||||
def collect(self, n_step=0, n_episode=0):
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
cur_step = 0
|
||||
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
|
||||
cur_step, cur_episode = 0, 0
|
||||
while True:
|
||||
if self.multi_env:
|
||||
batch_data = Batch(
|
||||
@ -61,41 +74,55 @@ class Collector(object):
|
||||
done=self._done, obs_next=None, info=self._info)
|
||||
else:
|
||||
batch_data = Batch(
|
||||
obs=[self._obs], act=[self._act], rew=[self._rew],
|
||||
done=[self._done], obs_next=None, info=[self._info])
|
||||
obs=self._make_batch(self._obs),
|
||||
act=self._make_batch(self._act),
|
||||
rew=self._make_batch(self._rew),
|
||||
done=self._make_batch(self._done),
|
||||
obs_next=None, info=self._make_batch(self._info))
|
||||
result = self.policy.act(batch_data, self.state)
|
||||
self.state = result.state
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
self._act = result.act
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
self._act)
|
||||
cur_step += 1
|
||||
self.length += 1
|
||||
self.reward += self._rew
|
||||
if self.multi_env:
|
||||
if self._multi_env:
|
||||
for i in range(self.env_num):
|
||||
if n_episode > 0 and \
|
||||
cur_episode[i] < n_episode or n_episode == 0:
|
||||
self.buffer[i].add(
|
||||
self._obs[i], self._act[i], self._rew[i],
|
||||
self._done[i], obs_next[i], self._info[i])
|
||||
if self._done[i]:
|
||||
cur_episode[i] += 1
|
||||
self.stat_reward.add(self.reward[i])
|
||||
self.stat_length.add(self.length[i])
|
||||
self.reward[i], self.length[i] = 0, 0
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
else:
|
||||
self.state[i] = self.state[i] * 0
|
||||
if hasattr(self.state, 'detach'):
|
||||
# remove ref in torch
|
||||
self.state = self.state.detach()
|
||||
if n_episode > 0 and (cur_episode >= n_episode).all():
|
||||
data = {
|
||||
'obs': self._obs[i], 'act': self._act[i],
|
||||
'rew': self._rew[i], 'done': self._done[i],
|
||||
'obs_next': obs_next[i], 'info': self._info[i]}
|
||||
if self._cached_buf:
|
||||
self._cached_buf[i].add(**data)
|
||||
elif self._multi_buf:
|
||||
self.buffer[i].add(**data)
|
||||
cur_step += 1
|
||||
else:
|
||||
self.buffer.add(**data)
|
||||
cur_step += 1
|
||||
if self._done[i]:
|
||||
cur_episode += 1
|
||||
self.stat_reward.add(self.reward[i])
|
||||
self.stat_length.add(self.length[i])
|
||||
self.reward[i], self.length[i] = 0, 0
|
||||
if self._cached_buf:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
cur_step += len(self._cached_buf[i])
|
||||
self._cached_buf[i].reset()
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
else:
|
||||
self.state[i] = self.state[i] * 0
|
||||
if isinstance(self.state, torch.Tensor):
|
||||
# remove ref in torch (?)
|
||||
self.state = self.state.detach()
|
||||
if n_episode > 0 and cur_episode >= n_episode:
|
||||
break
|
||||
else:
|
||||
self.buffer.add(
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info)
|
||||
cur_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
self.stat_reward.add(self.reward)
|
||||
@ -110,7 +137,7 @@ class Collector(object):
|
||||
self._obs = obs_next
|
||||
|
||||
def sample(self, batch_size):
|
||||
if self.multi_env:
|
||||
if self._multi_buf:
|
||||
if batch_size > 0:
|
||||
lens = [len(b) for b in self.buffer]
|
||||
total = sum(lens)
|
||||
|
@ -21,9 +21,11 @@ class BasePolicy(ABC):
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def process_fn(batch, buffer, indice):
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
return batch
|
||||
|
||||
def sync_weights(self):
|
||||
pass
|
||||
|
||||
def exploration(self):
|
||||
pass
|
||||
|
33
tianshou/policy/dqn.py
Normal file
33
tianshou/policy/dqn.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from copy import deepcopy
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy, nn.Module):
|
||||
"""docstring for DQNPolicy"""
|
||||
|
||||
def __init__(self, model, discount_factor=0.99, estimation_step=1,
|
||||
use_target_network=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self._gamma = discount_factor
|
||||
self._n_step = estimation_step
|
||||
self._target = use_target_network
|
||||
if use_target_network:
|
||||
self.model_old = deepcopy(self.model)
|
||||
|
||||
def act(self, batch, hidden_state=None):
|
||||
batch_result = Batch()
|
||||
return batch_result
|
||||
|
||||
def sync_weights(self):
|
||||
if self._use_target_network:
|
||||
for old, new in zip(
|
||||
self.model_old.parameters(), self.model.parameters()):
|
||||
old.data.copy_(new.data)
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
return batch
|
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -8,8 +9,7 @@ class MovAvg(object):
|
||||
self.cache = []
|
||||
|
||||
def add(self, x):
|
||||
if hasattr(x, 'detach'):
|
||||
# which means x is torch.Tensor (?)
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
if x != np.inf:
|
||||
self.cache.append(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user