Asynchronous sampling vector environment (#134)
Fix #103 Co-authored-by: youkaichao <youkaichao@126.com> Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
		
							parent
							
								
									30368c29a6
								
							
						
					
					
						commit
						e024afab8c
					
				@ -1,5 +1,6 @@
 | 
			
		||||
import gym
 | 
			
		||||
import time
 | 
			
		||||
import random
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gym.spaces import Discrete, MultiDiscrete, Box
 | 
			
		||||
 | 
			
		||||
@ -9,9 +10,10 @@ class MyTestEnv(gym.Env):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
 | 
			
		||||
                 multidiscrete_action=False):
 | 
			
		||||
                 multidiscrete_action=False, random_sleep=False):
 | 
			
		||||
        self.size = size
 | 
			
		||||
        self.sleep = sleep
 | 
			
		||||
        self.random_sleep = random_sleep
 | 
			
		||||
        self.dict_state = dict_state
 | 
			
		||||
        self.ma_rew = ma_rew
 | 
			
		||||
        self._md_action = multidiscrete_action
 | 
			
		||||
@ -48,7 +50,9 @@ class MyTestEnv(gym.Env):
 | 
			
		||||
        if self.done:
 | 
			
		||||
            raise ValueError('step after done !!!')
 | 
			
		||||
        if self.sleep > 0:
 | 
			
		||||
            time.sleep(self.sleep)
 | 
			
		||||
            sleep_time = random.random() if self.random_sleep else 1
 | 
			
		||||
            sleep_time *= self.sleep
 | 
			
		||||
            time.sleep(sleep_time)
 | 
			
		||||
        if self.index == self.size:
 | 
			
		||||
            self.done = True
 | 
			
		||||
            return self._get_dict_state(), self._get_reward(), self.done, {}
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@ import numpy as np
 | 
			
		||||
from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
 | 
			
		||||
from tianshou.policy import BasePolicy
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv, AsyncVectorEnv
 | 
			
		||||
from tianshou.data import Collector, Batch, ReplayBuffer
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
@ -103,6 +103,51 @@ def test_collector():
 | 
			
		||||
    c2.collect(n_episode=[1, 1, 1, 1], random=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_collector_with_async():
 | 
			
		||||
    env_lens = [2, 3, 4, 5]
 | 
			
		||||
    writer = SummaryWriter('log/async_collector')
 | 
			
		||||
    logger = Logger(writer)
 | 
			
		||||
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
 | 
			
		||||
               for i in env_lens]
 | 
			
		||||
 | 
			
		||||
    venv = AsyncVectorEnv(env_fns)
 | 
			
		||||
    policy = MyPolicy()
 | 
			
		||||
    c1 = Collector(policy, venv,
 | 
			
		||||
                   ReplayBuffer(size=1000, ignore_obs_next=False),
 | 
			
		||||
                   logger.preprocess_fn)
 | 
			
		||||
    c1.collect(n_episode=10)
 | 
			
		||||
    # check if the data in the buffer is chronological
 | 
			
		||||
    # i.e. data in the buffer are full episodes, and each episode is
 | 
			
		||||
    # returned by the same environment
 | 
			
		||||
    env_id = c1.buffer.info['env_id']
 | 
			
		||||
    size = len(c1.buffer)
 | 
			
		||||
    obs = c1.buffer.obs[:size]
 | 
			
		||||
    done = c1.buffer.done[:size]
 | 
			
		||||
    print(env_id[:size])
 | 
			
		||||
    print(obs)
 | 
			
		||||
    obs_ground_truth = []
 | 
			
		||||
    i = 0
 | 
			
		||||
    while i < size:
 | 
			
		||||
        # i is the start of an episode
 | 
			
		||||
        if done[i]:
 | 
			
		||||
            # this episode has one transition
 | 
			
		||||
            assert env_lens[env_id[i]] == 1
 | 
			
		||||
            i += 1
 | 
			
		||||
            continue
 | 
			
		||||
        j = i
 | 
			
		||||
        while True:
 | 
			
		||||
            j += 1
 | 
			
		||||
            # in one episode, the environment id is the same
 | 
			
		||||
            assert env_id[j] == env_id[i]
 | 
			
		||||
            if done[j]:
 | 
			
		||||
                break
 | 
			
		||||
        j = j + 1  # j is the start of the next episode
 | 
			
		||||
        assert j - i == env_lens[env_id[i]]
 | 
			
		||||
        obs_ground_truth += list(range(j - i))
 | 
			
		||||
        i = j
 | 
			
		||||
    assert np.allclose(obs, obs_ground_truth)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_collector_with_dict_state():
 | 
			
		||||
    env = MyTestEnv(size=5, sleep=0, dict_state=True)
 | 
			
		||||
    policy = MyPolicy(dict_state=True)
 | 
			
		||||
@ -181,3 +226,4 @@ if __name__ == '__main__':
 | 
			
		||||
    test_collector()
 | 
			
		||||
    test_collector_with_dict_state()
 | 
			
		||||
    test_collector_with_ma()
 | 
			
		||||
    test_collector_with_async()
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,9 @@
 | 
			
		||||
