Yet another 3 fix (#160)
1. DQN learn should keep eps=0 2. Add a warning of env.seed in VecEnv 3. fix #162 of multi-dim action
This commit is contained in:
parent
bfeffe1f97
commit
38a95c19da
@ -48,6 +48,17 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte
|
|||||||
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
||||||
venv = SubprocVectorEnv(env_fns)
|
venv = SubprocVectorEnv(env_fns)
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
def seed(self, seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
Otherwise, the outputs of these envs may be the same with each other.
|
||||||
|
|
||||||
.. _preprocess_fn:
|
.. _preprocess_fn:
|
||||||
|
|
||||||
Handle Batched Data Stream in Collector
|
Handle Batched Data Stream in Collector
|
||||||
|
@ -40,6 +40,17 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_
|
|||||||
|
|
||||||
For the demonstration, here we use the second block of codes.
|
For the demonstration, here we use the second block of codes.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
def seed(self, seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
Otherwise, the outputs of these envs may be the same with each other.
|
||||||
|
|
||||||
.. _build_the_network:
|
.. _build_the_network:
|
||||||
|
|
||||||
Build the Network
|
Build the Network
|
||||||
|
@ -1,20 +1,30 @@
|
|||||||
import gym
|
import gym
|
||||||
import time
|
import time
|
||||||
from gym.spaces.discrete import Discrete
|
import numpy as np
|
||||||
|
from gym.spaces import Discrete, MultiDiscrete, Box
|
||||||
|
|
||||||
|
|
||||||
class MyTestEnv(gym.Env):
|
class MyTestEnv(gym.Env):
|
||||||
"""This is a "going right" task. The task is to go right ``size`` steps.
|
"""This is a "going right" task. The task is to go right ``size`` steps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
|
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
|
||||||
|
multidiscrete_action=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
self.dict_state = dict_state
|
self.dict_state = dict_state
|
||||||
self.ma_rew = ma_rew
|
self.ma_rew = ma_rew
|
||||||
self.action_space = Discrete(2)
|
self._md_action = multidiscrete_action
|
||||||
|
self.observation_space = Box(shape=(1, ), low=0, high=size - 1)
|
||||||
|
if multidiscrete_action:
|
||||||
|
self.action_space = MultiDiscrete([2, 2])
|
||||||
|
else:
|
||||||
|
self.action_space = Discrete(2)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def seed(self, seed=0):
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
def reset(self, state=0):
|
def reset(self, state=0):
|
||||||
self.done = False
|
self.done = False
|
||||||
self.index = state
|
self.index = state
|
||||||
@ -29,9 +39,12 @@ class MyTestEnv(gym.Env):
|
|||||||
|
|
||||||
def _get_dict_state(self):
|
def _get_dict_state(self):
|
||||||
"""Generate a dict_state if dict_state is True."""
|
"""Generate a dict_state if dict_state is True."""
|
||||||
return {'index': self.index} if self.dict_state else self.index
|
return {'index': self.index, 'rand': np.random.rand()} \
|
||||||
|
if self.dict_state else self.index
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
if self._md_action:
|
||||||
|
action = action[0]
|
||||||
if self.done:
|
if self.done:
|
||||||
raise ValueError('step after done !!!')
|
raise ValueError('step after done !!!')
|
||||||
if self.sleep > 0:
|
if self.sleep > 0:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
|
||||||
import copy
|
import copy
|
||||||
|
import torch
|
||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -113,8 +113,12 @@ def test_collector_with_dict_state():
|
|||||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
||||||
for i in [2, 3, 4, 5]]
|
for i in [2, 3, 4, 5]]
|
||||||
envs = VectorEnv(env_fns)
|
envs = VectorEnv(env_fns)
|
||||||
|
envs.seed(666)
|
||||||
|
obs = envs.reset()
|
||||||
|
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
|
||||||
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
||||||
Logger.single_preprocess_fn)
|
Logger.single_preprocess_fn)
|
||||||
|
c1.seed(0)
|
||||||
c1.collect(n_step=10)
|
c1.collect(n_step=10)
|
||||||
c1.collect(n_episode=[2, 1, 1, 2])
|
c1.collect(n_episode=[2, 1, 1, 2])
|
||||||
batch = c1.sample(10)
|
batch = c1.sample(10)
|
||||||
|
@ -86,7 +86,7 @@ def get_args():
|
|||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
|
11
tianshou/env/basevecenv.py
vendored
11
tianshou/env/basevecenv.py
vendored
@ -27,6 +27,17 @@ class BaseVectorEnv(ABC, gym.Env):
|
|||||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||||
envs.render() # render all environments
|
envs.render() # render all environments
|
||||||
envs.close() # close all environments
|
envs.close() # close all environments
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you use your own environment, please make sure the ``seed`` method
|
||||||
|
is set up properly, e.g.,
|
||||||
|
::
|
||||||
|
|
||||||
|
def seed(self, seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
Otherwise, the outputs of these envs may be the same with each other.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||||
|
@ -108,8 +108,9 @@ class A2CPolicy(PGPolicy):
|
|||||||
v = self.critic(b.obs).flatten()
|
v = self.critic(b.obs).flatten()
|
||||||
a = to_torch_as(b.act, v)
|
a = to_torch_as(b.act, v)
|
||||||
r = to_torch_as(b.returns, v)
|
r = to_torch_as(b.returns, v)
|
||||||
a_loss = -(dist.log_prob(a).flatten() * (r - v).detach()
|
log_prob = dist.log_prob(a).reshape(
|
||||||
).mean()
|
r.shape[0], -1).transpose(0, 1)
|
||||||
|
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||||
vf_loss = F.mse_loss(r, v)
|
vf_loss = F.mse_loss(r, v)
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||||
|
@ -157,7 +157,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
if self._target and self._cnt % self._freq == 0:
|
if self._target and self._cnt % self._freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
q = self(batch).logits
|
q = self(batch, eps=0.).logits
|
||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = to_torch_as(batch.returns, q).flatten()
|
r = to_torch_as(batch.returns, q).flatten()
|
||||||
if hasattr(batch, 'update_weight'):
|
if hasattr(batch, 'update_weight'):
|
||||||
|
@ -90,7 +90,9 @@ class PGPolicy(BasePolicy):
|
|||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
a = to_torch_as(b.act, dist.logits)
|
a = to_torch_as(b.act, dist.logits)
|
||||||
r = to_torch_as(b.returns, dist.logits)
|
r = to_torch_as(b.returns, dist.logits)
|
||||||
loss = -(dist.log_prob(a).flatten() * r).sum()
|
log_prob = dist.log_prob(a).reshape(
|
||||||
|
r.shape[0], -1).transpose(0, 1)
|
||||||
|
loss = -(log_prob * r).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
@ -132,7 +132,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
to_torch_as(b.act, v[0])))
|
to_torch_as(b.act, v[0])))
|
||||||
batch.v = torch.cat(v, dim=0).flatten() # old value
|
batch.v = torch.cat(v, dim=0).flatten() # old value
|
||||||
batch.act = to_torch_as(batch.act, v[0])
|
batch.act = to_torch_as(batch.act, v[0])
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0).flatten()
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||||
batch.returns = to_torch_as(batch.returns, v[0])
|
batch.returns = to_torch_as(batch.returns, v[0])
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
mean, std = batch.returns.mean(), batch.returns.std()
|
mean, std = batch.returns.mean(), batch.returns.std()
|
||||||
@ -147,8 +147,8 @@ class PPOPolicy(PGPolicy):
|
|||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
value = self.critic(b.obs).flatten()
|
value = self.critic(b.obs).flatten()
|
||||||
ratio = (dist.log_prob(b.act).flatten()
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||||
- b.logp_old).exp().float()
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||||
surr1 = ratio * b.adv
|
surr1 = ratio * b.adv
|
||||||
surr2 = ratio.clamp(1. - self._eps_clip,
|
surr2 = ratio.clamp(1. - self._eps_clip,
|
||||||
1. + self._eps_clip) * b.adv
|
1. + self._eps_clip) * b.adv
|
||||||
|
@ -36,7 +36,7 @@ class Net(nn.Module):
|
|||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
"""s -> flatten -> logits"""
|
"""s -> flatten -> logits"""
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
s = s.flatten(1)
|
s = s.reshape(s.size(0), -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
return logits, state
|
return logits, state
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user