add policy.update to enable post process and remove collector.sample (#180)

* add policy.update to enable post process and remove collector.sample

* update doc in policy concept

* remove collector.sample in doc

* doc update of concepts

* docs

* polish

* polish policy

* remove collector.sample in docs

* minor fix

* Apply suggestions from code review

just a test

* doc fix

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-08-15 16:10:42 +08:00 committed by GitHub
parent 140b1c2cab
commit 7f3b817b24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 92 additions and 89 deletions

View File

@ -139,12 +139,14 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page
### Modularized Policy
We decouple all of the algorithms into 4 parts:
We decouple all of the algorithms roughly into the following parts:
- `__init__`: initialize the policy;
- `forward`: to compute actions over given observations;
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
- `learn`: to learn from a given batch data.
- `learn`: to learn from a given batch data;
- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.
Within this API, we can interact with different policies conveniently.
@ -165,7 +167,7 @@ result = collector.collect(n_episode=[1, 0, 3])
If you want to train the given policy with a sampled batch:
```python
result = policy.learn(collector.sample(batch_size))
result = policy.update(batch_size, collector.buffer)
```
You can check out the [documentation](https://tianshou.readthedocs.io) for further usage.

View File

@ -64,12 +64,14 @@ Policy
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
A policy class typically has four parts:
A policy class typically has the following parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer;
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data.
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
@ -125,10 +127,8 @@ Collector
---------
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
In short, :class:`~tianshou.data.Collector` has two main methods:
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.
Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
@ -144,8 +144,6 @@ Once you have a collector and a policy, you can start writing the training metho
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
There will be more types of trainers, for instance, multi-agent trainer.
.. _pseudocode:
@ -165,7 +163,8 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
if i % 1000 == 0: # done in trainer
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
# the following is done in policy.update(batch_size, buffer)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
# update DQN policy

View File

@ -210,8 +210,8 @@ Tianshou supports user-defined training code. Here is the code snippet:
# back to training eps
policy.set_eps(0.1)
# train policy with a sampled batch data
losses = policy.learn(train_collector.sample(batch_size=64))
# train policy with a sampled batch data from buffer
losses = policy.update(64, train_collector.buffer)
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.

View File

@ -174,7 +174,7 @@ def test_collector_with_dict_state():
c1.seed(0)
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([
@ -184,7 +184,7 @@ def test_collector_with_dict_state():
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
Logger.single_preprocess_fn)
c2.collect(n_episode=[0, 0, 0, 10])
batch = c2.sample(10)
batch, _ = c2.buffer.sample(10)
print(batch['obs_next']['index'])
@ -209,7 +209,7 @@ def test_collector_with_ma():
assert np.asanyarray(r).size == 1 and r == 4.
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c1.sample(10)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
obs = np.array(np.expand_dims([
@ -226,7 +226,7 @@ def test_collector_with_ma():
Logger.single_preprocess_fn, reward_metric=reward_metric)
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c2.sample(10)
batch, _ = c2.buffer.sample(10)
print(batch['obs_next'])

View File

@ -64,16 +64,6 @@ class Collector(object):
# sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)
# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data) # btw, vanilla policy gradient only
# supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
# clear the buffer
collector.reset_buffer()
Collected data always consist of full episodes. So if only ``n_step``
argument is give, the collector may return the data more than the
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
@ -357,13 +347,18 @@ class Collector(object):
def sample(self, batch_size: int) -> Batch:
"""Sample a data batch from the internal replay buffer. It will call
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning the
final batch data.
:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
import warnings
warnings.warn(
'Collector.sample is deprecated and will cause error if you use '
'prioritized experience replay! Collector.sample will be removed '
'upon version 0.3. Use policy.update instead!', Warning)
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data

View File

@ -93,9 +93,8 @@ class BasePolicy(ABC, nn.Module):
# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly call
# batch.policy.log_prob to get your data, although it is stored in
# np.ndarray.
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.
"""
pass
@ -123,6 +122,7 @@ class BasePolicy(ABC, nn.Module):
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
@ -136,6 +136,8 @@ class BasePolicy(ABC, nn.Module):
to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to ``False``.
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
@ -150,6 +152,8 @@ class BasePolicy(ABC, nn.Module):
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
returns[i] += gae
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
batch.returns = returns
return batch
@ -196,7 +200,7 @@ class BasePolicy(ABC, nn.Module):
if rew_norm:
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0):
if np.isclose(std, 0, 1e-2):
mean, std = 0, 1
else:
mean, std = 0, 1
@ -216,9 +220,30 @@ class BasePolicy(ABC, nn.Module):
batch.returns = target_q * gammas + returns
# prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
batch.indice = indice
batch.weight = to_torch_as(batch.weight, target_q)
else:
batch.weight = torch.ones_like(target_q)
return batch
def post_process_fn(self, batch: Batch,
buffer: ReplayBuffer, indice: np.ndarray):
"""Post-process the data from the provided replay buffer. Typical
usage is to update the sampling weight in prioritized experience
replay. Check out :ref:`policy_concept` for more information.
"""
if isinstance(buffer, PrioritizedReplayBuffer):
buffer.update_weight(indice, batch.weight)
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
"""Update the policy network and replay buffer (if needed). It includes
three function steps: process_fn, learn, and post_process_fn.
:param int batch_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with the given batch_size.
:param ReplayBuffer buffer: the corresponding replay buffer.
"""
batch, indice = buffer.sample(batch_size)
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, *args, **kwargs)
self.post_process_fn(batch, buffer, indice)
return result

View File

@ -67,7 +67,8 @@ class A2CPolicy(PGPolicy):
v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
rew_norm=self._rew_norm)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
@ -97,9 +98,6 @@ class A2CPolicy(PGPolicy):
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
r = batch.returns
if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std()
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):

View File

@ -144,10 +144,8 @@ class DDPGPolicy(BasePolicy):
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, td)
critic_loss = (td.pow(2) * batch.weight).mean()
# critic_loss = F.mse_loss(current_q, target_q)
batch.weight = td # prio-buffer
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()