import time
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gym.spaces.discrete import Discrete
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
 | 
			
		||||
from tianshou.data import Batch
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv, \
 | 
			
		||||
    RayVectorEnv, AsyncVectorEnv
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from env import MyTestEnv
 | 
			
		||||
@ -9,6 +11,43 @@ else:  # pytest
 | 
			
		||||
    from test.base.env import MyTestEnv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_async_env(num=8, sleep=0.1):
 | 
			
		||||
    # simplify the test case, just keep stepping
 | 
			
		||||
    size = 10000
 | 
			
		||||
    env_fns = [
 | 
			
		||||
        lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
 | 
			
		||||
        for i in range(size, size + num)
 | 
			
		||||
    ]
 | 
			
		||||
    v = AsyncVectorEnv(env_fns, wait_num=num // 2)
 | 
			
		||||
    v.seed()
 | 
			
		||||
    v.reset()
 | 
			
		||||
    # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
 | 
			
		||||
    # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
 | 
			
		||||
    # expectation of v is n / (n + 1)
 | 
			
		||||
    # for a synchronous environment, the following actions should take
 | 
			
		||||
    # about 7 * sleep * num / (num + 1) seconds
 | 
			
		||||
    # for AsyncVectorEnv, the analysis is complicated, but the time cost
 | 
			
		||||
    # should be smaller
 | 
			
		||||
    action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
 | 
			
		||||
    current_index_start = 0
 | 
			
		||||
    action = action_list[:num]
 | 
			
		||||
    env_ids = list(range(num))
 | 
			
		||||
    o = []
 | 
			
		||||
    spent_time = time.time()
 | 
			
		||||
    while current_index_start < len(action_list):
 | 
			
		||||
        A, B, C, D = v.step(action=action, id=env_ids)
 | 
			
		||||
        b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
 | 
			
		||||
        env_ids = b.info.env_id
 | 
			
		||||
        o.append(b)
 | 
			
		||||
        current_index_start += len(action)
 | 
			
		||||
        action = action_list[current_index_start: current_index_start + len(A)]
 | 
			
		||||
    spent_time = time.time() - spent_time
 | 
			
		||||
    data = Batch.cat(o)
 | 
			
		||||
    # assure 1/7 improvement
 | 
			
		||||
    assert spent_time < 6.0 * sleep * num / (num + 1)
 | 
			
		||||
    return spent_time, data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_vecenv(size=10, num=8, sleep=0.001):
 | 
			
		||||
    verbose = __name__ == '__main__'
 | 
			
		||||
    env_fns = [
 | 
			
		||||
@ -60,3 +99,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    test_vecenv()
 | 
			
		||||
    test_async_env()
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,11 @@ import warnings
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Any, Dict, List, Union, Optional, Callable
 | 
			
		||||
 | 
			
		||||
from tianshou.env import BaseVectorEnv, VectorEnv
 | 
			
		||||
from tianshou.env import BaseVectorEnv, VectorEnv, AsyncVectorEnv
 | 
			
		||||
from tianshou.policy import BasePolicy
 | 
			
		||||
from tianshou.exploration import BaseNoise
 | 
			
		||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
 | 
			
		||||
from tianshou.data.batch import _create_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Collector(object):
 | 
			
		||||
@ -96,6 +97,13 @@ class Collector(object):
 | 
			
		||||
            env = VectorEnv([lambda: env])
 | 
			
		||||
        self.env = env
 | 
			
		||||
        self.env_num = len(env)
 | 
			
		||||
        # environments that are available in step()
 | 
			
		||||
        # this means all environments in synchronous simulation
 | 
			
		||||
        # but only a subset of environments in asynchronous simulation
 | 
			
		||||
        self._ready_env_ids = np.arange(self.env_num)
 | 
			
		||||
        # self.async is a flag to indicate whether this collector works
 | 
			
		||||
        # with asynchronous simulation
 | 
			
		||||
        self.is_async = isinstance(env, AsyncVectorEnv)
 | 
			
		||||
        # need cache buffers before storing in the main buffer
 | 
			
		||||
        self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
 | 
			
		||||
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
 | 
			
		||||
@ -105,6 +113,9 @@ class Collector(object):
 | 
			
		||||
        self.process_fn = policy.process_fn
 | 
			
		||||
        self._action_noise = action_noise
 | 
			
		||||
        self._rew_metric = reward_metric or Collector._default_rew_metric
 | 
			
		||||
        # avoid creating attribute outside __init__
 | 
			
		||||
        self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
 | 
			
		||||
                          obs_next={}, policy={})
 | 
			
		||||
        self.reset()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -139,6 +150,7 @@ class Collector(object):
 | 
			
		||||
        """Reset all of the environment(s)' states and reset all of the cache
 | 
			
		||||
        buffers (if need).
 | 
			
		||||
        """
 | 
			
		||||
        self._ready_env_ids = np.arange(self.env_num)
 | 
			
		||||
        obs = self.env.reset()
 | 
			
		||||
        if self.preprocess_fn:
 | 
			
		||||
            obs = self.preprocess_fn(obs=obs).get('obs', obs)
 | 
			
		||||
@ -159,7 +171,7 @@ class Collector(object):
 | 
			
		||||
        self.env.close()
 | 
			
		||||
 | 
			
		||||
    def _reset_state(self, id: Union[int, List[int]]) -> None:
 | 
			
		||||
        """Reset self.data.state[id]."""
 | 
			
		||||
        """Reset the hidden state: self.data.state[id]."""
 | 
			
		||||
        state = self.data.state  # it is a reference
 | 
			
		||||
        if isinstance(state, torch.Tensor):
 | 
			
		||||
            state[id].zero_()
 | 
			
		||||
@ -207,6 +219,7 @@ class Collector(object):
 | 
			
		||||
        # episode of each environment
 | 
			
		||||
        episode_count = np.zeros(self.env_num)
 | 
			
		||||
        reward_total = 0.0
 | 
			
		||||
        whole_data = Batch()
 | 
			
		||||
        while True:
 | 
			
		||||
            if step_count >= 100000 and episode_count.sum() == 0:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
@ -214,6 +227,15 @@ class Collector(object):
 | 
			
		||||
                    'You should add a time limitation to your environment!',
 | 
			
		||||
                    Warning)
 | 
			
		||||
 | 
			
		||||
            if self.is_async:
 | 
			
		||||
                # self.data are the data for all environments
 | 
			
		||||
                # in async simulation, only a subset of data are disposed
 | 
			
		||||
                # so we store the whole data in ``whole_data``, let self.data
 | 
			
		||||
                # to be all the data available in ready environments, and
 | 
			
		||||
                # finally set these back into all the data
 | 
			
		||||
                whole_data = self.data
 | 
			
		||||
                self.data = self.data[self._ready_env_ids]
 | 
			
		||||
 | 
			
		||||
            # restore the state and the input data
 | 
			
		||||
            last_state = self.data.state
 | 
			
		||||
            if last_state.is_empty():
 | 
			
		||||
