fix qvalue mask_action error for obs_next (#310)

* fix #309
* remove for-loop in dqn expl_noise
This commit is contained in:
n+e 2021-03-15 08:06:24 +08:00 committed by GitHub
parent 243ab43b3c
commit ec23c7efe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 34 deletions

View File

@ -17,7 +17,7 @@ from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplay
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
@ -134,7 +134,6 @@ def test_qrdqn(args=get_args()):
def test_pqrdqn(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_qrdqn(args)

View File

@ -287,9 +287,9 @@ class BaseVectorEnv(gym.Env):
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs:
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
return obs

View File

@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import Any, Dict
from typing import Any, Dict, Optional
from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
@ -57,14 +57,13 @@ class C51Policy(DQNPolicy):
)
self.delta_z = (v_max - v_min) / (num_atoms - 1)
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return (logits * self.support).sum(2)
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
return super().compute_q_value((logits * self.support).sum(2), mask)
def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target:

View File

@ -73,13 +73,13 @@ class DQNPolicy(BasePolicy):
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
result = self(batch, input="obs_next")
if self._target:
a = self(batch, input="obs_next").act
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
target_q = result.logits
target_q = target_q[np.arange(len(result.act)), result.act]
return target_q
def process_fn(
@ -95,8 +95,14 @@ class DQNPolicy(BasePolicy):
self._gamma, self._n_step, self._rew_norm)
return batch
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
"""Compute the q value based on the network's raw output and action mask."""
if mask is not None:
# the masked q value should be smaller than logits.min()
min_value = logits.min() - logits.max() - 1.0
logits = logits + to_torch_as(1 - mask, logits) * min_value
return logits
def forward(
@ -140,15 +146,10 @@ class DQNPolicy(BasePolicy):
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info)
q = self.compute_q_value(logits)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
@ -169,10 +170,11 @@ class DQNPolicy(BasePolicy):
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if not np.isclose(self.eps, 0.0):
for i in range(len(act)):
if np.random.rand() < self.eps:
q_ = np.random.rand(self.max_action_num)
if hasattr(batch["obs"], "mask"):
q_[~batch["obs"].mask[i]] = -np.inf
act[i] = q_.argmax()
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch.obs, "mask"):
q += batch.obs.mask
rand_act = q.argmax(axis=1)
act[rand_mask] = rand_act[rand_mask]
return act

View File

@ -1,8 +1,8 @@
import torch
import warnings
import numpy as np
from typing import Any, Dict
import torch.nn.functional as F
from typing import Any, Dict, Optional
from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
@ -61,9 +61,10 @@ class QRDQNPolicy(DQNPolicy):
next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles]
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return logits.mean(2)
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
return super().compute_q_value(logits.mean(2), mask)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0: