clarify updating state (#224)

Add an indicator(i.e. `self.learning`) of learning will be convenient for distinguishing state of policy.
Meanwhile, the state of `self.training` will be undisputed in the training stage.
Related issue: #211 

Others:
- fix a bug in DDQN: target_q could not be sampled from np.random.rand
- fix a bug in DQN atari net: it should add a ReLU before the last layer
- fix a bug in collector timing

Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
rocknamx 2020-09-22 16:28:46 +08:00 committed by GitHub
parent eec0826fd3
commit bf39b9ef7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 67 additions and 19 deletions

View File

@ -27,6 +27,7 @@
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)

View File

@ -18,6 +18,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

View File

@ -75,6 +75,34 @@ A policy class typically has the following parts:
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
.. _policy_state:
States for policy
^^^^^^^^^^^^^^^^^
During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state.
The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process.
As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer;
we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process.
In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows:
+-----------------------------------+-----------------+-----------------+
| State for policy | policy.training | policy.updating |
+================+==================+=================+=================+
| | Collecting state | True | False |
| Training state +------------------+-----------------+-----------------+
| | Updating state | True | True |
+----------------+------------------+-----------------+-----------------+
| Testing state | False | False |
+-----------------------------------+-----------------+-----------------+
``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case.
policy.forward
^^^^^^^^^^^^^^

View File

@ -129,10 +129,14 @@ class Collector(object):
obs_next={}, policy={})
self.reset_env()
self.reset_buffer()
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
self.reset_stat()
if self._action_noise is not None:
self._action_noise.reset()
def reset_stat(self) -> None:
"""Reset the statistic variables."""
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
def reset_buffer(self) -> None:
"""Reset the main data buffer."""
if self.buffer is not None:

View File

@ -60,6 +60,7 @@ class BasePolicy(ABC, nn.Module):
self.observation_space = observation_space
self.action_space = action_space
self.agent_id = 0
self.updating = False
self._compile()
def set_agent_id(self, agent_id: int) -> None:
@ -118,6 +119,13 @@ class BasePolicy(ABC, nn.Module):
:return: A dict which includes loss and its corresponding label.
.. note::
In order to distinguish the collecting state, updating state and
testing state, you can check the policy state by ``self.training``
and ``self.updating``. Please refer to :ref:`policy_state` for more
detailed explanation.
.. warning::
If you use ``torch.distributions.Normal`` and
@ -146,6 +154,10 @@ class BasePolicy(ABC, nn.Module):
"""Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn.
In addition, this function will change the value of ``self.updating``:
it will be False before this function and will be True when executing
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
explanation.
:param int sample_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with given sample_size.
@ -154,9 +166,11 @@ class BasePolicy(ABC, nn.Module):
if buffer is None:
return {}
batch, indice = buffer.sample(sample_size)
self.updating = True
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indice)
self.updating = False
return result
@staticmethod

View File

@ -103,9 +103,9 @@ class DDPGPolicy(BasePolicy):
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next',
explorating=False).act)
target_q = self.critic_old(
batch.obs_next,
self(batch, model='actor_old', input='obs_next').act)
return target_q
def process_fn(
@ -124,7 +124,6 @@ class DDPGPolicy(BasePolicy):
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "actor",
input: str = "obs",
explorating: bool = True,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
@ -143,7 +142,7 @@ class DDPGPolicy(BasePolicy):
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self._noise and self.training and explorating:
if self._noise and not self.updating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
@ -158,7 +157,7 @@ class DDPGPolicy(BasePolicy):
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
action = self(batch, explorating=False).act
action = self(batch).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad()
actor_loss.backward()

View File

@ -80,7 +80,7 @@ class DQNPolicy(BasePolicy):
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(batch, input="obs_next", eps=0).act
a = self(batch, input="obs_next").act
with torch.no_grad():
target_q = self(
batch, model="model_old", input="obs_next"
@ -110,7 +110,6 @@ class DQNPolicy(BasePolicy):
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
eps: Optional[float] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
@ -152,12 +151,10 @@ class DQNPolicy(BasePolicy):
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act
if eps is None:
eps = self.eps
if not np.isclose(eps, 0.0):
# add eps to act in training or testing phase
if not self.updating and not np.isclose(self.eps, 0.0):
for i in range(len(q)):
if np.random.rand() < eps:
if np.random.rand() < self.eps:
q_ = np.random.rand(*q[i].shape)
if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
@ -169,7 +166,7 @@ class DQNPolicy(BasePolicy):
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
q = self(batch, eps=0.0).logits
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns.flatten(), q)
td = r - q

View File

@ -110,7 +110,6 @@ class SACPolicy(DDPGPolicy):
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = "obs",
explorating: bool = True,
**kwargs: Any,
) -> Batch:
obs = batch[input]
@ -123,7 +122,7 @@ class SACPolicy(DDPGPolicy):
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and self.training and explorating:
if self._noise is not None and not self.updating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
return Batch(
@ -134,7 +133,7 @@ class SACPolicy(DDPGPolicy):
) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
obs_next_result = self(batch, input='obs_next', explorating=False)
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
batch.act = to_torch_as(batch.act, a_)
target_q = torch.min(
@ -167,7 +166,7 @@ class SACPolicy(DDPGPolicy):
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
obs_result = self(batch, explorating=False)
obs_result = self(batch)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a).flatten()
current_q2a = self.critic2(batch.obs, a).flatten()

View File

@ -75,6 +75,8 @@ def offpolicy_trainer(
best_epoch, best_reward = -1, -1.0
stat: Dict[str, MovAvg] = {}
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train

View File

@ -75,6 +75,8 @@ def onpolicy_trainer(
best_epoch, best_reward = -1, -1.0
stat: Dict[str, MovAvg] = {}
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train

View File

@ -116,6 +116,7 @@ class DQN(nn.Module):
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(linear_input_size, 512),
nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape)),
)