add policy.update to enable post process and remove collector.sample (#180)
* add policy.update to enable post process and remove collector.sample * update doc in policy concept * remove collector.sample in doc * doc update of concepts * docs * polish * polish policy * remove collector.sample in docs * minor fix * Apply suggestions from code review just a test * doc fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
140b1c2cab
commit
7f3b817b24
@ -139,12 +139,14 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page
|
|||||||
|
|
||||||
### Modularized Policy
|
### Modularized Policy
|
||||||
|
|
||||||
We decouple all of the algorithms into 4 parts:
|
We decouple all of the algorithms roughly into the following parts:
|
||||||
|
|
||||||
- `__init__`: initialize the policy;
|
- `__init__`: initialize the policy;
|
||||||
- `forward`: to compute actions over given observations;
|
- `forward`: to compute actions over given observations;
|
||||||
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
||||||
- `learn`: to learn from a given batch data.
|
- `learn`: to learn from a given batch data;
|
||||||
|
- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
|
||||||
|
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.
|
||||||
|
|
||||||
Within this API, we can interact with different policies conveniently.
|
Within this API, we can interact with different policies conveniently.
|
||||||
|
|
||||||
@ -165,7 +167,7 @@ result = collector.collect(n_episode=[1, 0, 3])
|
|||||||
If you want to train the given policy with a sampled batch:
|
If you want to train the given policy with a sampled batch:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
result = policy.learn(collector.sample(batch_size))
|
result = policy.update(batch_size, collector.buffer)
|
||||||
```
|
```
|
||||||
|
|
||||||
You can check out the [documentation](https://tianshou.readthedocs.io) for further usage.
|
You can check out the [documentation](https://tianshou.readthedocs.io) for further usage.
|
||||||
|
@ -64,12 +64,14 @@ Policy
|
|||||||
|
|
||||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
|
||||||
|
|
||||||
A policy class typically has four parts:
|
A policy class typically has the following parts:
|
||||||
|
|
||||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
|
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on;
|
||||||
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
|
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
|
||||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
|
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer;
|
||||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
|
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data.
|
||||||
|
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
|
||||||
|
|
||||||
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
|
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
|
||||||
|
|
||||||
@ -125,10 +127,8 @@ Collector
|
|||||||
---------
|
---------
|
||||||
|
|
||||||
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
|
||||||
In short, :class:`~tianshou.data.Collector` has two main methods:
|
|
||||||
|
|
||||||
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
|
:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.
|
||||||
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
|
|
||||||
|
|
||||||
Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
|
Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
|
||||||
|
|
||||||
@ -144,8 +144,6 @@ Once you have a collector and a policy, you can start writing the training metho
|
|||||||
|
|
||||||
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
|
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||||
|
|
||||||
There will be more types of trainers, for instance, multi-agent trainer.
|
|
||||||
|
|
||||||
|
|
||||||
.. _pseudocode:
|
.. _pseudocode:
|
||||||
|
|
||||||
@ -165,7 +163,8 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
|
|||||||
buffer.store(s, a, s_, r, d) # collector.collect(...)
|
buffer.store(s, a, s_, r, d) # collector.collect(...)
|
||||||
s = s_ # collector.collect(...)
|
s = s_ # collector.collect(...)
|
||||||
if i % 1000 == 0: # done in trainer
|
if i % 1000 == 0: # done in trainer
|
||||||
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
|
# the following is done in policy.update(batch_size, buffer)
|
||||||
|
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size)
|
||||||
# compute 2-step returns. How?
|
# compute 2-step returns. How?
|
||||||
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
|
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
|
||||||
# update DQN policy
|
# update DQN policy
|
||||||
|
@ -210,8 +210,8 @@ Tianshou supports user-defined training code. Here is the code snippet:
|
|||||||
# back to training eps
|
# back to training eps
|
||||||
policy.set_eps(0.1)
|
policy.set_eps(0.1)
|
||||||
|
|
||||||
# train policy with a sampled batch data
|
# train policy with a sampled batch data from buffer
|
||||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
losses = policy.update(64, train_collector.buffer)
|
||||||
|
|
||||||
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
|
For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ def test_collector_with_dict_state():
|
|||||||
c1.seed(0)
|
c1.seed(0)
|
||||||
c1.collect(n_step=10)
|
c1.collect(n_step=10)
|
||||||
c1.collect(n_episode=[2, 1, 1, 2])
|
c1.collect(n_episode=[2, 1, 1, 2])
|
||||||
batch = c1.sample(10)
|
batch, _ = c1.buffer.sample(10)
|
||||||
print(batch)
|
print(batch)
|
||||||
c0.buffer.update(c1.buffer)
|
c0.buffer.update(c1.buffer)
|
||||||
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([
|
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([
|
||||||
@ -184,7 +184,7 @@ def test_collector_with_dict_state():
|
|||||||
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
||||||
Logger.single_preprocess_fn)
|
Logger.single_preprocess_fn)
|
||||||
c2.collect(n_episode=[0, 0, 0, 10])
|
c2.collect(n_episode=[0, 0, 0, 10])
|
||||||
batch = c2.sample(10)
|
batch, _ = c2.buffer.sample(10)
|
||||||
print(batch['obs_next']['index'])
|
print(batch['obs_next']['index'])
|
||||||
|
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ def test_collector_with_ma():
|
|||||||
assert np.asanyarray(r).size == 1 and r == 4.
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
|
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
|
||||||
assert np.asanyarray(r).size == 1 and r == 4.
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
batch = c1.sample(10)
|
batch, _ = c1.buffer.sample(10)
|
||||||
print(batch)
|
print(batch)
|
||||||
c0.buffer.update(c1.buffer)
|
c0.buffer.update(c1.buffer)
|
||||||
obs = np.array(np.expand_dims([
|
obs = np.array(np.expand_dims([
|
||||||
@ -226,7 +226,7 @@ def test_collector_with_ma():
|
|||||||
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
||||||
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
||||||
assert np.asanyarray(r).size == 1 and r == 4.
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
batch = c2.sample(10)
|
batch, _ = c2.buffer.sample(10)
|
||||||
print(batch['obs_next'])
|
print(batch['obs_next'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,16 +64,6 @@ class Collector(object):
|
|||||||
# sleep time between rendering consecutive frames)
|
# sleep time between rendering consecutive frames)
|
||||||
collector.collect(n_episode=1, render=0.03)
|
collector.collect(n_episode=1, render=0.03)
|
||||||
|
|
||||||
# sample data with a given number of batch-size:
|
|
||||||
batch_data = collector.sample(batch_size=64)
|
|
||||||
# policy.learn(batch_data) # btw, vanilla policy gradient only
|
|
||||||
# supports on-policy training, so here we pick all data in the buffer
|
|
||||||
batch_data = collector.sample(batch_size=0)
|
|
||||||
policy.learn(batch_data)
|
|
||||||
# on-policy algorithms use the collected data only once, so here we
|
|
||||||
# clear the buffer
|
|
||||||
collector.reset_buffer()
|
|
||||||
|
|
||||||
Collected data always consist of full episodes. So if only ``n_step``
|
Collected data always consist of full episodes. So if only ``n_step``
|
||||||
argument is give, the collector may return the data more than the
|
argument is give, the collector may return the data more than the
|
||||||
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
|
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
|
||||||
@ -357,13 +347,18 @@ class Collector(object):
|
|||||||
|
|
||||||
def sample(self, batch_size: int) -> Batch:
|
def sample(self, batch_size: int) -> Batch:
|
||||||
"""Sample a data batch from the internal replay buffer. It will call
|
"""Sample a data batch from the internal replay buffer. It will call
|
||||||
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning the
|
||||||
the final batch data.
|
final batch data.
|
||||||
|
|
||||||
:param int batch_size: ``0`` means it will extract all the data from
|
:param int batch_size: ``0`` means it will extract all the data from
|
||||||
the buffer, otherwise it will extract the data with the given
|
the buffer, otherwise it will extract the data with the given
|
||||||
batch_size.
|
batch_size.
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
'Collector.sample is deprecated and will cause error if you use '
|
||||||
|
'prioritized experience replay! Collector.sample will be removed '
|
||||||
|
'upon version 0.3. Use policy.update instead!', Warning)
|
||||||
batch_data, indice = self.buffer.sample(batch_size)
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
return batch_data
|
return batch_data
|
||||||
|
@ -93,9 +93,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
|
|
||||||
# some code
|
# some code
|
||||||
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
||||||
# and in the sampled data batch, you can directly call
|
# and in the sampled data batch, you can directly use
|
||||||
# batch.policy.log_prob to get your data, although it is stored in
|
# batch.policy.log_prob to get your data.
|
||||||
# np.ndarray.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -123,6 +122,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
|
rew_norm: bool = False,
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
"""Compute returns over given full-length episodes, including the
|
"""Compute returns over given full-length episodes, including the
|
||||||
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||||
@ -136,6 +136,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
to 0.99.
|
to 0.99.
|
||||||
:param float gae_lambda: the parameter for Generalized Advantage
|
:param float gae_lambda: the parameter for Generalized Advantage
|
||||||
Estimation, should be in [0, 1], defaults to 0.95.
|
Estimation, should be in [0, 1], defaults to 0.95.
|
||||||
|
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
||||||
|
to ``False``.
|
||||||
|
|
||||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||||
array with shape (bsz, ).
|
array with shape (bsz, ).
|
||||||
@ -150,6 +152,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
for i in range(len(rew) - 1, -1, -1):
|
for i in range(len(rew) - 1, -1, -1):
|
||||||
gae = delta[i] + m[i] * gae
|
gae = delta[i] + m[i] * gae
|
||||||
returns[i] += gae
|
returns[i] += gae
|
||||||
|
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
|
||||||
|
returns = (returns - returns.mean()) / returns.std()
|
||||||
batch.returns = returns
|
batch.returns = returns
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -196,7 +200,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
if rew_norm:
|
if rew_norm:
|
||||||
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
|
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||||
mean, std = bfr.mean(), bfr.std()
|
mean, std = bfr.mean(), bfr.std()
|
||||||
if np.isclose(std, 0):
|
if np.isclose(std, 0, 1e-2):
|
||||||
mean, std = 0, 1
|
mean, std = 0, 1
|
||||||
else:
|
else:
|
||||||
mean, std = 0, 1
|
mean, std = 0, 1
|
||||||
@ -216,9 +220,30 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
batch.returns = target_q * gammas + returns
|
batch.returns = target_q * gammas + returns
|
||||||
# prio buffer update
|
# prio buffer update
|
||||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||||
batch.update_weight = buffer.update_weight
|
|
||||||
batch.indice = indice
|
|
||||||
batch.weight = to_torch_as(batch.weight, target_q)
|
batch.weight = to_torch_as(batch.weight, target_q)
|
||||||
else:
|
else:
|
||||||
batch.weight = torch.ones_like(target_q)
|
batch.weight = torch.ones_like(target_q)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def post_process_fn(self, batch: Batch,
|
||||||
|
buffer: ReplayBuffer, indice: np.ndarray):
|
||||||
|
"""Post-process the data from the provided replay buffer. Typical
|
||||||
|
usage is to update the sampling weight in prioritized experience
|
||||||
|
replay. Check out :ref:`policy_concept` for more information.
|
||||||
|
"""
|
||||||
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||||
|
buffer.update_weight(indice, batch.weight)
|
||||||
|
|
||||||
|
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
|
||||||
|
"""Update the policy network and replay buffer (if needed). It includes
|
||||||
|
three function steps: process_fn, learn, and post_process_fn.
|
||||||
|
|
||||||
|
:param int batch_size: 0 means it will extract all the data from the
|
||||||
|
buffer, otherwise it will sample a batch with the given batch_size.
|
||||||
|
:param ReplayBuffer buffer: the corresponding replay buffer.
|
||||||
|
"""
|
||||||
|
batch, indice = buffer.sample(batch_size)
|
||||||
|
batch = self.process_fn(batch, buffer, indice)
|
||||||
|
result = self.learn(batch, *args, **kwargs)
|
||||||
|
self.post_process_fn(batch, buffer, indice)
|
||||||
|
return result
|
||||||
|
@ -67,7 +67,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
v_.append(to_numpy(self.critic(b.obs_next)))
|
v_.append(to_numpy(self.critic(b.obs_next)))
|
||||||
v_ = np.concatenate(v_, axis=0)
|
v_ = np.concatenate(v_, axis=0)
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
||||||
|
rew_norm=self._rew_norm)
|
||||||
|
|
||||||
def forward(self, batch: Batch,
|
def forward(self, batch: Batch,
|
||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
@ -97,9 +98,6 @@ class A2CPolicy(PGPolicy):
|
|||||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||||
**kwargs) -> Dict[str, List[float]]:
|
**kwargs) -> Dict[str, List[float]]:
|
||||||
self._batch = batch_size
|
self._batch = batch_size
|
||||||
r = batch.returns
|
|
||||||
if self._rew_norm and not np.isclose(r.std(), 0):
|
|
||||||
batch.returns = (r - r.mean()) / r.std()
|
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
|
@ -144,10 +144,8 @@ class DDPGPolicy(BasePolicy):
|
|||||||
current_q = self.critic(batch.obs, batch.act).flatten()
|
current_q = self.critic(batch.obs, batch.act).flatten()
|
||||||
target_q = batch.returns.flatten()
|
target_q = batch.returns.flatten()
|
||||||
td = current_q - target_q
|
td = current_q - target_q
|
||||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
|
||||||
batch.update_weight(batch.indice, td)
|
|
||||||
critic_loss = (td.pow(2) * batch.weight).mean()
|
critic_loss = (td.pow(2) * batch.weight).mean()
|
||||||
# critic_loss = F.mse_loss(current_q, target_q)
|
batch.weight = td # prio-buffer
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
self.critic_optim.step()
|
self.critic_optim.step()
|
||||||
|
@ -160,10 +160,8 @@ class DQNPolicy(BasePolicy):
|
|||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = to_torch_as(batch.returns, q).flatten()
|
r = to_torch_as(batch.returns, q).flatten()
|
||||||
td = r - q
|
td = r - q
|
||||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
|
||||||
batch.update_weight(batch.indice, td)
|
|
||||||
loss = (td.pow(2) * batch.weight).mean()
|
loss = (td.pow(2) * batch.weight).mean()
|
||||||
# loss = F.mse_loss(q, r)
|
batch.weight = td # prio-buffer
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self._cnt += 1
|
self._cnt += 1
|
||||||
|
@ -51,7 +51,7 @@ class PGPolicy(BasePolicy):
|
|||||||
# batch.returns = self._vectorized_returns(batch)
|
# batch.returns = self._vectorized_returns(batch)
|
||||||
# return batch
|
# return batch
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
batch, gamma=self._gamma, gae_lambda=1.)
|
batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm)
|
||||||
|
|
||||||
def forward(self, batch: Batch,
|
def forward(self, batch: Batch,
|
||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
@ -81,9 +81,6 @@ class PGPolicy(BasePolicy):
|
|||||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||||
**kwargs) -> Dict[str, List[float]]:
|
**kwargs) -> Dict[str, List[float]]:
|
||||||
losses = []
|
losses = []
|
||||||
r = batch.returns
|
|
||||||
if self._rew_norm and not np.isclose(r.std(), 0):
|
|
||||||
batch.returns = (r - r.mean()) / r.std()
|
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -79,18 +79,29 @@ class PPOPolicy(PGPolicy):
|
|||||||
indice: np.ndarray) -> Batch:
|
indice: np.ndarray) -> Batch:
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
mean, std = batch.rew.mean(), batch.rew.std()
|
mean, std = batch.rew.mean(), batch.rew.std()
|
||||||
if not np.isclose(std, 0):
|
if not np.isclose(std, 0, 1e-2):
|
||||||
batch.rew = (batch.rew - mean) / std
|
batch.rew = (batch.rew - mean) / std
|
||||||
if self._lambda in [0, 1]:
|
v, v_, old_log_prob = [], [], []
|
||||||
return self.compute_episodic_return(
|
|
||||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
|
||||||
v_ = []
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch, shuffle=False):
|
for b in batch.split(self._batch, shuffle=False):
|
||||||
v_.append(self.critic(b.obs_next))
|
v_.append(self.critic(b.obs_next))
|
||||||
|
v.append(self.critic(b.obs))
|
||||||
|
old_log_prob.append(self(b).dist.log_prob(
|
||||||
|
to_torch_as(b.act, v[0])))
|
||||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||||
return self.compute_episodic_return(
|
batch = self.compute_episodic_return(
|
||||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
||||||
|
rew_norm=self._rew_norm)
|
||||||
|
batch.v = torch.cat(v, dim=0).flatten() # old value
|
||||||
|
batch.act = to_torch_as(batch.act, v[0])
|
||||||
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||||
|
batch.returns = to_torch_as(batch.returns, v[0])
|
||||||
|
batch.adv = batch.returns - batch.v
|
||||||
|
if self._rew_norm:
|
||||||
|
mean, std = batch.adv.mean(), batch.adv.std()
|
||||||
|
if not np.isclose(std.item(), 0, 1e-2):
|
||||||
|
batch.adv = (batch.adv - mean) / std
|
||||||
|
return batch
|
||||||
|
|
||||||
def forward(self, batch: Batch,
|
def forward(self, batch: Batch,
|
||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
@ -123,26 +134,6 @@ class PPOPolicy(PGPolicy):
|
|||||||
**kwargs) -> Dict[str, List[float]]:
|
**kwargs) -> Dict[str, List[float]]:
|
||||||
self._batch = batch_size
|
self._batch = batch_size
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
v = []
|
|
||||||
old_log_prob = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for b in batch.split(batch_size, shuffle=False):
|
|
||||||
v.append(self.critic(b.obs))
|
|
||||||
old_log_prob.append(self(b).dist.log_prob(
|
|
||||||
to_torch_as(b.act, v[0])))
|
|
||||||
batch.v = torch.cat(v, dim=0).flatten() # old value
|
|
||||||
batch.act = to_torch_as(batch.act, v[0])
|
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
|
||||||
batch.returns = to_torch_as(batch.returns, v[0])
|
|
||||||
if self._rew_norm:
|
|
||||||
mean, std = batch.returns.mean(), batch.returns.std()
|
|
||||||
if not np.isclose(std.item(), 0):
|
|
||||||
batch.returns = (batch.returns - mean) / std
|
|
||||||
batch.adv = batch.returns - batch.v
|
|
||||||
if self._rew_norm:
|
|
||||||
mean, std = batch.adv.mean(), batch.adv.std()
|
|
||||||
if not np.isclose(std.item(), 0):
|
|
||||||
batch.adv = (batch.adv - mean) / std
|
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
|
@ -154,9 +154,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
self.critic2_optim.zero_grad()
|
self.critic2_optim.zero_grad()
|
||||||
critic2_loss.backward()
|
critic2_loss.backward()
|
||||||
self.critic2_optim.step()
|
self.critic2_optim.step()
|
||||||
# prio-buffer
|
batch.weight = (td1 + td2) / 2. # prio-buffer
|
||||||
if hasattr(batch, 'update_weight'):
|
|
||||||
batch.update_weight(batch.indice, (td1 + td2) / 2.)
|
|
||||||
# actor
|
# actor
|
||||||
obs_result = self(batch, explorating=False)
|
obs_result = self(batch, explorating=False)
|
||||||
a = obs_result.act
|
a = obs_result.act
|
||||||
|
@ -132,8 +132,7 @@ class TD3Policy(DDPGPolicy):
|
|||||||
self.critic2_optim.zero_grad()
|
self.critic2_optim.zero_grad()
|
||||||
critic2_loss.backward()
|
critic2_loss.backward()
|
||||||
self.critic2_optim.step()
|
self.critic2_optim.step()
|
||||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
batch.weight = (td1 + td2) / 2. # prio-buffer
|
||||||
batch.update_weight(batch.indice, (td1 + td2) / 2.)
|
|
||||||
if self._cnt % self._freq == 0:
|
if self._cnt % self._freq == 0:
|
||||||
actor_loss = -self.critic1(
|
actor_loss = -self.critic1(
|
||||||
batch.obs, self(batch, eps=0).act).mean()
|
batch.obs, self(batch, eps=0).act).mean()
|
||||||
|
@ -28,7 +28,8 @@ def offpolicy_trainer(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
test_in_train: bool = True,
|
test_in_train: bool = True,
|
||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
"""A wrapper for off-policy trainer procedure.
|
"""A wrapper for off-policy trainer procedure. The ``step`` in trainer
|
||||||
|
means a policy network update.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
class.
|
class.
|
||||||
@ -47,8 +48,9 @@ def offpolicy_trainer(
|
|||||||
:param int batch_size: the batch size of sample data, which is going to
|
:param int batch_size: the batch size of sample data, which is going to
|
||||||
feed in the policy network.
|
feed in the policy network.
|
||||||
:param int update_per_step: the number of times the policy network would
|
:param int update_per_step: the number of times the policy network would
|
||||||
be updated after frames be collected. In other words, collect some
|
be updated after frames are collected, for example, set it to 256 means
|
||||||
frames and do some policy network update.
|
it updates policy 256 times once after ``collect_per_step`` frames are
|
||||||
|
collected.
|
||||||
:param function train_fn: a function receives the current number of epoch
|
:param function train_fn: a function receives the current number of epoch
|
||||||
index and performs some operations at the beginning of training in this
|
index and performs some operations at the beginning of training in this
|
||||||
epoch.
|
epoch.
|
||||||
@ -103,7 +105,7 @@ def offpolicy_trainer(
|
|||||||
for i in range(update_per_step * min(
|
for i in range(update_per_step * min(
|
||||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||||
global_step += 1
|
global_step += 1
|
||||||
losses = policy.learn(train_collector.sample(batch_size))
|
losses = policy.update(batch_size, train_collector.buffer)
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
if writer and global_step % log_interval == 0:
|
if writer and global_step % log_interval == 0:
|
||||||
|
@ -28,7 +28,8 @@ def onpolicy_trainer(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
test_in_train: bool = True,
|
test_in_train: bool = True,
|
||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
"""A wrapper for on-policy trainer procedure.
|
"""A wrapper for on-policy trainer procedure. The ``step`` in trainer means
|
||||||
|
a policy network update.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
class.
|
class.
|
||||||
@ -101,8 +102,8 @@ def onpolicy_trainer(
|
|||||||
policy.train()
|
policy.train()
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
losses = policy.learn(
|
losses = policy.update(
|
||||||
train_collector.sample(0), batch_size, repeat_per_collect)
|
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||||
train_collector.reset_buffer()
|
train_collector.reset_buffer()
|
||||||
step = 1
|
step = 1
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user