fix qvalue mask_action error for obs_next (#310)
* fix #309 * remove for-loop in dqn expl_noise
This commit is contained in:
parent
243ab43b3c
commit
ec23c7efe9
@ -17,7 +17,7 @@ from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplay
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
@ -134,7 +134,6 @@ def test_qrdqn(args=get_args()):
|
||||
def test_pqrdqn(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_qrdqn(args)
|
||||
|
||||
|
||||
|
4
tianshou/env/venvs.py
vendored
4
tianshou/env/venvs.py
vendored
@ -287,9 +287,9 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
|
||||
"""Normalize observations by statistics in obs_rms."""
|
||||
clip_max = 10.0 # this magic number is from openai baselines
|
||||
# see baselines/common/vec_env/vec_normalize.py#L10
|
||||
if self.obs_rms and self.norm_obs:
|
||||
clip_max = 10.0 # this magic number is from openai baselines
|
||||
# see baselines/common/vec_env/vec_normalize.py#L10
|
||||
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
|
||||
obs = np.clip(obs, -clip_max, clip_max)
|
||||
return obs
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
@ -57,14 +57,13 @@ class C51Policy(DQNPolicy):
|
||||
)
|
||||
self.delta_z = (v_max - v_min) / (num_atoms - 1)
|
||||
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]
|
||||
|
||||
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output logits."""
|
||||
return (logits * self.support).sum(2)
|
||||
def compute_q_value(
|
||||
self, logits: torch.Tensor, mask: Optional[np.ndarray]
|
||||
) -> torch.Tensor:
|
||||
return super().compute_q_value((logits * self.support).sum(2), mask)
|
||||
|
||||
def _target_dist(self, batch: Batch) -> torch.Tensor:
|
||||
if self._target:
|
||||
|
@ -73,13 +73,13 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
result = self(batch, input="obs_next")
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
target_q = self(batch, model="model_old", input="obs_next").logits
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
|
||||
target_q = result.logits
|
||||
target_q = target_q[np.arange(len(result.act)), result.act]
|
||||
return target_q
|
||||
|
||||
def process_fn(
|
||||
@ -95,8 +95,14 @@ class DQNPolicy(BasePolicy):
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
return batch
|
||||
|
||||
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output logits."""
|
||||
def compute_q_value(
|
||||
self, logits: torch.Tensor, mask: Optional[np.ndarray]
|
||||
) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output and action mask."""
|
||||
if mask is not None:
|
||||
# the masked q value should be smaller than logits.min()
|
||||
min_value = logits.min() - logits.max() - 1.0
|
||||
logits = logits + to_torch_as(1 - mask, logits) * min_value
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
@ -140,15 +146,10 @@ class DQNPolicy(BasePolicy):
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, h = model(obs_, state=state, info=batch.info)
|
||||
q = self.compute_q_value(logits)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
if not hasattr(self, "max_action_num"):
|
||||
self.max_action_num = q.shape[1]
|
||||
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||
if hasattr(obs, "mask"):
|
||||
# some of actions are masked, they cannot be selected
|
||||
q_: np.ndarray = to_numpy(q)
|
||||
q_[~obs.mask] = -np.inf
|
||||
act = q_.argmax(axis=1)
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
return Batch(logits=logits, act=act, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
@ -169,10 +170,11 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
|
||||
if not np.isclose(self.eps, 0.0):
|
||||
for i in range(len(act)):
|
||||
if np.random.rand() < self.eps:
|
||||
q_ = np.random.rand(self.max_action_num)
|
||||
if hasattr(batch["obs"], "mask"):
|
||||
q_[~batch["obs"].mask[i]] = -np.inf
|
||||
act[i] = q_.argmax()
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
|
||||
if hasattr(batch.obs, "mask"):
|
||||
q += batch.obs.mask
|
||||
rand_act = q.argmax(axis=1)
|
||||
act[rand_mask] = rand_act[rand_mask]
|
||||
return act
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Dict
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
@ -61,9 +61,10 @@ class QRDQNPolicy(DQNPolicy):
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
return next_dist # shape: [bsz, num_quantiles]
|
||||
|
||||
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output logits."""
|
||||
return logits.mean(2)
|
||||
def compute_q_value(
|
||||
self, logits: torch.Tensor, mask: Optional[np.ndarray]
|
||||
) -> torch.Tensor:
|
||||
return super().compute_q_value(logits.mean(2), mask)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user