Yet another 3 fix (#160)
1. DQN learn should keep eps=0 2. Add a warning of env.seed in VecEnv 3. fix #162 of multi-dim action
This commit is contained in:
parent
bfeffe1f97
commit
38a95c19da
@ -48,6 +48,17 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
|
||||
|
||||
::
|
||||
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
.. _preprocess_fn:
|
||||
|
||||
Handle Batched Data Stream in Collector
|
||||
|
@ -40,6 +40,17 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_
|
||||
|
||||
For the demonstration, here we use the second block of codes.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
|
||||
|
||||
::
|
||||
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
.. _build_the_network:
|
||||
|
||||
Build the Network
|
||||
|
@ -1,20 +1,30 @@
|
||||
import gym
|
||||
import time
|
||||
from gym.spaces.discrete import Discrete
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, MultiDiscrete, Box
|
||||
|
||||
|
||||
class MyTestEnv(gym.Env):
|
||||
"""This is a "going right" task. The task is to go right ``size`` steps.
|
||||
"""
|
||||
|
||||
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
|
||||
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
|
||||
multidiscrete_action=False):
|
||||
self.size = size
|
||||
self.sleep = sleep
|
||||
self.dict_state = dict_state
|
||||
self.ma_rew = ma_rew
|
||||
self.action_space = Discrete(2)
|
||||
self._md_action = multidiscrete_action
|
||||
self.observation_space = Box(shape=(1, ), low=0, high=size - 1)
|
||||
if multidiscrete_action:
|
||||
self.action_space = MultiDiscrete([2, 2])
|
||||
else:
|
||||
self.action_space = Discrete(2)
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=0):
|
||||
np.random.seed(seed)
|
||||
|
||||
def reset(self, state=0):
|
||||
self.done = False
|
||||
self.index = state
|
||||
@ -29,9 +39,12 @@ class MyTestEnv(gym.Env):
|
||||
|
||||
def _get_dict_state(self):
|
||||
"""Generate a dict_state if dict_state is True."""
|
||||
return {'index': self.index} if self.dict_state else self.index
|
||||
return {'index': self.index, 'rand': np.random.rand()} \
|
||||
if self.dict_state else self.index
|
||||
|
||||
def step(self, action):
|
||||
if self._md_action:
|
||||
action = action[0]
|
||||
if self.done:
|
||||
raise ValueError('step after done !!!')
|
||||
if self.sleep > 0:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch
|
||||
import pickle
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
@ -113,8 +113,12 @@ def test_collector_with_dict_state():
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
||||
for i in [2, 3, 4, 5]]
|
||||
envs = VectorEnv(env_fns)
|
||||
envs.seed(666)
|
||||
obs = envs.reset()
|
||||
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
|
||||
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
||||
Logger.single_preprocess_fn)
|
||||
c1.seed(0)
|
||||
c1.collect(n_step=10)
|
||||
c1.collect(n_episode=[2, 1, 1, 2])
|
||||
batch = c1.sample(10)
|
||||
|
@ -86,7 +86,7 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
|
11
tianshou/env/basevecenv.py
vendored
11
tianshou/env/basevecenv.py
vendored
@ -27,6 +27,17 @@ class BaseVectorEnv(ABC, gym.Env):
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use your own environment, please make sure the ``seed`` method
|
||||
is set up properly, e.g.,
|
||||
::
|
||||
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
|
@ -108,8 +108,9 @@ class A2CPolicy(PGPolicy):
|
||||
v = self.critic(b.obs).flatten()
|
||||
a = to_torch_as(b.act, v)
|
||||
r = to_torch_as(b.returns, v)
|
||||
a_loss = -(dist.log_prob(a).flatten() * (r - v).detach()
|
||||
).mean()
|
||||
log_prob = dist.log_prob(a).reshape(
|
||||
r.shape[0], -1).transpose(0, 1)
|
||||
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r, v)
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||
|
@ -157,7 +157,7 @@ class DQNPolicy(BasePolicy):
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
q = self(batch).logits
|
||||
q = self(batch, eps=0.).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns, q).flatten()
|
||||
if hasattr(batch, 'update_weight'):
|
||||
|
@ -90,7 +90,9 @@ class PGPolicy(BasePolicy):
|
||||
dist = self(b).dist
|
||||
a = to_torch_as(b.act, dist.logits)
|
||||
r = to_torch_as(b.returns, dist.logits)
|
||||
loss = -(dist.log_prob(a).flatten() * r).sum()
|
||||
log_prob = dist.log_prob(a).reshape(
|
||||
r.shape[0], -1).transpose(0, 1)
|
||||
loss = -(log_prob * r).mean()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.item())
|
||||
|
@ -132,7 +132,7 @@ class PPOPolicy(PGPolicy):
|
||||
to_torch_as(b.act, v[0])))
|
||||
batch.v = torch.cat(v, dim=0).flatten() # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0).flatten()
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0])
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
@ -147,8 +147,8 @@ class PPOPolicy(PGPolicy):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs).flatten()
|
||||
ratio = (dist.log_prob(b.act).flatten()
|
||||
- b.logp_old).exp().float()
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1. - self._eps_clip,
|
||||
1. + self._eps_clip) * b.adv
|
||||
|
@ -36,7 +36,7 @@ class Net(nn.Module):
|
||||
def forward(self, s, state=None, info={}):
|
||||
"""s -> flatten -> logits"""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.flatten(1)
|
||||
s = s.reshape(s.size(0), -1)
|
||||
logits = self.model(s)
|
||||
return logits, state
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user