clear
This commit is contained in:
parent
f16e05c0e7
commit
543e57cdbd
3
.gitignore
vendored
3
.gitignore
vendored
@ -132,3 +132,6 @@ dmypy.json
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# customize
|
||||
flake8.sh
|
||||
|
@ -1,3 +0,0 @@
|
||||
#!/bin/sh
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
|
@ -2,8 +2,8 @@ from tianshou import data, env, utils, policy
|
||||
|
||||
__version__ = '0.2.0'
|
||||
__all__ = [
|
||||
'data',
|
||||
'env',
|
||||
'data',
|
||||
'utils',
|
||||
'policy'
|
||||
]
|
||||
|
@ -25,5 +25,5 @@ class Batch(object):
|
||||
self.__dict__[k] += batch.__dict__[k]
|
||||
else:
|
||||
raise TypeError(
|
||||
'No support append method with {} in class Batch.'
|
||||
'Do not support append with type {} in class Batch.'
|
||||
.format(type(batch.__dict__[k])))
|
||||
|
@ -21,7 +21,7 @@ class Collector(object):
|
||||
self.env_num = len(env)
|
||||
if isinstance(self.buffer, list):
|
||||
assert len(self.buffer) == self.env_num,\
|
||||
'Data buffer number does not match the input env number.'
|
||||
'# of data buffer does not match the # of input env.'
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
||||
else:
|
||||
|
17
tianshou/env/wrapper.py
vendored
17
tianshou/env/wrapper.py
vendored
@ -58,18 +58,18 @@ class FrameStack(EnvWrapper):
|
||||
|
||||
|
||||
class BaseVectorEnv(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, env_fns, reset_after_done):
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
self._reset_after_done = reset_after_done
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""docstring for VectorEnv"""
|
||||
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__()
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
self.env_num = len(self.envs)
|
||||
self._reset_after_done = reset_after_done
|
||||
|
||||
def __len__(self):
|
||||
return len(self.envs)
|
||||
@ -132,8 +132,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""docstring for SubProcVectorEnv"""
|
||||
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__()
|
||||
self.env_num = len(env_fns)
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = \
|
||||
zip(*[Pipe() for _ in range(self.env_num)])
|
||||
@ -193,9 +192,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
"""docstring for RayVectorEnv"""
|
||||
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__()
|
||||
self.env_num = len(env_fns)
|
||||
self._reset_after_done = reset_after_done
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
|
@ -1,5 +1,7 @@
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.dqn import DQNPolicy
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy'
|
||||
'BasePolicy',
|
||||
'DQNPolicy',
|
||||
]
|
||||
|
@ -9,7 +9,7 @@ class BasePolicy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def act(self, batch, hidden_state=None):
|
||||
# return {policy, action, hidden}
|
||||
# return Batch(policy, action, hidden)
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user