fix qvalue mask_action error for obs_next (#310)

* fix #309
* remove for-loop in dqn expl_noise
This commit is contained in:
n+e 2021-03-15 08:06:24 +08:00 committed by GitHub
parent 243ab43b3c
commit ec23c7efe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 34 deletions

View File

@ -17,7 +17,7 @@ from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplay
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
@ -134,7 +134,6 @@ def test_qrdqn(args=get_args()):
def test_pqrdqn(args=get_args()): def test_pqrdqn(args=get_args()):
args.prioritized_replay = True args.prioritized_replay = True
args.gamma = .95 args.gamma = .95
args.seed = 1
test_qrdqn(args) test_qrdqn(args)

View File

@ -287,9 +287,9 @@ class BaseVectorEnv(gym.Env):
def normalize_obs(self, obs: np.ndarray) -> np.ndarray: def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms.""" """Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs: if self.obs_rms and self.norm_obs:
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps) obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max) obs = np.clip(obs, -clip_max, clip_max)
return obs return obs

View File

@ -1,6 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from typing import Any, Dict from typing import Any, Dict, Optional
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
@ -57,14 +57,13 @@ class C51Policy(DQNPolicy):
) )
self.delta_z = (v_max - v_min) / (num_atoms - 1) self.delta_z = (v_max - v_min) / (num_atoms - 1)
def _target_q( def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms] return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: def compute_q_value(
"""Compute the q value based on the network's raw output logits.""" self, logits: torch.Tensor, mask: Optional[np.ndarray]
return (logits * self.support).sum(2) ) -> torch.Tensor:
return super().compute_q_value((logits * self.support).sum(2), mask)
def _target_dist(self, batch: Batch) -> torch.Tensor: def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target: if self._target:

View File

@ -73,13 +73,13 @@ class DQNPolicy(BasePolicy):
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n} batch = buffer[indice] # batch.obs_next: s_{t+n}
# target_Q = Q_old(s_, argmax(Q_new(s_, *))) result = self(batch, input="obs_next")
if self._target: if self._target:
a = self(batch, input="obs_next").act # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits target_q = self(batch, model="model_old", input="obs_next").logits
target_q = target_q[np.arange(len(a)), a]
else: else:
target_q = self(batch, input="obs_next").logits.max(dim=1)[0] target_q = result.logits
target_q = target_q[np.arange(len(result.act)), result.act]
return target_q return target_q
def process_fn( def process_fn(
@ -95,8 +95,14 @@ class DQNPolicy(BasePolicy):
self._gamma, self._n_step, self._rew_norm) self._gamma, self._n_step, self._rew_norm)
return batch return batch
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: def compute_q_value(
"""Compute the q value based on the network's raw output logits.""" self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
"""Compute the q value based on the network's raw output and action mask."""
if mask is not None:
# the masked q value should be smaller than logits.min()
min_value = logits.min() - logits.max() - 1.0
logits = logits + to_torch_as(1 - mask, logits) * min_value
return logits return logits
def forward( def forward(
@ -140,15 +146,10 @@ class DQNPolicy(BasePolicy):
obs = batch[input] obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs obs_ = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info) logits, h = model(obs_, state=state, info=batch.info)
q = self.compute_q_value(logits) q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"): if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1] self.max_action_num = q.shape[1]
act: np.ndarray = to_numpy(q.max(dim=1)[1]) act = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
return Batch(logits=logits, act=act, state=h) return Batch(logits=logits, act=act, state=h)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
@ -169,10 +170,11 @@ class DQNPolicy(BasePolicy):
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if not np.isclose(self.eps, 0.0): if not np.isclose(self.eps, 0.0):
for i in range(len(act)): bsz = len(act)
if np.random.rand() < self.eps: rand_mask = np.random.rand(bsz) < self.eps
q_ = np.random.rand(self.max_action_num) q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch["obs"], "mask"): if hasattr(batch.obs, "mask"):
q_[~batch["obs"].mask[i]] = -np.inf q += batch.obs.mask
act[i] = q_.argmax() rand_act = q.argmax(axis=1)
act[rand_mask] = rand_act[rand_mask]
return act return act

View File

@ -1,8 +1,8 @@
import torch import torch
import warnings import warnings
import numpy as np import numpy as np
from typing import Any, Dict
import torch.nn.functional as F import torch.nn.functional as F
from typing import Any, Dict, Optional
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
@ -61,9 +61,10 @@ class QRDQNPolicy(DQNPolicy):
next_dist = next_dist[np.arange(len(a)), a, :] next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles] return next_dist # shape: [bsz, num_quantiles]
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: def compute_q_value(
"""Compute the q value based on the network's raw output logits.""" self, logits: torch.Tensor, mask: Optional[np.ndarray]
return logits.mean(2) ) -> torch.Tensor:
return super().compute_q_value(logits.mean(2), mask)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0: if self._target and self._iter % self._freq == 0: