clarify updating state (#224)
Add an indicator(i.e. `self.learning`) of learning will be convenient for distinguishing state of policy. Meanwhile, the state of `self.training` will be undisputed in the training stage. Related issue: #211 Others: - fix a bug in DDQN: target_q could not be sampled from np.random.rand - fix a bug in DQN atari net: it should add a ReLU before the last layer - fix a bug in collector timing Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
parent
eec0826fd3
commit
bf39b9ef7d
@ -27,6 +27,7 @@
|
|||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||||
|
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||||
- Vanilla Imitation Learning
|
- Vanilla Imitation Learning
|
||||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||||
|
@ -18,6 +18,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||||
|
@ -75,6 +75,34 @@ A policy class typically has the following parts:
|
|||||||
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
|
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.
|
||||||
|
|
||||||
|
|
||||||
|
.. _policy_state:
|
||||||
|
|
||||||
|
States for policy
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state.
|
||||||
|
|
||||||
|
The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process.
|
||||||
|
|
||||||
|
As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer;
|
||||||
|
we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process.
|
||||||
|
|
||||||
|
|
||||||
|
In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows:
|
||||||
|
|
||||||
|
+-----------------------------------+-----------------+-----------------+
|
||||||
|
| State for policy | policy.training | policy.updating |
|
||||||
|
+================+==================+=================+=================+
|
||||||
|
| | Collecting state | True | False |
|
||||||
|
| Training state +------------------+-----------------+-----------------+
|
||||||
|
| | Updating state | True | True |
|
||||||
|
+----------------+------------------+-----------------+-----------------+
|
||||||
|
| Testing state | False | False |
|
||||||
|
+-----------------------------------+-----------------+-----------------+
|
||||||
|
|
||||||
|
``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case.
|
||||||
|
|
||||||
|
|
||||||
policy.forward
|
policy.forward
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -129,10 +129,14 @@ class Collector(object):
|
|||||||
obs_next={}, policy={})
|
obs_next={}, policy={})
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
self.reset_buffer()
|
self.reset_buffer()
|
||||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
self.reset_stat()
|
||||||
if self._action_noise is not None:
|
if self._action_noise is not None:
|
||||||
self._action_noise.reset()
|
self._action_noise.reset()
|
||||||
|
|
||||||
|
def reset_stat(self) -> None:
|
||||||
|
"""Reset the statistic variables."""
|
||||||
|
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||||
|
|
||||||
def reset_buffer(self) -> None:
|
def reset_buffer(self) -> None:
|
||||||
"""Reset the main data buffer."""
|
"""Reset the main data buffer."""
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
|
@ -60,6 +60,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.agent_id = 0
|
self.agent_id = 0
|
||||||
|
self.updating = False
|
||||||
self._compile()
|
self._compile()
|
||||||
|
|
||||||
def set_agent_id(self, agent_id: int) -> None:
|
def set_agent_id(self, agent_id: int) -> None:
|
||||||
@ -118,6 +119,13 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
|
|
||||||
:return: A dict which includes loss and its corresponding label.
|
:return: A dict which includes loss and its corresponding label.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
In order to distinguish the collecting state, updating state and
|
||||||
|
testing state, you can check the policy state by ``self.training``
|
||||||
|
and ``self.updating``. Please refer to :ref:`policy_state` for more
|
||||||
|
detailed explanation.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
If you use ``torch.distributions.Normal`` and
|
If you use ``torch.distributions.Normal`` and
|
||||||
@ -146,6 +154,10 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
"""Update the policy network and replay buffer.
|
"""Update the policy network and replay buffer.
|
||||||
|
|
||||||
It includes 3 function steps: process_fn, learn, and post_process_fn.
|
It includes 3 function steps: process_fn, learn, and post_process_fn.
|
||||||
|
In addition, this function will change the value of ``self.updating``:
|
||||||
|
it will be False before this function and will be True when executing
|
||||||
|
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
|
||||||
|
explanation.
|
||||||
|
|
||||||
:param int sample_size: 0 means it will extract all the data from the
|
:param int sample_size: 0 means it will extract all the data from the
|
||||||
buffer, otherwise it will sample a batch with given sample_size.
|
buffer, otherwise it will sample a batch with given sample_size.
|
||||||
@ -154,9 +166,11 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
if buffer is None:
|
if buffer is None:
|
||||||
return {}
|
return {}
|
||||||
batch, indice = buffer.sample(sample_size)
|
batch, indice = buffer.sample(sample_size)
|
||||||
|
self.updating = True
|
||||||
batch = self.process_fn(batch, buffer, indice)
|
batch = self.process_fn(batch, buffer, indice)
|
||||||
result = self.learn(batch, **kwargs)
|
result = self.learn(batch, **kwargs)
|
||||||
self.post_process_fn(batch, buffer, indice)
|
self.post_process_fn(batch, buffer, indice)
|
||||||
|
self.updating = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -103,9 +103,9 @@ class DDPGPolicy(BasePolicy):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_q = self.critic_old(batch.obs_next, self(
|
target_q = self.critic_old(
|
||||||
batch, model='actor_old', input='obs_next',
|
batch.obs_next,
|
||||||
explorating=False).act)
|
self(batch, model='actor_old', input='obs_next').act)
|
||||||
return target_q
|
return target_q
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
@ -124,7 +124,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
model: str = "actor",
|
model: str = "actor",
|
||||||
input: str = "obs",
|
input: str = "obs",
|
||||||
explorating: bool = True,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
@ -143,7 +142,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
obs = batch[input]
|
obs = batch[input]
|
||||||
actions, h = model(obs, state=state, info=batch.info)
|
actions, h = model(obs, state=state, info=batch.info)
|
||||||
actions += self._action_bias
|
actions += self._action_bias
|
||||||
if self._noise and self.training and explorating:
|
if self._noise and not self.updating:
|
||||||
actions += to_torch_as(self._noise(actions.shape), actions)
|
actions += to_torch_as(self._noise(actions.shape), actions)
|
||||||
actions = actions.clamp(self._range[0], self._range[1])
|
actions = actions.clamp(self._range[0], self._range[1])
|
||||||
return Batch(act=actions, state=h)
|
return Batch(act=actions, state=h)
|
||||||
@ -158,7 +157,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
self.critic_optim.step()
|
self.critic_optim.step()
|
||||||
action = self(batch, explorating=False).act
|
action = self(batch).act
|
||||||
actor_loss = -self.critic(batch.obs, action).mean()
|
actor_loss = -self.critic(batch.obs, action).mean()
|
||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
|
@ -80,7 +80,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||||
if self._target:
|
if self._target:
|
||||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||||
a = self(batch, input="obs_next", eps=0).act
|
a = self(batch, input="obs_next").act
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_q = self(
|
target_q = self(
|
||||||
batch, model="model_old", input="obs_next"
|
batch, model="model_old", input="obs_next"
|
||||||
@ -110,7 +110,6 @@ class DQNPolicy(BasePolicy):
|
|||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
model: str = "model",
|
model: str = "model",
|
||||||
input: str = "obs",
|
input: str = "obs",
|
||||||
eps: Optional[float] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
"""Compute action over the given batch data.
|
"""Compute action over the given batch data.
|
||||||
@ -152,12 +151,10 @@ class DQNPolicy(BasePolicy):
|
|||||||
q_: np.ndarray = to_numpy(q)
|
q_: np.ndarray = to_numpy(q)
|
||||||
q_[~obs.mask] = -np.inf
|
q_[~obs.mask] = -np.inf
|
||||||
act = q_.argmax(axis=1)
|
act = q_.argmax(axis=1)
|
||||||
# add eps to act
|
# add eps to act in training or testing phase
|
||||||
if eps is None:
|
if not self.updating and not np.isclose(self.eps, 0.0):
|
||||||
eps = self.eps
|
|
||||||
if not np.isclose(eps, 0.0):
|
|
||||||
for i in range(len(q)):
|
for i in range(len(q)):
|
||||||
if np.random.rand() < eps:
|
if np.random.rand() < self.eps:
|
||||||
q_ = np.random.rand(*q[i].shape)
|
q_ = np.random.rand(*q[i].shape)
|
||||||
if hasattr(obs, "mask"):
|
if hasattr(obs, "mask"):
|
||||||
q_[~obs.mask[i]] = -np.inf
|
q_[~obs.mask[i]] = -np.inf
|
||||||
@ -169,7 +166,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
weight = batch.pop("weight", 1.0)
|
weight = batch.pop("weight", 1.0)
|
||||||
q = self(batch, eps=0.0).logits
|
q = self(batch).logits
|
||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = to_torch_as(batch.returns.flatten(), q)
|
r = to_torch_as(batch.returns.flatten(), q)
|
||||||
td = r - q
|
td = r - q
|
||||||
|
@ -110,7 +110,6 @@ class SACPolicy(DDPGPolicy):
|
|||||||
batch: Batch,
|
batch: Batch,
|
||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
input: str = "obs",
|
input: str = "obs",
|
||||||
explorating: bool = True,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
obs = batch[input]
|
obs = batch[input]
|
||||||
@ -123,7 +122,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||||
log_prob = dist.log_prob(x).unsqueeze(-1)
|
log_prob = dist.log_prob(x).unsqueeze(-1)
|
||||||
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
||||||
if self._noise is not None and self.training and explorating:
|
if self._noise is not None and not self.updating:
|
||||||
act += to_torch_as(self._noise(act.shape), act)
|
act += to_torch_as(self._noise(act.shape), act)
|
||||||
act = act.clamp(self._range[0], self._range[1])
|
act = act.clamp(self._range[0], self._range[1])
|
||||||
return Batch(
|
return Batch(
|
||||||
@ -134,7 +133,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
obs_next_result = self(batch, input='obs_next', explorating=False)
|
obs_next_result = self(batch, input='obs_next')
|
||||||
a_ = obs_next_result.act
|
a_ = obs_next_result.act
|
||||||
batch.act = to_torch_as(batch.act, a_)
|
batch.act = to_torch_as(batch.act, a_)
|
||||||
target_q = torch.min(
|
target_q = torch.min(
|
||||||
@ -167,7 +166,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||||
|
|
||||||
# actor
|
# actor
|
||||||
obs_result = self(batch, explorating=False)
|
obs_result = self(batch)
|
||||||
a = obs_result.act
|
a = obs_result.act
|
||||||
current_q1a = self.critic1(batch.obs, a).flatten()
|
current_q1a = self.critic1(batch.obs, a).flatten()
|
||||||
current_q2a = self.critic2(batch.obs, a).flatten()
|
current_q2a = self.critic2(batch.obs, a).flatten()
|
||||||
|
@ -75,6 +75,8 @@ def offpolicy_trainer(
|
|||||||
best_epoch, best_reward = -1, -1.0
|
best_epoch, best_reward = -1, -1.0
|
||||||
stat: Dict[str, MovAvg] = {}
|
stat: Dict[str, MovAvg] = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
train_collector.reset_stat()
|
||||||
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
|
@ -75,6 +75,8 @@ def onpolicy_trainer(
|
|||||||
best_epoch, best_reward = -1, -1.0
|
best_epoch, best_reward = -1, -1.0
|
||||||
stat: Dict[str, MovAvg] = {}
|
stat: Dict[str, MovAvg] = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
train_collector.reset_stat()
|
||||||
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
|
@ -116,6 +116,7 @@ class DQN(nn.Module):
|
|||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(linear_input_size, 512),
|
nn.Linear(linear_input_size, 512),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(512, np.prod(action_shape)),
|
nn.Linear(512, np.prod(action_shape)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user