add policy.update to enable post process and remove collector.sample (#180)
* add policy.update to enable post process and remove collector.sample * update doc in policy concept * remove collector.sample in doc * doc update of concepts * docs * polish * polish policy * remove collector.sample in docs * minor fix * Apply suggestions from code review just a test * doc fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
140b1c2cab
commit
7f3b817b24
@ -139,12 +139,14 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page
|
||||
|
||||
### Modularized Policy
|
||||
|
||||
We decouple all of the algorithms into 4 parts:
|
||||
We decouple all of the algorithms roughly into the following parts:
|
||||
|
||||
- `__init__`: initialize the policy;
|
||||
- `forward`: to compute actions over given observations;
|
||||
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
||||
- `learn`: to learn from a given batch data.
|
||||
- `learn`: to learn from a given batch data;
|
||||
- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
|
||||
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.
|
||||
|
||||
Within this API, we can interact with different policies conveniently.
|
||||
|
||||
@ -165,7 +167,7 @@ result = collector.collect(n_episode=[1, 0, 3])
|
||||
If you want to train the given policy with a sampled batch:
|
||||
|
||||
```python
|
||||
result = policy.learn(collector.sample(batch_size))
|
||||
result = policy.update(batch_size, collector.buffer)
|
||||
```
|
||||
|
||||
You can check out the [documentation](https://tianshou.readthedocs.io) for further usage.
|
||||
|
@ -64,12 +64,14 @@ Policy
|
||||
|
||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
A policy class typically has four parts:
|
||||
A policy class typically has the following parts:
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer;
|
||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
|
||||
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data.
|
||||
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
|
||||
|
||||
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
|
||||
|
||||
@ -125,10 +127,8 @@ Collector
|
||||
---------
|
||||
|
||||
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
||||
In short, :class:`~tianshou.data.Collector` has two main methods:
|
||||
|
||||
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
|
||||
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
|
||||
:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.
|
||||
|
||||
Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
|
||||
|
||||
@ -144,8 +144,6 @@ Once you have a collector and a policy, you can start writing the training metho
|
||||
|
||||
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||
|
||||
There will be more types of trainers, for instance, multi-agent trainer.
|
||||
|
||||
|
||||
.. _pseudocode:
|
||||
|
||||
@ -165,7 +163,8 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
|
||||
buffer.store(s, a, s_, r, d) # collector.collect(...)
|
||||
s = s_ # collector.collect(...)
|
||||
if i % 1000 == 0: # done in trainer
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
|
||||
# the following is done in policy.update(batch_size, buffer)
|
||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size)
|
||||
# compute 2-step returns. How?
|
||||
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
|
||||
# update DQN policy
|
||||
|
@ -210,8 +210,8 @@ Tianshou supports user-defined training code. Here is the code snippet:
|
||||
# back to training eps
|
||||
policy.set_eps(0.1)
|
||||
|
||||
# train policy with a sampled batch data
|
||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
||||
# train policy with a sampled batch data from buffer
|
||||
losses = policy.update(64, train_collector.buffer)
|
||||
|
||||
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
|
||||
|
||||
|
@ -174,7 +174,7 @@ def test_collector_with_dict_state():
|
||||
c1.seed(0)
|
||||
c1.collect(n_step=10)
|
||||
c1.collect(n_episode=[2, 1, 1, 2])
|
||||
batch = c1.sample(10)
|
||||
batch, _ = c1.buffer.sample(10)
|
||||
print(batch)
|
||||
c0.buffer.update(c1.buffer)
|
||||
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([
|
||||
@ -184,7 +184,7 @@ def test_collector_with_dict_state():
|
||||
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
||||
Logger.single_preprocess_fn)
|
||||
c2.collect(n_episode=[0, 0, 0, 10])
|
||||
batch = c2.sample(10)
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
print(batch['obs_next']['index'])
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ def test_collector_with_ma():
|
||||
assert np.asanyarray(r).size == 1 and r == 4.
|
||||
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
|
||||
assert np.asanyarray(r).size == 1 and r == 4.
|
||||
batch = c1.sample(10)
|
||||
batch, _ = c1.buffer.sample(10)
|
||||
print(batch)
|
||||
c0.buffer.update(c1.buffer)
|
||||
obs = np.array(np.expand_dims([
|
||||
@ -226,7 +226,7 @@ def test_collector_with_ma():
|
||||
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
||||
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
||||
assert np.asanyarray(r).size == 1 and r == 4.
|
||||
batch = c2.sample(10)
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
print(batch['obs_next'])
|
||||
|
||||
|
||||
|
@ -64,16 +64,6 @@ class Collector(object):
|
||||
# sleep time between rendering consecutive frames)
|
||||
collector.collect(n_episode=1, render=0.03)
|
||||
|
||||
# sample data with a given number of batch-size:
|
||||
batch_data = collector.sample(batch_size=64)
|
||||
# policy.learn(batch_data) # btw, vanilla policy gradient only
|
||||
# supports on-policy training, so here we pick all data in the buffer
|
||||
batch_data = collector.sample(batch_size=0)
|
||||
policy.learn(batch_data)
|
||||
# on-policy algorithms use the collected data only once, so here we
|
||||
# clear the buffer
|
||||
collector.reset_buffer()
|
||||
|
||||
Collected data always consist of full episodes. So if only ``n_step``
|
||||
argument is give, the collector may return the data more than the
|
||||
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
|
||||
@ -357,13 +347,18 @@ class Collector(object):
|
||||
|
||||
def sample(self, batch_size: int) -> Batch:
|
||||
"""Sample a data batch from the internal replay buffer. It will call
|
||||
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
||||
the final batch data.
|
||||
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning the
|
||||
final batch data.
|
||||
|
||||
:param int batch_size: ``0`` means it will extract all the data from
|
||||
the buffer, otherwise it will extract the data with the given
|
||||
batch_size.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
'Collector.sample is deprecated and will cause error if you use '
|
||||
'prioritized experience replay! Collector.sample will be removed '
|
||||
'upon version 0.3. Use policy.update instead!', Warning)
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
return batch_data
|
||||
|
@ -93,9 +93,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
# some code
|
||||
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
||||
# and in the sampled data batch, you can directly call
|
||||
# batch.policy.log_prob to get your data, although it is stored in
|
||||
# np.ndarray.
|
||||
# and in the sampled data batch, you can directly use
|
||||
# batch.policy.log_prob to get your data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -123,6 +122,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
rew_norm: bool = False,
|
||||
) -> Batch:
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
@ -136,6 +136,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
||||
to ``False``.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||
array with shape (bsz, ).
|
||||
@ -150,6 +152,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae = delta[i] + m[i] * gae
|
||||
returns[i] += gae
|
||||
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
|
||||
returns = (returns - returns.mean()) / returns.std()
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
@ -196,7 +200,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
if rew_norm:
|
||||
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if np.isclose(std, 0):
|
||||
if np.isclose(std, 0, 1e-2):
|
||||
mean, std = 0, 1
|
||||
else:
|
||||
mean, std = 0, 1
|
||||
@ -216,9 +220,30 @@ class BasePolicy(ABC, nn.Module):
|
||||
batch.returns = target_q * gammas + returns
|
||||
# prio buffer update
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
batch.update_weight = buffer.update_weight
|
||||
batch.indice = indice
|
||||
batch.weight = to_torch_as(batch.weight, target_q)
|
||||
else:
|
||||
batch.weight = torch.ones_like(target_q)
|
||||
return batch
|
||||
|
||||
def post_process_fn(self, batch: Batch,
|
||||
buffer: ReplayBuffer, indice: np.ndarray):
|
||||
"""Post-process the data from the provided replay buffer. Typical
|
||||
usage is to update the sampling weight in prioritized experience
|
||||
replay. Check out :ref:`policy_concept` for more information.
|
||||
"""
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
buffer.update_weight(indice, batch.weight)
|
||||
|
||||
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
|
||||
"""Update the policy network and replay buffer (if needed). It includes
|
||||
three function steps: process_fn, learn, and post_process_fn.
|
||||
|
||||
:param int batch_size: 0 means it will extract all the data from the
|
||||
buffer, otherwise it will sample a batch with the given batch_size.
|
||||
:param ReplayBuffer buffer: the corresponding replay buffer.
|
||||
"""
|
||||
batch, indice = buffer.sample(batch_size)
|
||||
batch = self.process_fn(batch, buffer, indice)
|
||||
result = self.learn(batch, *args, **kwargs)
|
||||
self.post_process_fn(batch, buffer, indice)
|
||||
return result
|
||||
|
@ -67,7 +67,8 @@ class A2CPolicy(PGPolicy):
|
||||
v_.append(to_numpy(self.critic(b.obs_next)))
|
||||
v_ = np.concatenate(v_, axis=0)
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
||||
rew_norm=self._rew_norm)
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
@ -97,9 +98,6 @@ class A2CPolicy(PGPolicy):
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
r = batch.returns
|
||||
if self._rew_norm and not np.isclose(r.std(), 0):
|
||||
batch.returns = (r - r.mean()) / r.std()
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
|
@ -144,10 +144,8 @@ class DDPGPolicy(BasePolicy):
|
||||
current_q = self.critic(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
td = current_q - target_q
|
||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
||||
batch.update_weight(batch.indice, td)
|
||||
critic_loss = (td.pow(2) * batch.weight).mean()
|
||||
# critic_loss = F.mse_loss(current_q, target_q)
|
||||
batch.weight = td # prio-buffer
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
@ -160,10 +160,8 @@ class DQNPolicy(BasePolicy):
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns, q).flatten()
|
||||
td = r - q
|
||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
||||
batch.update_weight(batch.indice, td)
|
||||
loss = (td.pow(2) * batch.weight).mean()
|
||||
# loss = F.mse_loss(q, r)
|
||||
batch.weight = td # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
|
@ -51,7 +51,7 @@ class PGPolicy(BasePolicy):
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
# return batch
|
||||
return self.compute_episodic_return(
|
||||
batch, gamma=self._gamma, gae_lambda=1.)
|
||||
batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm)
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
@ -81,9 +81,6 @@ class PGPolicy(BasePolicy):
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
losses = []
|
||||
r = batch.returns
|
||||
if self._rew_norm and not np.isclose(r.std(), 0):
|
||||
batch.returns = (r - r.mean()) / r.std()
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
|
@ -79,18 +79,29 @@ class PPOPolicy(PGPolicy):
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
if not np.isclose(std, 0, 1e-2):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
v_ = []
|
||||
v, v_, old_log_prob = [], [], []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
to_torch_as(b.act, v[0])))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
batch = self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
||||
rew_norm=self._rew_norm)
|
||||
batch.v = torch.cat(v, dim=0).flatten() # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0])
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0, 1e-2):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
@ -123,26 +134,6 @@ class PPOPolicy(PGPolicy):
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
v = []
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
to_torch_as(b.act, v[0])))
|
||||
batch.v = torch.cat(v, dim=0).flatten() # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0])
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
|
@ -154,9 +154,7 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
# prio-buffer
|
||||
if hasattr(batch, 'update_weight'):
|
||||
batch.update_weight(batch.indice, (td1 + td2) / 2.)
|
||||
batch.weight = (td1 + td2) / 2. # prio-buffer
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
|
@ -132,8 +132,7 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
||||
batch.update_weight(batch.indice, (td1 + td2) / 2.)
|
||||
batch.weight = (td1 + td2) / 2. # prio-buffer
|
||||
if self._cnt % self._freq == 0:
|
||||
actor_loss = -self.critic1(
|
||||
batch.obs, self(batch, eps=0).act).mean()
|
||||
|
@ -28,7 +28,8 @@ def offpolicy_trainer(
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
"""A wrapper for off-policy trainer procedure. The ``step`` in trainer
|
||||
means a policy network update.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
@ -47,8 +48,9 @@ def offpolicy_trainer(
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param int update_per_step: the number of times the policy network would
|
||||
be updated after frames be collected. In other words, collect some
|
||||
frames and do some policy network update.
|
||||
be updated after frames are collected, for example, set it to 256 means
|
||||
it updates policy 256 times once after ``collect_per_step`` frames are
|
||||
collected.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of training in this
|
||||
epoch.
|
||||
@ -103,7 +105,7 @@ def offpolicy_trainer(
|
||||
for i in range(update_per_step * min(
|
||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||
global_step += 1
|
||||
losses = policy.learn(train_collector.sample(batch_size))
|
||||
losses = policy.update(batch_size, train_collector.buffer)
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer and global_step % log_interval == 0:
|
||||
|
@ -28,7 +28,8 @@ def onpolicy_trainer(
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
"""A wrapper for on-policy trainer procedure. The ``step`` in trainer means
|
||||
a policy network update.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
@ -101,8 +102,8 @@ def onpolicy_trainer(
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
losses = policy.learn(
|
||||
train_collector.sample(0), batch_size, repeat_per_collect)
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
step = 1
|
||||
for k in losses.keys():
|
||||
|
Loading…
x
Reference in New Issue
Block a user