maybe finished collector?
This commit is contained in:
parent
f58c1397c6
commit
f16e05c0e7
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
|||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
pip install pytest pytest-cov
|
pip install pytest pytest-cov
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
|
||||||
|
2
setup.py
2
setup.py
@ -42,7 +42,7 @@ setup(
|
|||||||
'tqdm',
|
'tqdm',
|
||||||
'numpy',
|
'numpy',
|
||||||
'torch',
|
'torch',
|
||||||
'cloudpickle'
|
'cloudpickle',
|
||||||
'tensorboard',
|
'tensorboard',
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
16
test/test_batch.py
Normal file
16
test/test_batch.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch():
|
||||||
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||||
|
batch.update(obs=[1])
|
||||||
|
assert batch.obs == [1]
|
||||||
|
batch.append(batch)
|
||||||
|
assert batch.obs == [1, 1]
|
||||||
|
assert batch.np.shape == (6, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_batch()
|
@ -14,8 +14,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
obs_next, rew, done, info = env.step(a)
|
obs_next, rew, done, info = env.step(a)
|
||||||
buf.add(obs, a, rew, done, obs_next, info)
|
buf.add(obs, a, rew, done, obs_next, info)
|
||||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||||
indice = buf.sample_indice(4)
|
data, indice = buf.sample(4)
|
||||||
data = buf.sample(4)
|
|
||||||
assert (indice < len(buf)).all()
|
assert (indice < len(buf)).all()
|
||||||
assert (data.obs < size).all()
|
assert (data.obs < size).all()
|
||||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||||
|
@ -79,7 +79,9 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
|||||||
for a in action_list:
|
for a in action_list:
|
||||||
e.step([a] * num)
|
e.step([a] * num)
|
||||||
t[i] = time.time() - t[i]
|
t[i] = time.time() - t[i]
|
||||||
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
|
print(f'VectorEnv: {t[0]:.6f}s')
|
||||||
|
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
||||||
|
print(f'RayVectorEnv: {t[2]:.6f}s')
|
||||||
for v in venv:
|
for v in venv:
|
||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
from tianshou import data, env, utils
|
from tianshou import data, env, utils, policy
|
||||||
|
|
||||||
__version__ = '0.2.0'
|
__version__ = '0.2.0'
|
||||||
__all__ = ['data', 'env', 'utils']
|
__all__ = [
|
||||||
|
'data',
|
||||||
|
'env',
|
||||||
|
'utils',
|
||||||
|
'policy'
|
||||||
|
]
|
||||||
|
@ -2,4 +2,9 @@ from tianshou.data.batch import Batch
|
|||||||
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
|
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||||
from tianshou.data.collector import Collector
|
from tianshou.data.collector import Collector
|
||||||
|
|
||||||
__all__ = ['Batch', 'ReplayBuffer', 'PrioritizedReplayBuffer', 'Collector']
|
__all__ = [
|
||||||
|
'Batch',
|
||||||
|
'ReplayBuffer',
|
||||||
|
'PrioritizedReplayBuffer',
|
||||||
|
'Collector'
|
||||||
|
]
|
||||||
|
@ -1,9 +1,29 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Batch(object):
|
class Batch(object):
|
||||||
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.obs_next = None
|
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def update(self, **kwargs):
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
|
|
||||||
|
def append(self, batch):
|
||||||
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
|
for k in batch.__dict__.keys():
|
||||||
|
if batch.__dict__[k] is None:
|
||||||
|
continue
|
||||||
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||||
|
self.__dict__[k] = batch.__dict__[k]
|
||||||
|
elif isinstance(batch.__dict__[k], np.ndarray):
|
||||||
|
self.__dict__[k] = np.concatenate([
|
||||||
|
self.__dict__[k], batch.__dict__[k]])
|
||||||
|
elif isinstance(batch.__dict__[k], list):
|
||||||
|
self.__dict__[k] += batch.__dict__[k]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'No support append method with {} in class Batch.'
|
||||||
|
.format(type(batch.__dict__[k])))
|
||||||
|
@ -4,6 +4,7 @@ from tianshou.data.batch import Batch
|
|||||||
|
|
||||||
class ReplayBuffer(object):
|
class ReplayBuffer(object):
|
||||||
"""docstring for ReplayBuffer"""
|
"""docstring for ReplayBuffer"""
|
||||||
|
|
||||||
def __init__(self, size):
|
def __init__(self, size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
@ -19,7 +20,8 @@ class ReplayBuffer(object):
|
|||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||||
elif isinstance(inst, dict):
|
elif isinstance(inst, dict):
|
||||||
self.__dict__[name] = np.array([{} for _ in range(self._maxsize)])
|
self.__dict__[name] = np.array(
|
||||||
|
[{} for _ in range(self._maxsize)])
|
||||||
else: # assume `inst` is a number
|
else: # assume `inst` is a number
|
||||||
self.__dict__[name] = np.zeros([self._maxsize])
|
self.__dict__[name] = np.zeros([self._maxsize])
|
||||||
self.__dict__[name][self._index] = inst
|
self.__dict__[name][self._index] = inst
|
||||||
@ -28,7 +30,8 @@ class ReplayBuffer(object):
|
|||||||
'''
|
'''
|
||||||
weight: importance weights, disabled here
|
weight: importance weights, disabled here
|
||||||
'''
|
'''
|
||||||
assert isinstance(info, dict), 'You should return a dict in the last argument of env.step function.'
|
assert isinstance(info, dict),\
|
||||||
|
'You should return a dict in the last argument of env.step().'
|
||||||
self._add_to_buffer('obs', obs)
|
self._add_to_buffer('obs', obs)
|
||||||
self._add_to_buffer('act', act)
|
self._add_to_buffer('act', act)
|
||||||
self._add_to_buffer('rew', rew)
|
self._add_to_buffer('rew', rew)
|
||||||
@ -42,18 +45,11 @@ class ReplayBuffer(object):
|
|||||||
self._index = self._size = 0
|
self._index = self._size = 0
|
||||||
self.indice = []
|
self.indice = []
|
||||||
|
|
||||||
def sample_indice(self, batch_size):
|
def sample(self, batch_size):
|
||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
self.indice = np.random.choice(self._size, batch_size)
|
indice = np.random.choice(self._size, batch_size)
|
||||||
else:
|
else:
|
||||||
self.indice = np.arange(self._size)
|
indice = np.arange(self._size)
|
||||||
return self.indice
|
|
||||||
|
|
||||||
def sample(self, batch_size, indice=None):
|
|
||||||
if indice is None:
|
|
||||||
indice = self.sample_indice(batch_size)
|
|
||||||
else:
|
|
||||||
self.indice = indice
|
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self.obs[indice],
|
obs=self.obs[indice],
|
||||||
act=self.act[indice],
|
act=self.act[indice],
|
||||||
@ -61,11 +57,12 @@ class ReplayBuffer(object):
|
|||||||
done=self.done[indice],
|
done=self.done[indice],
|
||||||
obs_next=self.obs_next[indice],
|
obs_next=self.obs_next[indice],
|
||||||
info=self.info[indice]
|
info=self.info[indice]
|
||||||
)
|
), indice
|
||||||
|
|
||||||
|
|
||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
"""docstring for PrioritizedReplayBuffer"""
|
"""docstring for PrioritizedReplayBuffer"""
|
||||||
|
|
||||||
def __init__(self, size):
|
def __init__(self, size):
|
||||||
super().__init__(size)
|
super().__init__(size)
|
||||||
|
|
||||||
|
@ -5,8 +5,10 @@ from tianshou.env import BaseVectorEnv
|
|||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
"""docstring for Collector"""
|
"""docstring for Collector"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer):
|
def __init__(self, policy, env, buffer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
@ -18,15 +20,18 @@ class Collector(object):
|
|||||||
if self.multi_env:
|
if self.multi_env:
|
||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
if isinstance(self.buffer, list):
|
if isinstance(self.buffer, list):
|
||||||
assert len(self.buffer) == self.env_num, 'The data buffer number does not match the input env number.'
|
assert len(self.buffer) == self.env_num,\
|
||||||
|
'Data buffer number does not match the input env number.'
|
||||||
elif isinstance(self.buffer, ReplayBuffer):
|
elif isinstance(self.buffer, ReplayBuffer):
|
||||||
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
self.buffer = [deepcopy(buffer) for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
raise TypeError('The buffer in data collector is invalid!')
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
self.clear_buffer()
|
self.clear_buffer()
|
||||||
# state over batch is either a list, an np.ndarray, or torch.Tensor (hasattr 'shape')
|
# state over batch is either a list, an np.ndarray, or torch.Tensor
|
||||||
self.state = None
|
self.state = None
|
||||||
|
self.stat_reward = MovAvg()
|
||||||
|
self.stat_length = MovAvg()
|
||||||
|
|
||||||
def clear_buffer(self):
|
def clear_buffer(self):
|
||||||
if self.multi_env:
|
if self.multi_env:
|
||||||
@ -38,39 +43,64 @@ class Collector(object):
|
|||||||
def reset_env(self):
|
def reset_env(self):
|
||||||
self._obs = self.env.reset()
|
self._obs = self.env.reset()
|
||||||
self._act = self._rew = self._done = self._info = None
|
self._act = self._rew = self._done = self._info = None
|
||||||
|
if self.multi_env:
|
||||||
|
self.reward = np.zeros(self.env_num)
|
||||||
|
self.length = np.zeros(self.env_num)
|
||||||
|
else:
|
||||||
|
self.reward, self.length = 0, 0
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, tqdm_hook=None):
|
def collect(self, n_step=0, n_episode=0):
|
||||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1, "One and only one collection number specification permitted!"
|
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||||
|
"One and only one collection number specification permitted!"
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
|
cur_episode = np.zeros(self.env_num) if self.multi_env else 0
|
||||||
while True:
|
while True:
|
||||||
if self.multi_env:
|
if self.multi_env:
|
||||||
batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, info=self._info)
|
batch_data = Batch(
|
||||||
|
obs=self._obs, act=self._act, rew=self._rew,
|
||||||
|
done=self._done, obs_next=None, info=self._info)
|
||||||
else:
|
else:
|
||||||
batch_data = Batch(obs=[self._obs], act=[self._act], rew=[self._rew], done=[self._done], info=[self_info])
|
batch_data = Batch(
|
||||||
|
obs=[self._obs], act=[self._act], rew=[self._rew],
|
||||||
|
done=[self._done], obs_next=None, info=[self._info])
|
||||||
result = self.policy.act(batch_data, self.state)
|
result = self.policy.act(batch_data, self.state)
|
||||||
self.state = result.state
|
self.state = result.state
|
||||||
self._act = result.act
|
self._act = result.act
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(self._act)
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
|
self._act)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
|
self.length += 1
|
||||||
|
self.reward += self._rew
|
||||||
if self.multi_env:
|
if self.multi_env:
|
||||||
for i in range(self.env_num):
|
for i in range(self.env_num):
|
||||||
if n_episode > 0 and cur_episode[i] < n_episode or n_episode == 0:
|
if n_episode > 0 and \
|
||||||
self.buffer[i].add(self._obs[i], self._act[i], self._rew[i], self._done[i], obs_next[i], self._info[i])
|
cur_episode[i] < n_episode or n_episode == 0:
|
||||||
|
self.buffer[i].add(
|
||||||
|
self._obs[i], self._act[i], self._rew[i],
|
||||||
|
self._done[i], obs_next[i], self._info[i])
|
||||||
if self._done[i]:
|
if self._done[i]:
|
||||||
cur_episode[i] += 1
|
cur_episode[i] += 1
|
||||||
|
self.stat_reward.add(self.reward[i])
|
||||||
|
self.stat_length.add(self.length[i])
|
||||||
|
self.reward[i], self.length[i] = 0, 0
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[i] = None
|
self.state[i] = None
|
||||||
else:
|
else:
|
||||||
self.state[i] = self.state[i] * 0
|
self.state[i] = self.state[i] * 0
|
||||||
if hasattr(self.state, 'detach'): # remove count in torch
|
if hasattr(self.state, 'detach'):
|
||||||
|
# remove ref in torch
|
||||||
self.state = self.state.detach()
|
self.state = self.state.detach()
|
||||||
if n_episode > 0 and (cur_episode >= n_episode).all():
|
if n_episode > 0 and (cur_episode >= n_episode).all():
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info)
|
self.buffer.add(
|
||||||
|
self._obs, self._act[0], self._rew,
|
||||||
|
self._done, obs_next, self._info)
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
|
self.stat_reward.add(self.reward)
|
||||||
|
self.stat_length.add(self.length)
|
||||||
|
self.reward, self.length = 0, 0
|
||||||
self.state = None
|
self.state = None
|
||||||
if n_episode > 0 and cur_episode >= n_episode:
|
if n_episode > 0 and cur_episode >= n_episode:
|
||||||
break
|
break
|
||||||
@ -79,8 +109,29 @@ class Collector(object):
|
|||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
|
|
||||||
def sample(self):
|
def sample(self, batch_size):
|
||||||
pass
|
if self.multi_env:
|
||||||
|
if batch_size > 0:
|
||||||
|
lens = [len(b) for b in self.buffer]
|
||||||
|
total = sum(lens)
|
||||||
|
ib = np.random.choice(
|
||||||
|
total, batch_size, p=np.array(lens) / total)
|
||||||
|
else:
|
||||||
|
ib = np.array([])
|
||||||
|
batch_data = Batch()
|
||||||
|
for i, b in enumerate(self.buffer):
|
||||||
|
cur_batch = (ib == i).sum()
|
||||||
|
if batch_size and cur_batch or batch_size <= 0:
|
||||||
|
batch, indice = b.sample(cur_batch)
|
||||||
|
batch = self.process_fn(batch, b, indice)
|
||||||
|
batch_data.append(batch)
|
||||||
|
else:
|
||||||
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
|
return batch_data
|
||||||
|
|
||||||
def stat(self):
|
def stat(self):
|
||||||
pass
|
return {
|
||||||
|
'reward': self.stat_reward.get(),
|
||||||
|
'length': self.stat_length.get(),
|
||||||
|
}
|
||||||
|
12
tianshou/env/__init__.py
vendored
12
tianshou/env/__init__.py
vendored
@ -1,3 +1,11 @@
|
|||||||
from tianshou.env.wrapper import FrameStack, BaseVectorEnv, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
from tianshou.env.wrapper import FrameStack,\
|
||||||
|
BaseVectorEnv, VectorEnv, SubprocVectorEnv,\
|
||||||
|
RayVectorEnv
|
||||||
|
|
||||||
__all__ = ['FrameStack', 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv']
|
__all__ = [
|
||||||
|
'FrameStack',
|
||||||
|
'BaseVectorEnv',
|
||||||
|
'VectorEnv',
|
||||||
|
'SubprocVectorEnv',
|
||||||
|
'RayVectorEnv'
|
||||||
|
]
|
||||||
|
22
tianshou/env/wrapper.py
vendored
22
tianshou/env/wrapper.py
vendored
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
from abc import ABC
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@ -64,6 +64,7 @@ class BaseVectorEnv(ABC):
|
|||||||
|
|
||||||
class VectorEnv(BaseVectorEnv):
|
class VectorEnv(BaseVectorEnv):
|
||||||
"""docstring for VectorEnv"""
|
"""docstring for VectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns, reset_after_done=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.envs = [_() for _ in env_fns]
|
self.envs = [_() for _ in env_fns]
|
||||||
@ -129,14 +130,19 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
|
|||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
"""docstring for SubProcVectorEnv"""
|
"""docstring for SubProcVectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns, reset_after_done=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.parent_remote, self.child_remote = zip(*[Pipe() for _ in range(self.env_num)])
|
self.parent_remote, self.child_remote = \
|
||||||
|
zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
self.processes = [
|
self.processes = [
|
||||||
Process(target=worker, args=(parent, child, CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
|
Process(target=worker, args=(
|
||||||
for (parent, child, env_fn) in zip(self.parent_remote, self.child_remote, env_fns)
|
parent, child,
|
||||||
|
CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
|
||||||
|
for (parent, child, env_fn) in zip(
|
||||||
|
self.parent_remote, self.child_remote, env_fns)
|
||||||
]
|
]
|
||||||
for p in self.processes:
|
for p in self.processes:
|
||||||
p.start()
|
p.start()
|
||||||
@ -185,6 +191,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
"""docstring for RayVectorEnv"""
|
"""docstring for RayVectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns, reset_after_done=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
@ -193,8 +200,11 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init()
|
ray.init()
|
||||||
except NameError:
|
except NameError:
|
||||||
raise ImportError('Please install ray to support VectorEnv: pip3 install ray -U')
|
raise ImportError(
|
||||||
self.envs = [ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns]
|
'Please install ray to support VectorEnv: pip3 install ray -U')
|
||||||
|
self.envs = [
|
||||||
|
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
||||||
|
for e in env_fns]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy.base import BasePolicy
|
||||||
|
|
||||||
__all__ = ['BasePolicy']
|
__all__ = [
|
||||||
|
'BasePolicy'
|
||||||
|
]
|
||||||
|
@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
class BasePolicy(ABC):
|
class BasePolicy(ABC):
|
||||||
"""docstring for BasePolicy"""
|
"""docstring for BasePolicy"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -21,8 +22,8 @@ class BasePolicy(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_fn(batch, buffer, index):
|
def process_fn(batch, buffer, indice):
|
||||||
pass
|
return batch
|
||||||
|
|
||||||
def exploration(self):
|
def exploration(self):
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user