flake8 fix
This commit is contained in:
parent
776acd9f13
commit
5550aed0a1
5
.github/workflows/pytest.yml
vendored
5
.github/workflows/pytest.yml
vendored
@ -30,10 +30,7 @@ jobs:
|
|||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
pip install flake8
|
pip install flake8
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
./flake_check.sh
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
|
||||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
|
||||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
|
3
flake_check.sh
Executable file
3
flake_check.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||||
|
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
@ -20,7 +20,7 @@ class ReplayBuffer(object):
|
|||||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||||
elif isinstance(inst, dict):
|
elif isinstance(inst, dict):
|
||||||
self.__dict__[name] = np.array([{} for _ in range(self._maxsize)])
|
self.__dict__[name] = np.array([{} for _ in range(self._maxsize)])
|
||||||
else: # assume `inst` is a number
|
else: # assume `inst` is a number
|
||||||
self.__dict__[name] = np.zeros([self._maxsize])
|
self.__dict__[name] = np.zeros([self._maxsize])
|
||||||
self.__dict__[name][self._index] = inst
|
self.__dict__[name][self._index] = inst
|
||||||
|
|
||||||
@ -46,8 +46,14 @@ class ReplayBuffer(object):
|
|||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
indice = self.sample_index(batch_size)
|
indice = self.sample_index(batch_size)
|
||||||
return Batch(obs=self.obs[indice], act=self.act[indice], rew=self.rew[indice],
|
return Batch(
|
||||||
done=self.done[indice], obs_next=self.obs_next[indice], info=self.info[indice])
|
obs=self.obs[indice],
|
||||||
|
act=self.act[indice],
|
||||||
|
rew=self.rew[indice],
|
||||||
|
done=self.done[indice],
|
||||||
|
obs_next=self.obs_next[indice],
|
||||||
|
info=self.info[indice]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
|
24
tianshou/env/wrapper.py
vendored
24
tianshou/env/wrapper.py
vendored
@ -1,6 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
from tianshou.utils import CloudpickleWrapper
|
from tianshou.utils import CloudpickleWrapper
|
||||||
|
|
||||||
@ -89,12 +93,15 @@ class VectorEnv(object):
|
|||||||
class SubprocVectorEnv(object):
|
class SubprocVectorEnv(object):
|
||||||
"""docstring for SubProcVectorEnv"""
|
"""docstring for SubProcVectorEnv"""
|
||||||
def __init__(self, env_fns, **kwargs):
|
def __init__(self, env_fns, **kwargs):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
self.processes = [Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
self.processes = [
|
||||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)]
|
Process(target=self.worker, args=(parent, child, CloudpickleWrapper(env_fn), kwargs), daemon=True)
|
||||||
|
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
||||||
|
]
|
||||||
for p in self.processes:
|
for p in self.processes:
|
||||||
p.start()
|
p.start()
|
||||||
for c in self.child_remote:
|
for c in self.child_remote:
|
||||||
@ -103,26 +110,26 @@ class SubprocVectorEnv(object):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
def worker(parent, p, env_fn_wrapper, **kwargs):
|
def worker(self, parent, p, env_fn_wrapper, **kwargs):
|
||||||
reset_after_done = kwargs.get('reset_after_done', True)
|
reset_after_done = kwargs.get('reset_after_done', True)
|
||||||
parent.close()
|
parent.close()
|
||||||
env = env_fn_wrapper.data()
|
env = env_fn_wrapper.data()
|
||||||
while True:
|
while True:
|
||||||
cmd, data = p.recv()
|
cmd, data = p.recv()
|
||||||
if cmd is 'step':
|
if cmd == 'step':
|
||||||
obs, rew, done, info = env.step(data)
|
obs, rew, done, info = env.step(data)
|
||||||
if reset_after_done and done:
|
if reset_after_done and done:
|
||||||
# s_ is useless when episode finishes
|
# s_ is useless when episode finishes
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
p.send([obs, rew, done, info])
|
p.send([obs, rew, done, info])
|
||||||
elif cmd is 'reset':
|
elif cmd == 'reset':
|
||||||
p.send(env.reset())
|
p.send(env.reset())
|
||||||
elif cmd is 'close':
|
elif cmd == 'close':
|
||||||
p.close()
|
p.close()
|
||||||
break
|
break
|
||||||
elif cmd is 'render':
|
elif cmd == 'render':
|
||||||
p.send(env.render())
|
p.send(env.render())
|
||||||
elif cmd is 'seed':
|
elif cmd == 'seed':
|
||||||
p.send(env.seed(data))
|
p.send(env.seed(data))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -163,7 +170,6 @@ class SubprocVectorEnv(object):
|
|||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(object):
|
class RayVectorEnv(object):
|
||||||
"""docstring for RayVectorEnv"""
|
"""docstring for RayVectorEnv"""
|
||||||
def __init__(self, env_fns, **kwargs):
|
def __init__(self, env_fns, **kwargs):
|
||||||
|
@ -4,7 +4,9 @@ import cloudpickle
|
|||||||
class CloudpickleWrapper(object):
|
class CloudpickleWrapper(object):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return cloudpickle.dumps(self.data)
|
return cloudpickle.dumps(self.data)
|
||||||
|
|
||||||
def __setstate__(self, data):
|
def __setstate__(self, data):
|
||||||
self.data = cloudpickle.loads(data)
|
self.data = cloudpickle.loads(data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user