flake8 fix
This commit is contained in:
parent
776acd9f13
commit
5550aed0a1
5
.github/workflows/pytest.yml
vendored
5
.github/workflows/pytest.yml
vendored
@ -30,10 +30,7 @@ jobs:
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
./flake_check.sh
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pip install pytest
|
||||
|
3
flake_check.sh
Executable file
3
flake_check.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/sh
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
14
setup.py
14
setup.py
@ -40,12 +40,12 @@ setup(
|
||||
'examples', 'examples.*',
|
||||
'docs', 'docs.*']),
|
||||
install_requires=[
|
||||
'numpy',
|
||||
'torch',
|
||||
'tensorboard',
|
||||
'tqdm',
|
||||
# 'ray',
|
||||
'gym',
|
||||
'numpy',
|
||||
'torch',
|
||||
'tensorboard',
|
||||
'tqdm',
|
||||
# 'ray',
|
||||
'gym',
|
||||
'cloudpickle'
|
||||
],
|
||||
)
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ class ReplayBuffer(object):
|
||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||
elif isinstance(inst, dict):
|
||||
self.__dict__[name] = np.array([{} for _ in range(self._maxsize)])
|
||||
else: # assume `inst` is a number
|
||||
else: # assume `inst` is a number
|
||||
self.__dict__[name] = np.zeros([self._maxsize])
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
@ -46,15 +46,21 @@ class ReplayBuffer(object):
|
||||
|
||||
def sample(self, batch_size):
|
||||
indice = self.sample_index(batch_size)
|
||||
return Batch(obs=self.obs[indice], act=self.act[indice], rew=self.rew[indice],
|
||||
done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice])
|
||||
return Batch(
|
||||
obs=self.obs[indice],
|
||||
act=self.act[indice],
|
||||
rew=self.rew[indice],
|
||||
done=self.done[indice],
|
||||
obs_next=self.obs_next[indice],
|
||||
info=self.info[indice]
|
||||
)
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""docstring for PrioritizedReplayBuffer"""
|
||||
def __init__(self, size):
|
||||
super().__init__(size)
|
||||
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next, info={}, weight=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
32
tianshou/env/wrapper.py
vendored
32
tianshou/env/wrapper.py
vendored
@ -1,6 +1,10 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from multiprocessing import Process, Pipe
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from tianshou.utils import CloudpickleWrapper
|
||||
|
||||
@ -11,10 +15,10 @@ class EnvWrapper(object):
|
||||
|
||||
def step(self, action):
|
||||
return self.env.step(action)
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
|
||||
|
||||
def seed(self, seed=None):
|
||||
if hasattr(self.env, 'seed'):
|
||||
self.env.seed(seed)
|
||||
@ -55,7 +59,7 @@ class VectorEnv(object):
|
||||
super().__init__()
|
||||
self.envs = [_() for _ in env_fns]
|
||||
self._reset_after_done = kwargs.get('reset_after_done', False)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.envs)
|
||||
|
||||
@ -89,12 +93,15 @@ class VectorEnv(object):
|
||||
class SubprocVectorEnv(object):
|
||||
"""docstring for SubProcVectorEnv"""
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.env_num = len(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)]
|
||||
self.processes = [
|
||||
Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for c in self.child_remote:
|
||||
@ -102,27 +109,27 @@ class SubprocVectorEnv(object):
|
||||
|
||||
def __len__(self):
|
||||
return self.env_num
|
||||
|
||||
def worker(parent, p, env_fn_wrapper, **kwargs):
|
||||
|
||||
def worker(self, parent, p, env_fn_wrapper, **kwargs):
|
||||
reset_after_done = kwargs.get('reset_after_done', True)
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd is 'step':
|
||||
if cmd == 'step':
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
elif cmd is 'reset':
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd is 'close':
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
break
|
||||
elif cmd is 'render':
|
||||
elif cmd == 'render':
|
||||
p.send(env.render())
|
||||
elif cmd is 'seed':
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@ -163,7 +170,6 @@ class SubprocVectorEnv(object):
|
||||
p.join()
|
||||
|
||||
|
||||
|
||||
class RayVectorEnv(object):
|
||||
"""docstring for RayVectorEnv"""
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
|
@ -4,7 +4,9 @@ import cloudpickle
|
||||
class CloudpickleWrapper(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getstate__(self):
|
||||
return cloudpickle.dumps(self.data)
|
||||
|
||||
def __setstate__(self, data):
|
||||
self.data = cloudpickle.loads(data)
|
||||
|
Loading…
x
Reference in New Issue
Block a user