clarify updating state (#224)
Add an indicator(i.e. `self.learning`) of learning will be convenient for distinguishing state of policy. Meanwhile, the state of `self.training` will be undisputed in the training stage. Related issue: #211 Others: - fix a bug in DDQN: target_q could not be sampled from np.random.rand - fix a bug in DQN atari net: it should add a ReLU before the last layer - fix a bug in collector timing Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
parent
eec0826fd3
commit
bf39b9ef7d
@ -27,6 +27,7 @@
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
|
@ -18,6 +18,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
@ -75,6 +75,34 @@ A policy class typically has the following parts:
|
||||
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
|
||||
|
||||
|
||||
.. _policy_state:
|
||||
|
||||
States for policy
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state.
|
||||
|
||||
The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process.
|
||||
|
||||
As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer;
|
||||
we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process.
|
||||
|
||||
|
||||
In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows:
|
||||
|
||||
+-----------------------------------+-----------------+-----------------+
|
||||
| State for policy | policy.training | policy.updating |
|
||||
+================+==================+=================+=================+
|
||||
| | Collecting state | True | False |
|
||||
| Training state +------------------+-----------------+-----------------+
|
||||
| | Updating state | True | True |
|
||||
+----------------+------------------+-----------------+-----------------+
|
||||
| Testing state | False | False |
|
||||
+-----------------------------------+-----------------+-----------------+
|
||||
|
||||
``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case.
|
||||
|
||||
|
||||
policy.forward
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -129,10 +129,14 @@ class Collector(object):
|
||||
obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
self.reset_stat()
|
||||
if self._action_noise is not None:
|
||||
self._action_noise.reset()
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
"""Reset the statistic variables."""
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
|
||||
def reset_buffer(self) -> None:
|
||||
"""Reset the main data buffer."""
|
||||
if self.buffer is not None:
|
||||
|
@ -60,6 +60,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.agent_id = 0
|
||||
self.updating = False
|
||||
self._compile()
|
||||
|
||||
def set_agent_id(self, agent_id: int) -> None:
|
||||
@ -118,6 +119,13 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
:return: A dict which includes loss and its corresponding label.
|
||||
|
||||
.. note::
|
||||
|
||||
In order to distinguish the collecting state, updating state and
|
||||
testing state, you can check the policy state by ``self.training``
|
||||
and ``self.updating``. Please refer to :ref:`policy_state` for more
|
||||
detailed explanation.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use ``torch.distributions.Normal`` and
|
||||
@ -146,6 +154,10 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""Update the policy network and replay buffer.
|
||||
|
||||
It includes 3 function steps: process_fn, learn, and post_process_fn.
|
||||
In addition, this function will change the value of ``self.updating``:
|
||||
it will be False before this function and will be True when executing
|
||||
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
|
||||
explanation.
|
||||
|
||||
:param int sample_size: 0 means it will extract all the data from the
|
||||
buffer, otherwise it will sample a batch with given sample_size.
|
||||
@ -154,9 +166,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
if buffer is None:
|
||||
return {}
|
||||
batch, indice = buffer.sample(sample_size)
|
||||
self.updating = True
|
||||
batch = self.process_fn(batch, buffer, indice)
|
||||
result = self.learn(batch, **kwargs)
|
||||
self.post_process_fn(batch, buffer, indice)
|
||||
self.updating = False
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
@ -103,9 +103,9 @@ class DDPGPolicy(BasePolicy):
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
batch, model='actor_old', input='obs_next',
|
||||
explorating=False).act)
|
||||
target_q = self.critic_old(
|
||||
batch.obs_next,
|
||||
self(batch, model='actor_old', input='obs_next').act)
|
||||
return target_q
|
||||
|
||||
def process_fn(
|
||||
@ -124,7 +124,6 @@ class DDPGPolicy(BasePolicy):
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "actor",
|
||||
input: str = "obs",
|
||||
explorating: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
@ -143,7 +142,7 @@ class DDPGPolicy(BasePolicy):
|
||||
obs = batch[input]
|
||||
actions, h = model(obs, state=state, info=batch.info)
|
||||
actions += self._action_bias
|
||||
if self._noise and self.training and explorating:
|
||||
if self._noise and not self.updating:
|
||||
actions += to_torch_as(self._noise(actions.shape), actions)
|
||||
actions = actions.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=actions, state=h)
|
||||
@ -158,7 +157,7 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optim.step()
|
||||
action = self(batch, explorating=False).act
|
||||
action = self(batch).act
|
||||
actor_loss = -self.critic(batch.obs, action).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
|
@ -80,7 +80,7 @@ class DQNPolicy(BasePolicy):
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(batch, input="obs_next", eps=0).act
|
||||
a = self(batch, input="obs_next").act
|
||||
with torch.no_grad():
|
||||
target_q = self(
|
||||
batch, model="model_old", input="obs_next"
|
||||
@ -110,7 +110,6 @@ class DQNPolicy(BasePolicy):
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
eps: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
@ -152,12 +151,10 @@ class DQNPolicy(BasePolicy):
|
||||
q_: np.ndarray = to_numpy(q)
|
||||
q_[~obs.mask] = -np.inf
|
||||
act = q_.argmax(axis=1)
|
||||
# add eps to act
|
||||
if eps is None:
|
||||
eps = self.eps
|
||||
if not np.isclose(eps, 0.0):
|
||||
# add eps to act in training or testing phase
|
||||
if not self.updating and not np.isclose(self.eps, 0.0):
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < eps:
|
||||
if np.random.rand() < self.eps:
|
||||
q_ = np.random.rand(*q[i].shape)
|
||||
if hasattr(obs, "mask"):
|
||||
q_[~obs.mask[i]] = -np.inf
|
||||
@ -169,7 +166,7 @@ class DQNPolicy(BasePolicy):
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
q = self(batch, eps=0.0).logits
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns.flatten(), q)
|
||||
td = r - q
|
||||
|
@ -110,7 +110,6 @@ class SACPolicy(DDPGPolicy):
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = "obs",
|
||||
explorating: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
@ -123,7 +122,7 @@ class SACPolicy(DDPGPolicy):
|
||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||
log_prob = dist.log_prob(x).unsqueeze(-1)
|
||||
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
||||
if self._noise is not None and self.training and explorating:
|
||||
if self._noise is not None and not self.updating:
|
||||
act += to_torch_as(self._noise(act.shape), act)
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(
|
||||
@ -134,7 +133,7 @@ class SACPolicy(DDPGPolicy):
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next', explorating=False)
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
batch.act = to_torch_as(batch.act, a_)
|
||||
target_q = torch.min(
|
||||
@ -167,7 +166,7 @@ class SACPolicy(DDPGPolicy):
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
obs_result = self(batch)
|
||||
a = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, a).flatten()
|
||||
current_q2a = self.critic2(batch.obs, a).flatten()
|
||||
|
@ -75,6 +75,8 @@ def offpolicy_trainer(
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
|
@ -75,6 +75,8 @@ def onpolicy_trainer(
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
|
@ -116,6 +116,7 @@ class DQN(nn.Module):
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Flatten(),
|
||||
nn.Linear(linear_input_size, 512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, np.prod(action_shape)),
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user