@ -222,8 +244,16 @@ class Collector(object):
 | 
			
		||||
 | 
			
		||||
            # calculate the next action
 | 
			
		||||
            if random:
 | 
			
		||||
                if self.is_async:
 | 
			
		||||
                    # TODO self.env.action_space will invoke remote call for
 | 
			
		||||
                    #  all environments, which may hang in async simulation.
 | 
			
		||||
                    #  This can be avoided by using a random policy, but not
 | 
			
		||||
                    #  in the collector level. Leave it as a future work.
 | 
			
		||||
                    raise RuntimeError("cannot use random "
 | 
			
		||||
                                       "sampling in async simulation!")
 | 
			
		||||
                spaces = self.env.action_space
 | 
			
		||||
                result = Batch(
 | 
			
		||||
                    act=[a.sample() for a in self.env.action_space])
 | 
			
		||||
                    act=[spaces[i].sample() for i in self._ready_env_ids])
 | 
			
		||||
            else:
 | 
			
		||||
                with torch.no_grad():
 | 
			
		||||
                    result = self.policy(self.data, last_state)
 | 
			
		||||
@ -243,8 +273,18 @@ class Collector(object):
 | 
			
		||||
                self.data.act += self._action_noise(self.data.act.shape)
 | 
			
		||||
 | 
			
		||||
            # step in env
 | 
			
		||||
            obs_next, rew, done, info = self.env.step(self.data.act)
 | 
			
		||||
 | 
			
		||||
            if not self.is_async:
 | 
			
		||||
                obs_next, rew, done, info = self.env.step(self.data.act)
 | 
			
		||||
            else:
 | 
			
		||||
                # store computed actions, states, etc
 | 
			
		||||
                _batch_set_item(whole_data, self._ready_env_ids,
 | 
			
		||||
                                self.data, self.env_num)
 | 
			
		||||
                # fetch finished data
 | 
			
		||||
                obs_next, rew, done, info = self.env.step(
 | 
			
		||||
                    action=self.data.act, id=self._ready_env_ids)
 | 
			
		||||
                self._ready_env_ids = np.array([i['env_id'] for i in info])
 | 
			
		||||
                # get the stepped data
 | 
			
		||||
                self.data = whole_data[self._ready_env_ids]
 | 
			
		||||
            # move data to self.data
 | 
			
		||||
            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
 | 
			
		||||
 | 
			
		||||
@ -256,9 +296,11 @@ class Collector(object):
 | 
			
		||||
            if self.preprocess_fn:
 | 
			
		||||
                result = self.preprocess_fn(**self.data)
 | 
			
		||||
                self.data.update(result)
 | 
			
		||||
            for i in range(self.env_num):
 | 
			
		||||
                self._cached_buf[i].add(**self.data[i])
 | 
			
		||||
                if self.data.done[i]:
 | 
			
		||||
            for j, i in enumerate(self._ready_env_ids):
 | 
			
		||||
                # j is the index in current ready_env_ids
 | 
			
		||||
                # i is the index in all environments
 | 
			
		||||
                self._cached_buf[i].add(**self.data[j])
 | 
			
		||||
                if self.data.done[j]:
 | 
			
		||||
                    if n_step or np.isscalar(n_episode) or \
 | 
			
		||||
                            episode_count[i] < n_episode[i]:
 | 
			
		||||
                        episode_count[i] += 1
 | 
			
		||||
