flake8 fix

This commit is contained in:
Trinkle23897 2020-03-11 09:38:14 +08:00
parent 776acd9f13
commit 5550aed0a1
6 changed files with 42 additions and 28 deletions

View File

@ -30,10 +30,7 @@ jobs:
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
pip install flake8 pip install flake8
# stop the build if there are Python syntax errors or undefined names ./flake_check.sh
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest - name: Test with pytest
run: | run: |
pip install pytest pip install pytest

3
flake_check.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/sh
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

View File

@ -46,8 +46,14 @@ class ReplayBuffer(object):
def sample(self, batch_size): def sample(self, batch_size):
indice = self.sample_index(batch_size) indice = self.sample_index(batch_size)
return Batch(obs=self.obs[indice], act=self.act[indice], rew=self.rew[indice], return Batch(
done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice]) obs=self.obs[indice],
act=self.act[indice],
rew=self.rew[indice],
done=self.done[indice],
obs_next=self.obs_next[indice],
info=self.info[indice]
)
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):

View File

@ -1,6 +1,10 @@
import numpy as np import numpy as np
from collections import deque from collections import deque
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
try:
import ray
except ImportError:
pass
from tianshou.utils import CloudpickleWrapper from tianshou.utils import CloudpickleWrapper
@ -89,12 +93,15 @@ class VectorEnv(object):
class SubprocVectorEnv(object): class SubprocVectorEnv(object):
"""docstring for SubProcVectorEnv""" """docstring for SubProcVectorEnv"""
def __init__(self, env_fns, **kwargs): def __init__(self, env_fns, **kwargs):
super().__init__() super().__init__()
self.env_num = len(env_fns) self.env_num = len(env_fns)
self.closed = False self.closed = False
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)]) self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True) self.processes = [
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)] Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
]
for p in self.processes: for p in self.processes:
p.start() p.start()
for c in self.child_remote: for c in self.child_remote:
@ -103,26 +110,26 @@ class SubprocVectorEnv(object):
def __len__(self): def __len__(self):
return self.env_num return self.env_num
def worker(parent, p, env_fn_wrapper, **kwargs): def worker(self, parent, p, env_fn_wrapper, **kwargs):
reset_after_done = kwargs.get('reset_after_done', True) reset_after_done = kwargs.get('reset_after_done', True)
parent.close() parent.close()
env = env_fn_wrapper.data() env = env_fn_wrapper.data()
while True: while True:
cmd, data = p.recv() cmd, data = p.recv()
if cmd is 'step': if cmd == 'step':
obs, rew, done, info = env.step(data) obs, rew, done, info = env.step(data)
if reset_after_done and done: if reset_after_done and done:
# s_ is useless when episode finishes # s_ is useless when episode finishes
obs = env.reset() obs = env.reset()
p.send([obs, rew, done, info]) p.send([obs, rew, done, info])
elif cmd is 'reset': elif cmd == 'reset':
p.send(env.reset()) p.send(env.reset())
elif cmd is 'close': elif cmd == 'close':
p.close() p.close()
break break
elif cmd is 'render': elif cmd == 'render':
p.send(env.render()) p.send(env.render())
elif cmd is 'seed': elif cmd == 'seed':
p.send(env.seed(data)) p.send(env.seed(data))
else: else:
raise NotImplementedError raise NotImplementedError
@ -163,7 +170,6 @@ class SubprocVectorEnv(object):
p.join() p.join()
class RayVectorEnv(object): class RayVectorEnv(object):
"""docstring for RayVectorEnv""" """docstring for RayVectorEnv"""
def __init__(self, env_fns, **kwargs): def __init__(self, env_fns, **kwargs):

View File

@ -4,7 +4,9 @@ import cloudpickle
class CloudpickleWrapper(object): class CloudpickleWrapper(object):
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
def __getstate__(self): def __getstate__(self):
return cloudpickle.dumps(self.data) return cloudpickle.dumps(self.data)
def __setstate__(self, data): def __setstate__(self, data):
self.data = cloudpickle.loads(data) self.data = cloudpickle.loads(data)