diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index bae20b1..44e109a 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -48,6 +48,17 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]] venv = SubprocVectorEnv(env_fns) +.. warning:: + + If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., + + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs may be the same with each other. + .. _preprocess_fn: Handle Batched Data Stream in Collector diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 9cbb243..6106ca5 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -40,6 +40,17 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_ For the demonstration, here we use the second block of codes. +.. warning:: + + If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., + + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs may be the same with each other. + .. _build_the_network: Build the Network diff --git a/test/base/env.py b/test/base/env.py index b096215..40f0597 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,20 +1,30 @@ import gym import time -from gym.spaces.discrete import Discrete +import numpy as np +from gym.spaces import Discrete, MultiDiscrete, Box class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. """ - def __init__(self, size, sleep=0, dict_state=False, ma_rew=0): + def __init__(self, size, sleep=0, dict_state=False, ma_rew=0, + multidiscrete_action=False): self.size = size self.sleep = sleep self.dict_state = dict_state self.ma_rew = ma_rew - self.action_space = Discrete(2) + self._md_action = multidiscrete_action + self.observation_space = Box(shape=(1, ), low=0, high=size - 1) + if multidiscrete_action: + self.action_space = MultiDiscrete([2, 2]) + else: + self.action_space = Discrete(2) self.reset() + def seed(self, seed=0): + np.random.seed(seed) + def reset(self, state=0): self.done = False self.index = state @@ -29,9 +39,12 @@ class MyTestEnv(gym.Env): def _get_dict_state(self): """Generate a dict_state if dict_state is True.""" - return {'index': self.index} if self.dict_state else self.index + return {'index': self.index, 'rand': np.random.rand()} \ + if self.dict_state else self.index def step(self, action): + if self._md_action: + action = action[0] if self.done: raise ValueError('step after done !!!') if self.sleep > 0: diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 77d1343..a823491 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,5 +1,5 @@ -import torch import copy +import torch import pickle import pytest import numpy as np diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9fa37b6..b4a69c1 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -113,8 +113,12 @@ def test_collector_with_dict_state(): env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) + envs.seed(666) + obs = envs.reset() + assert not np.isclose(obs[0]['rand'], obs[1]['rand']) c1 = Collector(policy, envs, ReplayBuffer(size=100), Logger.single_preprocess_fn) + c1.seed(0) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch = c1.sample(10) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index c820762..fabfdc9 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -86,7 +86,7 @@ def get_args(): parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) diff --git a/tianshou/env/basevecenv.py b/tianshou/env/basevecenv.py index 60394e3..b6c160d 100644 --- a/tianshou/env/basevecenv.py +++ b/tianshou/env/basevecenv.py @@ -27,6 +27,17 @@ class BaseVectorEnv(ABC, gym.Env): obs, rew, done, info = envs.step([1] * 8) # step synchronously envs.render() # render all environments envs.close() # close all environments + + .. warning:: + + If you use your own environment, please make sure the ``seed`` method + is set up properly, e.g., + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs may be the same with each other. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 8b98f70..2a5c123 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -108,8 +108,9 @@ class A2CPolicy(PGPolicy): v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) - a_loss = -(dist.log_prob(a).flatten() * (r - v).detach() - ).mean() + log_prob = dist.log_prob(a).reshape( + r.shape[0], -1).transpose(0, 1) + a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 1627199..9bdbd11 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -157,7 +157,7 @@ class DQNPolicy(BasePolicy): if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() - q = self(batch).logits + q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() if hasattr(batch, 'update_weight'): diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 09f90ea..d6176e6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -90,7 +90,9 @@ class PGPolicy(BasePolicy): dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) - loss = -(dist.log_prob(a).flatten() * r).sum() + log_prob = dist.log_prob(a).reshape( + r.shape[0], -1).transpose(0, 1) + loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index bae6907..2d1dece 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -132,7 +132,7 @@ class PPOPolicy(PGPolicy): to_torch_as(b.act, v[0]))) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0).flatten() + batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, v[0]) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() @@ -147,8 +147,8 @@ class PPOPolicy(PGPolicy): for b in batch.split(batch_size): dist = self(b).dist value = self.critic(b.obs).flatten() - ratio = (dist.log_prob(b.act).flatten() - - b.logp_old).exp().float() + ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv surr2 = ratio.clamp(1. - self._eps_clip, 1. + self._eps_clip) * b.adv diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index dc85ac5..2401ebd 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -36,7 +36,7 @@ class Net(nn.Module): def forward(self, s, state=None, info={}): """s -> flatten -> logits""" s = to_torch(s, device=self.device, dtype=torch.float32) - s = s.flatten(1) + s = s.reshape(s.size(0), -1) logits = self.model(s) return logits, state