@ -267,17 +309,24 @@ class Collector(object):
 | 
			
		||||
                        if self.buffer is not None:
 | 
			
		||||
                            self.buffer.update(self._cached_buf[i])
 | 
			
		||||
                    self._cached_buf[i].reset()
 | 
			
		||||
                    self._reset_state(i)
 | 
			
		||||
                    self._reset_state(j)
 | 
			
		||||
            obs_next = self.data.obs_next
 | 
			
		||||
            if sum(self.data.done):
 | 
			
		||||
                env_ind = np.where(self.data.done)[0]
 | 
			
		||||
                obs_reset = self.env.reset(env_ind)
 | 
			
		||||
                env_ind_local = np.where(self.data.done)[0]
 | 
			
		||||
                env_ind_global = self._ready_env_ids[env_ind_local]
 | 
			
		||||
                obs_reset = self.env.reset(env_ind_global)
 | 
			
		||||
                if self.preprocess_fn:
 | 
			
		||||
                    obs_next[env_ind] = self.preprocess_fn(
 | 
			
		||||
                    obs_next[env_ind_local] = self.preprocess_fn(
 | 
			
		||||
                        obs=obs_reset).get('obs', obs_reset)
 | 
			
		||||
                else:
 | 
			
		||||
                    obs_next[env_ind] = obs_reset
 | 
			
		||||
                    obs_next[env_ind_local] = obs_reset
 | 
			
		||||
            self.data.obs = obs_next
 | 
			
		||||
            if self.is_async:
 | 
			
		||||
                # set data back
 | 
			
		||||
                _batch_set_item(whole_data, self._ready_env_ids,
 | 
			
		||||
                                self.data, self.env_num)
 | 
			
		||||
                # let self.data be the data in all environments again
 | 
			
		||||
                self.data = whole_data
 | 
			
		||||
            if n_step:
 | 
			
		||||
                if step_count >= n_step:
 | 
			
		||||
                    break
 | 
			
		||||
@ -320,3 +369,24 @@ class Collector(object):
 | 
			
		||||
        batch_data, indice = self.buffer.sample(batch_size)
 | 
			
		||||
        batch_data = self.process_fn(batch_data, self.buffer, indice)
 | 
			
		||||
        return batch_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _batch_set_item(source: Batch, indices: np.ndarray,
 | 
			
		||||
                    target: Batch, size: int):
 | 
			
		||||
    # for any key chain k, there are three cases
 | 
			
		||||
    # 1. source[k] is non-reserved, but target[k] does not exist or is reserved
 | 
			
		||||
    # 2. source[k] does not exist or is reserved, but target[k] is non-reserved
 | 
			
		||||
    # 3. both source[k] and target[k] is non-reserved
 | 
			
		||||
    for k, v in target.items():
 | 
			
		||||
        if not isinstance(v, Batch) or not v.is_empty():
 | 
			
		||||
            # target[k] is non-reserved
 | 
			
		||||
            vs = source.get(k, Batch())
 | 
			
		||||
            if isinstance(vs, Batch) and vs.is_empty():
 | 
			
		||||
                # case 2
 | 
			
		||||
                # use __dict__ to avoid many type checks
 | 
			
		||||
                source.__dict__[k] = _create_value(v[0], size)
 | 
			
		||||
        else:
 | 
			
		||||
            # target[k] is reserved
 | 
			
		||||
            # case 1
 | 
			
		||||
            continue
 | 
			
		||||
        source.__dict__[k][indices] = v
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								tianshou/env/__init__.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								tianshou/env/__init__.py
									
									
									
									
										vendored
									
									
								
							@ -1,11 +1,15 @@
 | 
			
		||||
from tianshou.env.basevecenv import BaseVectorEnv
 | 
			
		||||
from tianshou.env.vecenv import VectorEnv, SubprocVectorEnv, RayVectorEnv
 | 
			
		||||
from tianshou.env.vecenv.base import BaseVectorEnv
 | 
			
		||||
from tianshou.env.vecenv.dummy import VectorEnv
 | 
			
		||||
from tianshou.env.vecenv.subproc import SubprocVectorEnv
 | 
			
		||||
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
 | 
			
		||||
from tianshou.env.vecenv.rayenv import RayVectorEnv
 | 
			
		||||
from tianshou.env.maenv import MultiAgentEnv
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'BaseVectorEnv',
 | 
			
		||||
    'VectorEnv',
 | 
			
		||||
    'SubprocVectorEnv',
 | 
			
		||||
    'AsyncVectorEnv',
 | 
			
		||||
    'RayVectorEnv',
 | 
			
		||||
    'MultiAgentEnv',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										247
									
								
								tianshou/env/vecenv.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										247
									
								
								tianshou/env/vecenv.py
									
									
									
									
										vendored
									
									
								
							@ -1,247 +0,0 @@
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from multiprocessing import Process, Pipe
 | 
			
		||||
from typing import List, Tuple, Union, Optional, Callable, Any
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import ray
 | 
			
		||||
except ImportError:
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
from tianshou.env import BaseVectorEnv
 | 
			
		||||
from tianshou.env.utils import CloudpickleWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VectorEnv(BaseVectorEnv):
 | 
			
		||||
    """Dummy vectorized environment wrapper, implemented in for-loop.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        self.envs = [_() for _ in env_fns]
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        return [getattr(env, key) if hasattr(env, key) else None
 | 
			
		||||
                for env in self.envs]
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        obs = np.stack([self.envs[i].reset() for i in id])
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: np.ndarray,
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        assert len(action) == len(id)
 | 
			
		||||
        result = [self.envs[i].step(action[i]) for i in id]
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
 | 
			
		||||
        if np.isscalar(seed):
 | 
			
		||||
            seed = [seed + _ for _ in range(self.env_num)]
 | 
			
		||||
        elif seed is None:
 | 
			
		||||
            seed = [seed] * self.env_num
 | 
			
		||||
        result = []
 | 
			
		||||
        for e, s in zip(self.envs, seed):
 | 
			
		||||
            if hasattr(e, 'seed'):
 | 
			
		||||
                result.append(e.seed(s))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        result = []
 | 
			
		||||
        for e in self.envs:
 | 
			
		||||
            if hasattr(e, 'render'):
 | 
			
		||||
                result.append(e.render(**kwargs))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        return [e.close() for e in self.envs]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def worker(parent, p, env_fn_wrapper):
 | 
			
		||||
    parent.close()
 | 
			
		||||
    env = env_fn_wrapper.data()
 | 
			
		||||
    try:
 | 
			
		||||
        while True:
 | 
			
		||||
            cmd, data = p.recv()
 | 
			
		||||
            if cmd == 'step':
 | 
			
		||||
                p.send(env.step(data))
 | 
			
		||||
            elif cmd == 'reset':
 | 
			
		||||
                p.send(env.reset())
 | 
			
		||||
            elif cmd == 'close':
 | 
			
		||||
                p.send(env.close())
 | 
			
		||||
                p.close()
 | 
			
		||||
                break
 | 
			
		||||
            elif cmd == 'render':
 | 
			
		||||
                p.send(env.render(**data) if hasattr(env, 'render') else None)
 | 
			
		||||
            elif cmd == 'seed':
 | 
			
		||||
                p.send(env.seed(data) if hasattr(env, 'seed') else None)
 | 
			
		||||
            elif cmd == 'getattr':
 | 
			
		||||
                p.send(getattr(env, data) if hasattr(env, data) else None)
 | 
			
		||||
            else:
 | 
			
		||||
                p.close()
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
    except KeyboardInterrupt:
 | 
			
		||||
        p.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SubprocVectorEnv(BaseVectorEnv):
 | 
			
		||||
    """Vectorized environment wrapper based on subprocess.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        self.closed = False
 | 
			
		||||
        self.parent_remote, self.child_remote = \
 | 
			
		||||
            zip(*[Pipe() for _ in range(self.env_num)])
 | 
			
		||||
        self.processes = [
 | 
			
		||||
            Process(target=worker, args=(
 | 
			
		||||
                parent, child, CloudpickleWrapper(env_fn)), daemon=True)
 | 
			
		||||
            for (parent, child, env_fn) in zip(
 | 
			
		||||
                self.parent_remote, self.child_remote, env_fns)
 | 
			
		||||
        ]
 | 
			
		||||
        for p in self.processes:
 | 
			
		||||
            p.start()
 | 
			
		||||
        for c in self.child_remote:
 | 
			
		||||
            c.close()
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        for p in self.parent_remote:
 | 
			
		||||
            p.send(['getattr', key])
 | 
			
		||||
        return [p.recv() for p in self.parent_remote]
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: np.ndarray,
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        assert len(action) == len(id)
 | 
			
		||||
        for i, j in enumerate(id):
 | 
			
		||||
            self.parent_remote[j].send(['step', action[i]])
 | 
			
		||||
        result = [self.parent_remote[i].recv() for i in id]
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        for i in id:
 | 
			
		||||
            self.parent_remote[i].send(['reset', None])
 | 
			
		||||
        obs = np.stack([self.parent_remote[i].recv() for i in id])
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
 | 
			
		||||
        if np.isscalar(seed):
 | 
			
		||||
            seed = [seed + _ for _ in range(self.env_num)]
 | 
			
		||||
        elif seed is None:
 | 
			
		||||
            seed = [seed] * self.env_num
 | 
			
		||||
        for p, s in zip(self.parent_remote, seed):
 | 
			
		||||
            p.send(['seed', s])
 | 
			
		||||
        return [p.recv() for p in self.parent_remote]
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        for p in self.parent_remote:
 | 
			
		||||
            p.send(['render', kwargs])
 | 
			
		||||
        return [p.recv() for p in self.parent_remote]
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        if self.closed:
 | 
			
		||||
            return []
 | 
			
		||||
        for p in self.parent_remote:
 | 
			
		||||
            p.send(['close', None])
 | 
			
		||||
        result = [p.recv() for p in self.parent_remote]
 | 
			
		||||
        self.closed = True
 | 
			
		||||
        for p in self.processes:
 | 
			
		||||
            p.join()
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RayVectorEnv(BaseVectorEnv):
 | 
			
		||||
    """Vectorized environment wrapper based on
 | 
			
		||||
    `ray <https://github.com/ray-project/ray>`_. However, according to our
 | 
			
		||||
    test, it is about two times slower than
 | 
			
		||||
    :class:`~tianshou.env.SubprocVectorEnv`.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        try:
 | 
			
		||||
            if not ray.is_initialized():
 | 
			
		||||
                ray.init()
 | 
			
		||||
        except NameError:
 | 
			
		||||
            raise ImportError(
 | 
			
		||||
                'Please install ray to support RayVectorEnv: pip3 install ray')
 | 
			
		||||
        self.envs = [
 | 
			
		||||
            ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
 | 
			
		||||
            for e in env_fns]
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        return ray.get([e.__getattr__.remote(key) for e in self.envs])
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: np.ndarray,
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        assert len(action) == len(id)
 | 
			
		||||
        result = ray.get([self.envs[j].step.remote(action[i])
 | 
			
		||||
                          for i, j in enumerate(id)])
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
 | 
			
		||||
        if not hasattr(self.envs[0], 'seed'):
 | 
			
		||||
            return []
 | 
			
		||||
        if np.isscalar(seed):
 | 
			
		||||
            seed = [seed + _ for _ in range(self.env_num)]
 | 
			
		||||
        elif seed is None:
 | 
			
		||||
            seed = [seed] * self.env_num
 | 
			
		||||
        return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        if not hasattr(self.envs[0], 'render'):
 | 
			
		||||
            return [None for e in self.envs]
 | 
			
		||||
        return ray.get([e.render.remote(**kwargs) for e in self.envs])
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        return ray.get([e.close.remote() for e in self.envs])
 | 
			
		||||
							
								
								
									
										0
									
								
								tianshou/env/vecenv/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tianshou/env/vecenv/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
								
								
									
										104
									
								
								tianshou/env/vecenv/asyncenv.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								tianshou/env/vecenv/asyncenv.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,104 @@
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from multiprocessing import connection
 | 
			
		||||
from typing import List, Tuple, Union, Optional, Callable, Any
 | 
			
		||||
 | 
			
		||||
from tianshou.env import SubprocVectorEnv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncVectorEnv(SubprocVectorEnv):
 | 
			
		||||
    """Vectorized asynchronous environment wrapper based on subprocess.
 | 
			
		||||
 | 
			
		||||
    :param wait_num: used in asynchronous simulation if the time cost of
 | 
			
		||||
        ``env.step`` varies with time and synchronously waiting for all
 | 
			
		||||
        environments to finish a step is time-wasting. In that case, we can
 | 
			
		||||
        return when ``wait_num`` environments finish a step and keep on
 | 
			
		||||
        simulation in these environments. If ``None``, asynchronous simulation
 | 
			
		||||
        is disabled; else, ``1 <= wait_num <= env_num``.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]],
 | 
			
		||||
                 wait_num: Optional[int] = None) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        self.wait_num = wait_num or len(env_fns)
 | 
			
		||||
        assert 1 <= self.wait_num <= len(env_fns), \
 | 
			
		||||
            f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
 | 
			
		||||
        self.waiting_conn = []
 | 
			
		||||
        # environments in self.ready_id is actually ready
 | 
			
		||||
        # but environments in self.waiting_id are just waiting when checked,
 | 
			
		||||
        # and they may be ready now, but this is not known until we check it
 | 
			
		||||
        # in the step() function
 | 
			
		||||
        self.waiting_id = []
 | 
			
		||||
        # all environments are ready in the beginning
 | 
			
		||||
        self.ready_id = list(range(self.env_num))
 | 
			
		||||
 | 
			
		||||
    def _assert_and_transform_id(self,
 | 
			
		||||
                                 id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
                                 ) -> List[int]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = list(range(self.env_num))
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        for i in id:
 | 
			
		||||
            assert i not in self.waiting_id, \
 | 
			
		||||
                f'Cannot reset environment {i} which is stepping now!'
 | 
			
		||||
            assert i in self.ready_id, \
 | 
			
		||||
                f'Can only reset ready environments {self.ready_id}.'
 | 
			
		||||
        return id
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        id = self._assert_and_transform_id(id)
 | 
			
		||||
        return super().reset(id)
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        if len(self.waiting_id) > 0:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Environments {self.waiting_id} are still "
 | 
			
		||||
                f"stepping, cannot render them now.")
 | 
			
		||||
        return super().render(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        if self.closed:
 | 
			
		||||
            return []
 | 
			
		||||
        # finish remaining steps, and close
 | 
			
		||||
        self.step(None)
 | 
			
		||||
        return super().close()
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: Optional[np.ndarray],
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        """
 | 
			
		||||
        Provide the given action to the environments. The action sequence
 | 
			
		||||
        should correspond to the ``id`` argument, and the ``id`` argument
 | 
			
		||||
        should be a subset of the ``env_id`` in the last returned ``info``
 | 
			
		||||
        (initially they are env_ids of all the environments). If action is
 | 
			
		||||
        ``None``, fetch unfinished step() calls instead.
 | 
			
		||||
        """
 | 
			
		||||
        if action is not None:
 | 
			
		||||
            id = self._assert_and_transform_id(id)
 | 
			
		||||
            assert len(action) == len(id)
 | 
			
		||||
            for i, (act, env_id) in enumerate(zip(action, id)):
 | 
			
		||||
                self.parent_remote[env_id].send(['step', act])
 | 
			
		||||
                self.waiting_conn.append(self.parent_remote[env_id])
 | 
			
		||||
                self.waiting_id.append(env_id)
 | 
			
		||||
            self.ready_id = [x for x in self.ready_id if x not in id]
 | 
			
		||||
        result = []
 | 
			
		||||
        while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
 | 
			
		||||
            ready_conns = connection.wait(self.waiting_conn)
 | 
			
		||||
            for conn in ready_conns:
 | 
			
		||||
                waiting_index = self.waiting_conn.index(conn)
 | 
			
		||||
                self.waiting_conn.pop(waiting_index)
 | 
			
		||||
                env_id = self.waiting_id.pop(waiting_index)
 | 
			
		||||
                ans = conn.recv()
 | 
			
		||||
                obs, rew, done, info = ans
 | 
			
		||||
                info["env_id"] = env_id
 | 
			
		||||
                result.append((obs, rew, done, info))
 | 
			
		||||
                self.ready_id.append(env_id)
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
							
								
								
									
										65
									
								
								tianshou/env/vecenv/dummy.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								tianshou/env/vecenv/dummy.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,65 @@
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import List, Tuple, Union, Optional, Callable, Any
 | 
			
		||||
 | 
			
		||||
from tianshou.env import BaseVectorEnv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VectorEnv(BaseVectorEnv):
 | 
			
		||||
    """Dummy vectorized environment wrapper, implemented in for-loop.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        self.envs = [_() for _ in env_fns]
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        return [getattr(env, key) if hasattr(env, key) else None
 | 
			
		||||
                for env in self.envs]
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        obs = np.stack([self.envs[i].reset() for i in id])
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: np.ndarray,
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        assert len(action) == len(id)
 | 
			
		||||
        result = [self.envs[i].step(action[i]) for i in id]
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
 | 
			
		||||
        if np.isscalar(seed):
 | 
			
		||||
            seed = [seed + _ for _ in range(self.env_num)]
 | 
			
		||||
        elif seed is None:
 | 
			
		||||
            seed = [seed] * self.env_num
 | 
			
		||||
        result = []
 | 
			
		||||
        for e, s in zip(self.envs, seed):
 | 
			
		||||
            if hasattr(e, 'seed'):
 | 
			
		||||
                result.append(e.seed(s))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        result = []
 | 
			
		||||
        for e in self.envs:
 | 
			
		||||
            if hasattr(e, 'render'):
 | 
			
		||||
                result.append(e.render(**kwargs))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        return [e.close() for e in self.envs]
 | 
			
		||||
							
								
								
									
										76
									
								
								tianshou/env/vecenv/rayenv.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								tianshou/env/vecenv/rayenv.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import List, Tuple, Union, Optional, Callable, Any
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import ray
 | 
			
		||||
except ImportError:
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
from tianshou.env import BaseVectorEnv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RayVectorEnv(BaseVectorEnv):
 | 
			
		||||
    """Vectorized environment wrapper based on
 | 
			
		||||
    `ray <https://github.com/ray-project/ray>`_. This is a choice to run
 | 
			
		||||
    distributed environments in a cluster.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        try:
 | 
			
		||||
            if not ray.is_initialized():
 | 
			
		||||
                ray.init()
 | 
			
		||||
        except NameError:
 | 
			
		||||
            raise ImportError(
 | 
			
		||||
                'Please install ray to support RayVectorEnv: pip install ray')
 | 
			
		||||
        self.envs = [
 | 
			
		||||
            ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
 | 
			
		||||
            for e in env_fns]
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        return ray.get([e.__getattr__.remote(key) for e in self.envs])
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: np.ndarray,
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        assert len(action) == len(id)
 | 
			
		||||
        result = ray.get([self.envs[j].step.remote(action[i])
 | 
			
		||||
                          for i, j in enumerate(id)])
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
 | 
			
		||||
        if not hasattr(self.envs[0], 'seed'):
 | 
			
		||||
            return []
 | 
			
		||||
        if np.isscalar(seed):
 | 
			
		||||
            seed = [seed + _ for _ in range(self.env_num)]
 | 
			
		||||
        elif seed is None:
 | 
			
		||||
            seed = [seed] * self.env_num
 | 
			
		||||
        return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        if not hasattr(self.envs[0], 'render'):
 | 
			
		||||
            return [None for e in self.envs]
 | 
			
		||||
        return ray.get([e.render.remote(**kwargs) for e in self.envs])
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        return ray.get([e.close.remote() for e in self.envs])
 | 
			
		||||
							
								
								
									
										115
									
								
								tianshou/env/vecenv/subproc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								tianshou/env/vecenv/subproc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,115 @@
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from multiprocessing import Process, Pipe
 | 
			
		||||
from typing import List, Tuple, Union, Optional, Callable, Any
 | 
			
		||||
 | 
			
		||||
from tianshou.env import BaseVectorEnv
 | 
			
		||||
from tianshou.env.utils import CloudpickleWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def worker(parent, p, env_fn_wrapper):
 | 
			
		||||
    parent.close()
 | 
			
		||||
    env = env_fn_wrapper.data()
 | 
			
		||||
    try:
 | 
			
		||||
        while True:
 | 
			
		||||
            cmd, data = p.recv()
 | 
			
		||||
            if cmd == 'step':
 | 
			
		||||
                p.send(env.step(data))
 | 
			
		||||
            elif cmd == 'reset':
 | 
			
		||||
                p.send(env.reset())
 | 
			
		||||
            elif cmd == 'close':
 | 
			
		||||
                p.send(env.close())
 | 
			
		||||
                p.close()
 | 
			
		||||
                break
 | 
			
		||||
            elif cmd == 'render':
 | 
			
		||||
                p.send(env.render(**data) if hasattr(env, 'render') else None)
 | 
			
		||||
            elif cmd == 'seed':
 | 
			
		||||
                p.send(env.seed(data) if hasattr(env, 'seed') else None)
 | 
			
		||||
            elif cmd == 'getattr':
 | 
			
		||||
                p.send(getattr(env, data) if hasattr(env, data) else None)
 | 
			
		||||
            else:
 | 
			
		||||
                p.close()
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
    except KeyboardInterrupt:
 | 
			
		||||
        p.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SubprocVectorEnv(BaseVectorEnv):
 | 
			
		||||
    """Vectorized environment wrapper based on subprocess.
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
        Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
 | 
			
		||||
        explanation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
 | 
			
		||||
        super().__init__(env_fns)
 | 
			
		||||
        self.closed = False
 | 
			
		||||
        self.parent_remote, self.child_remote = \
 | 
			
		||||
            zip(*[Pipe() for _ in range(self.env_num)])
 | 
			
		||||
        self.processes = [
 | 
			
		||||
            Process(target=worker, args=(
 | 
			
		||||
                parent, child, CloudpickleWrapper(env_fn)), daemon=True)
 | 
			
		||||
            for (parent, child, env_fn) in zip(
 | 
			
		||||
                self.parent_remote, self.child_remote, env_fns)
 | 
			
		||||
        ]
 | 
			
		||||
        for p in self.processes:
 | 
			
		||||
            p.start()
 | 
			
		||||
        for c in self.child_remote:
 | 
			
		||||
            c.close()
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, key):
 | 
			
		||||
        for p in self.parent_remote:
 | 
			
		||||
            p.send(['getattr', key])
 | 
			
		||||
        return [p.recv() for p in self.parent_remote]
 | 
			
		||||
 | 
			
		||||
    def step(self,
 | 
			
		||||
             action: np.ndarray,
 | 
			
		||||
             id: Optional[Union[int, List[int]]] = None
 | 
			
		||||
             ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        assert len(action) == len(id)
 | 
			
		||||
        for i, j in enumerate(id):
 | 
			
		||||
            self.parent_remote[j].send(['step', action[i]])
 | 
			
		||||
        result = [self.parent_remote[i].recv() for i in id]
 | 
			
		||||
        obs, rew, done, info = map(np.stack, zip(*result))
 | 
			
		||||
        return obs, rew, done, info
 | 
			
		||||
 | 
			
		||||
    def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
 | 
			
		||||
        if id is None:
 | 
			
		||||
            id = range(self.env_num)
 | 
			
		||||
        elif np.isscalar(id):
 | 
			
		||||
            id = [id]
 | 
			
		||||
        for i in id:
 | 
			
		||||
            self.parent_remote[i].send(['reset', None])
 | 
			
		||||
        obs = np.stack([self.parent_remote[i].recv() for i in id])
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
 | 
			
		||||
        if np.isscalar(seed):
 | 
			
		||||
            seed = [seed + _ for _ in range(self.env_num)]
 | 
			
		||||
        elif seed is None:
 | 
			
		||||
            seed = [seed] * self.env_num
 | 
			
		||||
        for p, s in zip(self.parent_remote, seed):
 | 
			
		||||
            p.send(['seed', s])
 | 
			
		||||
        return [p.recv() for p in self.parent_remote]
 | 
			
		||||
 | 
			
		||||
    def render(self, **kwargs) -> List[Any]:
 | 
			
		||||
        for p in self.parent_remote:
 | 
			
		||||
            p.send(['render', kwargs])
 | 
			
		||||
        return [p.recv() for p in self.parent_remote]
 | 
			
		||||
 | 
			
		||||
    def close(self) -> List[Any]:
 | 
			
		||||
        if self.closed:
 | 
			
		||||
            return []
 | 
			
		||||
        for p in self.parent_remote:
 | 
			
		||||
            p.send(['close', None])
 | 
			
		||||
        result = [p.recv() for p in self.parent_remote]
 | 
			
		||||
        self.closed = True
 | 
			
		||||
        for p in self.processes:
 | 
			
		||||
            p.join()
 | 
			
		||||
        return result
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user