flake8 fix

This commit is contained in:
Trinkle23897 2020-03-11 09:38:14 +08:00
parent 776acd9f13
commit 5550aed0a1
6 changed files with 42 additions and 28 deletions

View File

@ -30,10 +30,7 @@ jobs:
- name: Lint with flake8
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
./flake_check.sh
- name: Test with pytest
run: |
pip install pytest

3
flake_check.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/sh
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

View File

@ -40,12 +40,12 @@ setup(
'examples', 'examples.*',
'docs', 'docs.*']),
install_requires=[
'numpy',
'torch',
'tensorboard',
'tqdm',
# 'ray',
'gym',
'numpy',
'torch',
'tensorboard',
'tqdm',
# 'ray',
'gym',
'cloudpickle'
],
)
)

View File

@ -20,7 +20,7 @@ class ReplayBuffer(object):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict):
self.__dict__[name] = np.array([{} for _ in range(self._maxsize)])
else: # assume `inst` is a number
else: # assume `inst` is a number
self.__dict__[name] = np.zeros([self._maxsize])
self.__dict__[name][self._index] = inst
@ -46,15 +46,21 @@ class ReplayBuffer(object):
def sample(self, batch_size):
indice = self.sample_index(batch_size)
return Batch(obs=self.obs[indice], act=self.act[indice], rew=self.rew[indice],
done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice])
return Batch(
obs=self.obs[indice],
act=self.act[indice],
rew=self.rew[indice],
done=self.done[indice],
obs_next=self.obs_next[indice],
info=self.info[indice]
)
class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer"""
def __init__(self, size):
super().__init__(size)
def add(self, obs, act, rew, done, obs_next, info={}, weight=None):
raise NotImplementedError

View File

@ -1,6 +1,10 @@
import numpy as np
from collections import deque
from multiprocessing import Process, Pipe
try:
import ray
except ImportError:
pass
from tianshou.utils import CloudpickleWrapper
@ -11,10 +15,10 @@ class EnvWrapper(object):
def step(self, action):
return self.env.step(action)
def reset(self):
self.env.reset()
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
@ -55,7 +59,7 @@ class VectorEnv(object):
super().__init__()
self.envs = [_() for _ in env_fns]
self._reset_after_done = kwargs.get('reset_after_done', False)
def __len__(self):
return len(self.envs)
@ -89,12 +93,15 @@ class VectorEnv(object):
class SubprocVectorEnv(object):
"""docstring for SubProcVectorEnv"""
def __init__(self, env_fns, **kwargs):
super().__init__()
self.env_num = len(env_fns)
self.closed = False
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)]
self.processes = [
Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes:
p.start()
for c in self.child_remote:
@ -102,27 +109,27 @@ class SubprocVectorEnv(object):
def __len__(self):
return self.env_num
def worker(parent, p, env_fn_wrapper, **kwargs):
def worker(self, parent, p, env_fn_wrapper, **kwargs):
reset_after_done = kwargs.get('reset_after_done', True)
parent.close()
env = env_fn_wrapper.data()
while True:
cmd, data = p.recv()
if cmd is 'step':
if cmd == 'step':
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
elif cmd is 'reset':
elif cmd == 'reset':
p.send(env.reset())
elif cmd is 'close':
elif cmd == 'close':
p.close()
break
elif cmd is 'render':
elif cmd == 'render':
p.send(env.render())
elif cmd is 'seed':
elif cmd == 'seed':
p.send(env.seed(data))
else:
raise NotImplementedError
@ -163,7 +170,6 @@ class SubprocVectorEnv(object):
p.join()
class RayVectorEnv(object):
"""docstring for RayVectorEnv"""
def __init__(self, env_fns, **kwargs):

View File

@ -4,7 +4,9 @@ import cloudpickle
class CloudpickleWrapper(object):
def __init__(self, data):
self.data = data
def __getstate__(self):
return cloudpickle.dumps(self.data)
def __setstate__(self, data):
self.data = cloudpickle.loads(data)