minor reformat (#2)
* update atari.py * fix setup.py pass the pytest * fix setup.py pass the pytest
This commit is contained in:
parent
fdc969b830
commit
3c0a09fefd
10
setup.py
10
setup.py
@ -37,11 +37,19 @@ setup(
|
|||||||
'examples', 'examples.*',
|
'examples', 'examples.*',
|
||||||
'docs', 'docs.*']),
|
'docs', 'docs.*']),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'gym',
|
'gym>=0.15.0',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'numpy',
|
'numpy',
|
||||||
'cloudpickle',
|
'cloudpickle',
|
||||||
'tensorboard',
|
'tensorboard',
|
||||||
'torch>=1.4.0',
|
'torch>=1.4.0',
|
||||||
],
|
],
|
||||||
|
extras_require={
|
||||||
|
'atari': [
|
||||||
|
'atari_py',
|
||||||
|
],
|
||||||
|
'mujoco': [
|
||||||
|
'mujoco_py',
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from tianshou.data import ReplayBuffer
|
from tianshou.data import ReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
else: # pytest
|
else: # pytest
|
||||||
|
|||||||
@ -11,6 +11,7 @@ else: # pytest
|
|||||||
|
|
||||||
class MyPolicy(BasePolicy):
|
class MyPolicy(BasePolicy):
|
||||||
"""docstring for MyPolicy"""
|
"""docstring for MyPolicy"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer,\
|
from tianshou import data, env, utils, policy, trainer, \
|
||||||
exploration
|
exploration
|
||||||
|
|
||||||
__version__ = '0.2.0'
|
__version__ = '0.2.0'
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class Batch(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'No support for append with type {} in class Batch.'
|
'No support for append with type {} in class Batch.'
|
||||||
.format(type(batch.__dict__[k])))
|
.format(type(batch.__dict__[k])))
|
||||||
|
|
||||||
def split(self, size=None, permute=True):
|
def split(self, size=None, permute=True):
|
||||||
length = min([
|
length = min([
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class ReplayBuffer(object):
|
|||||||
'''
|
'''
|
||||||
weight: importance weights, disabled here
|
weight: importance weights, disabled here
|
||||||
'''
|
'''
|
||||||
assert isinstance(info, dict),\
|
assert isinstance(info, dict), \
|
||||||
'You should return a dict in the last argument of env.step().'
|
'You should return a dict in the last argument of env.step().'
|
||||||
self._add_to_buffer('obs', obs)
|
self._add_to_buffer('obs', obs)
|
||||||
self._add_to_buffer('act', act)
|
self._add_to_buffer('act', act)
|
||||||
|
|||||||
@ -31,8 +31,8 @@ class Collector(object):
|
|||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
if isinstance(self.buffer, list):
|
if isinstance(self.buffer, list):
|
||||||
assert len(self.buffer) == self.env_num,\
|
assert len(self.buffer) == self.env_num, \
|
||||||
'The number of data buffer does not match the number of '\
|
'The number of data buffer does not match the number of ' \
|
||||||
'input env.'
|
'input env.'
|
||||||
self._multi_buf = True
|
self._multi_buf = True
|
||||||
elif isinstance(self.buffer, ReplayBuffer):
|
elif isinstance(self.buffer, ReplayBuffer):
|
||||||
@ -87,7 +87,7 @@ class Collector(object):
|
|||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
n_episode = np.sum(n_episode)
|
n_episode = np.sum(n_episode)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
assert sum([(n_step != 0), (n_episode != 0)]) == 1,\
|
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||||
|
|||||||
2
tianshou/env/__init__.py
vendored
2
tianshou/env/__init__.py
vendored
@ -1,6 +1,6 @@
|
|||||||
from tianshou.env.utils import CloudpickleWrapper
|
from tianshou.env.utils import CloudpickleWrapper
|
||||||
from tianshou.env.common import EnvWrapper, FrameStack
|
from tianshou.env.common import EnvWrapper, FrameStack
|
||||||
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv,\
|
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
||||||
SubprocVectorEnv, RayVectorEnv
|
SubprocVectorEnv, RayVectorEnv
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
3
tianshou/env/vecenv.py
vendored
3
tianshou/env/vecenv.py
vendored
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -122,7 +123,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
zip(*[Pipe() for _ in range(self.env_num)])
|
zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
self.processes = [
|
self.processes = [
|
||||||
Process(target=worker, args=(
|
Process(target=worker, args=(
|
||||||
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||||
for (parent, child, env_fn) in zip(
|
for (parent, child, env_fn) in zip(
|
||||||
self.parent_remote, self.child_remote, env_fns)
|
self.parent_remote, self.child_remote, env_fns)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class OUNoise(object):
|
|||||||
if self.x is None or self.x.shape != size:
|
if self.x is None or self.x.shape != size:
|
||||||
self.x = 0
|
self.x = 0
|
||||||
self.x = self.x + self.alpha * (mu - self.x) + \
|
self.x = self.x + self.alpha * (mu - self.x) + \
|
||||||
self.beta * np.random.normal(size=size)
|
self.beta * np.random.normal(size=size)
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@ -40,8 +40,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
vf_loss = F.mse_loss(r[:, None], v)
|
vf_loss = F.mse_loss(r[:, None], v)
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
loss = actor_loss \
|
loss = actor_loss \
|
||||||
+ self._w_vf * vf_loss \
|
+ self._w_vf * vf_loss \
|
||||||
- self._w_ent * ent_loss
|
- self._w_ent * ent_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if self._grad_norm:
|
if self._grad_norm:
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
# from tianshou.exploration import OUNoise
|
# from tianshou.exploration import OUNoise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class PGPolicy(BasePolicy):
|
|||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
losses = []
|
losses = []
|
||||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
batch.returns = (batch.returns - batch.returns.mean()) \
|
||||||
/ (batch.returns.std() + self._eps)
|
/ (batch.returns.std() + self._eps)
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
batch.returns = (batch.returns - batch.returns.mean()) \
|
||||||
/ (batch.returns.std() + self._eps)
|
/ (batch.returns.std() + self._eps)
|
||||||
batch.act = torch.tensor(batch.act)
|
batch.act = torch.tensor(batch.act)
|
||||||
batch.returns = torch.tensor(batch.returns)[:, None]
|
batch.returns = torch.tensor(batch.returns)[:, None]
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
@ -82,13 +82,13 @@ class PPOPolicy(PGPolicy):
|
|||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
ent_losses.append(ent_loss.detach().cpu().numpy())
|
ent_losses.append(ent_loss.detach().cpu().numpy())
|
||||||
loss = clip_loss \
|
loss = clip_loss \
|
||||||
+ self._w_vf * vf_loss - self._w_ent * ent_loss
|
+ self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||||
losses.append(loss.detach().cpu().numpy())
|
losses.append(loss.detach().cpu().numpy())
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
nn.utils.clip_grad_norm_(list(
|
nn.utils.clip_grad_norm_(list(
|
||||||
self.actor.parameters()) + list(self.critic.parameters()),
|
self.actor.parameters()) + list(self.critic.parameters()),
|
||||||
self._max_grad_norm)
|
self._max_grad_norm)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
return {
|
return {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user