View File

@ -160,10 +160,8 @@ class DQNPolicy(BasePolicy):
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
td = r - q
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, td)
loss = (td.pow(2) * batch.weight).mean()
# loss = F.mse_loss(q, r)
batch.weight = td # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1

View File

@ -51,7 +51,7 @@ class PGPolicy(BasePolicy):
# batch.returns = self._vectorized_returns(batch)
# return batch
return self.compute_episodic_return(
batch, gamma=self._gamma, gae_lambda=1.)
batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
@ -81,9 +81,6 @@ class PGPolicy(BasePolicy):
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
losses = []
r = batch.returns
if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std()
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()

View File

@ -79,18 +79,29 @@ class PPOPolicy(PGPolicy):
indice: np.ndarray) -> Batch:
if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std()
if not np.isclose(std, 0):
if not np.isclose(std, 0, 1e-2):
batch.rew = (batch.rew - mean) / std
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
v, v_, old_log_prob = [], [], []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next))
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
v_ = to_numpy(torch.cat(v_, dim=0))
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
batch = self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
rew_norm=self._rew_norm)
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, v[0])
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0, 1e-2):
batch.adv = (batch.adv - mean) / std
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
@ -123,26 +134,6 @@ class PPOPolicy(PGPolicy):
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
v = []
old_log_prob = []
with torch.no_grad():
for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):
batch.returns = (batch.returns - mean) / std
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0):
batch.adv = (batch.adv - mean) / std
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist

View File

@ -154,9 +154,7 @@ class SACPolicy(DDPGPolicy):
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
# prio-buffer
if hasattr(batch, 'update_weight'):
batch.update_weight(batch.indice, (td1 + td2) / 2.)
batch.weight = (td1 + td2) / 2. # prio-buffer
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act

View File

@ -132,8 +132,7 @@ class TD3Policy(DDPGPolicy):
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, (td1 + td2) / 2.)
batch.weight = (td1 + td2) / 2. # prio-buffer
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(
batch.obs, self(batch, eps=0).act).mean()

View File

@ -28,7 +28,8 @@ def offpolicy_trainer(
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure.
"""A wrapper for off-policy trainer procedure. The ``step`` in trainer
means a policy network update.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
@ -47,8 +48,9 @@ def offpolicy_trainer(
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param int update_per_step: the number of times the policy network would
be updated after frames be collected. In other words, collect some
frames and do some policy network update.
be updated after frames are collected, for example, set it to 256 means
it updates policy 256 times once after ``collect_per_step`` frames are
collected.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
@ -103,7 +105,7 @@ def offpolicy_trainer(
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))
losses = policy.update(batch_size, train_collector.buffer)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
if writer and global_step % log_interval == 0:

View File

@ -28,7 +28,8 @@ def onpolicy_trainer(
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
"""A wrapper for on-policy trainer procedure. The ``step`` in trainer means
a policy network update.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
@ -101,8 +102,8 @@ def onpolicy_trainer(
policy.train()
if train_fn:
train_fn(epoch)
losses = policy.learn(
train_collector.sample(0), batch_size, repeat_per_collect)
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
train_collector.reset_buffer()
step = 1
for k in losses.keys():