clear
This commit is contained in:
parent
f16e05c0e7
commit
543e57cdbd
3
.gitignore
vendored
3
.gitignore
vendored
@ -132,3 +132,6 @@ dmypy.json
|
|||||||
|
|
||||||
# pytype static type analyzer
|
# pytype static type analyzer
|
||||||
.pytype/
|
.pytype/
|
||||||
|
|
||||||
|
# customize
|
||||||
|
flake8.sh
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
|
||||||
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
|
|
@ -2,8 +2,8 @@ from tianshou import data, env, utils, policy
|
|||||||
|
|
||||||
__version__ = '0.2.0'
|
__version__ = '0.2.0'
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'data',
|
|
||||||
'env',
|
'env',
|
||||||
|
'data',
|
||||||
'utils',
|
'utils',
|
||||||
'policy'
|
'policy'
|
||||||
]
|
]
|
||||||
|
@ -25,5 +25,5 @@ class Batch(object):
|
|||||||
self.__dict__[k] += batch.__dict__[k]
|
self.__dict__[k] += batch.__dict__[k]
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'No support append method with {} in class Batch.'
|
'Do not support append with type {} in class Batch.'
|
||||||
.format(type(batch.__dict__[k])))
|
.format(type(batch.__dict__[k])))
|
||||||
|
@ -21,7 +21,7 @@ class Collector(object):
|
|||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
if isinstance(self.buffer, list):
|
if isinstance(self.buffer, list):
|
||||||
assert len(self.buffer) == self.env_num,\
|
assert len(self.buffer) == self.env_num,\
|
||||||
'Data buffer number does not match the input env number.'
|
'# of data buffer does not match the # of input env.'
|
||||||
elif isinstance(self.buffer, ReplayBuffer):
|
elif isinstance(self.buffer, ReplayBuffer):
|
||||||
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
|
17
tianshou/env/wrapper.py
vendored
17
tianshou/env/wrapper.py
vendored
@ -58,18 +58,18 @@ class FrameStack(EnvWrapper):
|
|||||||
|
|
||||||
|
|
||||||
class BaseVectorEnv(ABC):
|
class BaseVectorEnv(ABC):
|
||||||
def __init__(self):
|
def __init__(self, env_fns, reset_after_done):
|
||||||
pass
|
self._env_fns = env_fns
|
||||||
|
self.env_num = len(env_fns)
|
||||||
|
self._reset_after_done = reset_after_done
|
||||||
|
|
||||||
|
|
||||||
class VectorEnv(BaseVectorEnv):
|
class VectorEnv(BaseVectorEnv):
|
||||||
"""docstring for VectorEnv"""
|
"""docstring for VectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns, reset_after_done=False):
|
||||||
super().__init__()
|
super().__init__(env_fns, reset_after_done)
|
||||||
self.envs = [_() for _ in env_fns]
|
self.envs = [_() for _ in env_fns]
|
||||||
self.env_num = len(self.envs)
|
|
||||||
self._reset_after_done = reset_after_done
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.envs)
|
return len(self.envs)
|
||||||
@ -132,8 +132,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
"""docstring for SubProcVectorEnv"""
|
"""docstring for SubProcVectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns, reset_after_done=False):
|
||||||
super().__init__()
|
super().__init__(env_fns, reset_after_done)
|
||||||
self.env_num = len(env_fns)
|
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.parent_remote, self.child_remote = \
|
self.parent_remote, self.child_remote = \
|
||||||
zip(*[Pipe() for _ in range(self.env_num)])
|
zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
@ -193,9 +192,7 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
"""docstring for RayVectorEnv"""
|
"""docstring for RayVectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns, reset_after_done=False):
|
||||||
super().__init__()
|
super().__init__(env_fns, reset_after_done)
|
||||||
self.env_num = len(env_fns)
|
|
||||||
self._reset_after_done = reset_after_done
|
|
||||||
try:
|
try:
|
||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init()
|
ray.init()
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from tianshou.policy.base import BasePolicy
|
from tianshou.policy.base import BasePolicy
|
||||||
|
from tianshou.policy.dqn import DQNPolicy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BasePolicy'
|
'BasePolicy',
|
||||||
|
'DQNPolicy',
|
||||||
]
|
]
|
||||||
|
@ -9,7 +9,7 @@ class BasePolicy(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def act(self, batch, hidden_state=None):
|
def act(self, batch, hidden_state=None):
|
||||||
# return {policy, action, hidden}
|
# return Batch(policy, action, hidden)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user