656 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			656 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import time
 | 
						|
import warnings
 | 
						|
from collections.abc import Callable
 | 
						|
from dataclasses import dataclass
 | 
						|
from typing import Any, cast
 | 
						|
 | 
						|
import gymnasium as gym
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
 | 
						|
from tianshou.data import (
 | 
						|
    Batch,
 | 
						|
    CachedReplayBuffer,
 | 
						|
    PrioritizedReplayBuffer,
 | 
						|
    ReplayBuffer,
 | 
						|
    ReplayBufferManager,
 | 
						|
    SequenceSummaryStats,
 | 
						|
    VectorReplayBuffer,
 | 
						|
    to_numpy,
 | 
						|
)
 | 
						|
from tianshou.data.batch import alloc_by_keys_diff
 | 
						|
from tianshou.data.types import RolloutBatchProtocol
 | 
						|
from tianshou.env import BaseVectorEnv, DummyVectorEnv
 | 
						|
from tianshou.policy import BasePolicy
 | 
						|
 | 
						|
 | 
						|
@dataclass(kw_only=True)
 | 
						|
class CollectStatsBase:
 | 
						|
    """The most basic stats, often used for offline learning."""
 | 
						|
 | 
						|
    n_collected_episodes: int = 0
 | 
						|
    """The number of collected episodes."""
 | 
						|
    n_collected_steps: int = 0
 | 
						|
    """The number of collected steps."""
 | 
						|
 | 
						|
 | 
						|
@dataclass(kw_only=True)
 | 
						|
class CollectStats(CollectStatsBase):
 | 
						|
    """A data structure for storing the statistics of rollouts."""
 | 
						|
 | 
						|
    collect_time: float = 0.0
 | 
						|
    """The time for collecting transitions."""
 | 
						|
    collect_speed: float = 0.0
 | 
						|
    """The speed of collecting (env_step per second)."""
 | 
						|
    returns: np.ndarray
 | 
						|
    """The collected episode returns."""
 | 
						|
    returns_stat: SequenceSummaryStats | None  # can be None if no episode ends during collect step
 | 
						|
    """Stats of the collected returns."""
 | 
						|
    lens: np.ndarray
 | 
						|
    """The collected episode lengths."""
 | 
						|
    lens_stat: SequenceSummaryStats | None  # can be None if no episode ends during collect step
 | 
						|
    """Stats of the collected episode lengths."""
 | 
						|
 | 
						|
 | 
						|
class Collector:
 | 
						|
    """Collector enables the policy to interact with different types of envs with exact number of steps or episodes.
 | 
						|
 | 
						|
    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
 | 
						|
    :param env: a ``gym.Env`` environment or an instance of the
 | 
						|
        :class:`~tianshou.env.BaseVectorEnv` class.
 | 
						|
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
 | 
						|
        If set to None, it will not store the data. Default to None.
 | 
						|
    :param function preprocess_fn: a function called before the data has been added to
 | 
						|
        the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
 | 
						|
    :param exploration_noise: determine whether the action needs to be modified
 | 
						|
        with corresponding policy's exploration noise. If so, "policy.
 | 
						|
        exploration_noise(act, batch)" will be called automatically to add the
 | 
						|
        exploration noise into action. Default to False.
 | 
						|
 | 
						|
    The "preprocess_fn" is a function called before the data has been added to the
 | 
						|
    buffer with batch format. It will receive only "obs" and "env_id" when the
 | 
						|
    collector resets the environment, and will receive the keys "obs_next", "rew",
 | 
						|
    "terminated", "truncated, "info", "policy" and "env_id" in a normal env step.
 | 
						|
    Alternatively, it may also accept the keys "obs_next", "rew", "done", "info",
 | 
						|
    "policy" and "env_id".
 | 
						|
    It returns either a dict or a :class:`~tianshou.data.Batch` with the modified
 | 
						|
    keys and values. Examples are in "test/base/test_collector.py".
 | 
						|
 | 
						|
    .. note::
 | 
						|
 | 
						|
        Please make sure the given environment has a time limitation if using n_episode
 | 
						|
        collect option.
 | 
						|
 | 
						|
    .. note::
 | 
						|
 | 
						|
        In past versions of Tianshou, the replay buffer that was passed to `__init__`
 | 
						|
        was automatically reset. This is not done in the current implementation.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        policy: BasePolicy,
 | 
						|
        env: gym.Env | BaseVectorEnv,
 | 
						|
        buffer: ReplayBuffer | None = None,
 | 
						|
        preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
 | 
						|
        exploration_noise: bool = False,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
 | 
						|
            self.env = DummyVectorEnv([lambda: env])
 | 
						|
        else:
 | 
						|
            self.env = env  # type: ignore
 | 
						|
        self.env_num = len(self.env)
 | 
						|
        self.exploration_noise = exploration_noise
 | 
						|
        self.buffer: ReplayBuffer
 | 
						|
        self._assign_buffer(buffer)
 | 
						|
        self.policy = policy
 | 
						|
        self.preprocess_fn = preprocess_fn
 | 
						|
        self._action_space = self.env.action_space
 | 
						|
        self.data: RolloutBatchProtocol
 | 
						|
        # avoid creating attribute outside __init__
 | 
						|
        self.reset(False)
 | 
						|
 | 
						|
    def _assign_buffer(self, buffer: ReplayBuffer | None) -> None:
 | 
						|
        """Check if the buffer matches the constraint."""
 | 
						|
        if buffer is None:
 | 
						|
            buffer = VectorReplayBuffer(self.env_num, self.env_num)
 | 
						|
        elif isinstance(buffer, ReplayBufferManager):
 | 
						|
            assert buffer.buffer_num >= self.env_num
 | 
						|
            if isinstance(buffer, CachedReplayBuffer):
 | 
						|
                assert buffer.cached_buffer_num >= self.env_num
 | 
						|
        else:  # ReplayBuffer or PrioritizedReplayBuffer
 | 
						|
            assert buffer.maxsize > 0
 | 
						|
            if self.env_num > 1:
 | 
						|
                if isinstance(buffer, ReplayBuffer):
 | 
						|
                    buffer_type = "ReplayBuffer"
 | 
						|
                    vector_type = "VectorReplayBuffer"
 | 
						|
                if isinstance(buffer, PrioritizedReplayBuffer):
 | 
						|
                    buffer_type = "PrioritizedReplayBuffer"
 | 
						|
                    vector_type = "PrioritizedVectorReplayBuffer"
 | 
						|
                raise TypeError(
 | 
						|
                    f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
 | 
						|
                    f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
 | 
						|
                    f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.",
 | 
						|
                )
 | 
						|
        self.buffer = buffer
 | 
						|
 | 
						|
    def reset(
 | 
						|
        self,
 | 
						|
        reset_buffer: bool = True,
 | 
						|
        gym_reset_kwargs: dict[str, Any] | None = None,
 | 
						|
    ) -> None:
 | 
						|
        """Reset the environment, statistics, current data and possibly replay memory.
 | 
						|
 | 
						|
        :param reset_buffer: if true, reset the replay buffer that is attached
 | 
						|
            to the collector.
 | 
						|
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
 | 
						|
            reset function. Defaults to None (extra keyword arguments)
 | 
						|
        """
 | 
						|
        # use empty Batch for "state" so that self.data supports slicing
 | 
						|
        # convert empty Batch to None when passing data to policy
 | 
						|
        data = Batch(
 | 
						|
            obs={},
 | 
						|
            act={},
 | 
						|
            rew={},
 | 
						|
            terminated={},
 | 
						|
            truncated={},
 | 
						|
            done={},
 | 
						|
            obs_next={},
 | 
						|
            info={},
 | 
						|
            policy={},
 | 
						|
        )
 | 
						|
        self.data = cast(RolloutBatchProtocol, data)
 | 
						|
        self.reset_env(gym_reset_kwargs)
 | 
						|
        if reset_buffer:
 | 
						|
            self.reset_buffer()
 | 
						|
        self.reset_stat()
 | 
						|
 | 
						|
    def reset_stat(self) -> None:
 | 
						|
        """Reset the statistic variables."""
 | 
						|
        self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
 | 
						|
 | 
						|
    def reset_buffer(self, keep_statistics: bool = False) -> None:
 | 
						|
        """Reset the data buffer."""
 | 
						|
        self.buffer.reset(keep_statistics=keep_statistics)
 | 
						|
 | 
						|
    def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None:
 | 
						|
        """Reset all of the environments."""
 | 
						|
        gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
 | 
						|
        obs, info = self.env.reset(**gym_reset_kwargs)
 | 
						|
        if self.preprocess_fn:
 | 
						|
            processed_data = self.preprocess_fn(obs=obs, info=info, env_id=np.arange(self.env_num))
 | 
						|
            obs = processed_data.get("obs", obs)
 | 
						|
            info = processed_data.get("info", info)
 | 
						|
        self.data.info = info  # type: ignore
 | 
						|
        self.data.obs = obs
 | 
						|
 | 
						|
    def _reset_state(self, id: int | list[int]) -> None:
 | 
						|
        """Reset the hidden state: self.data.state[id]."""
 | 
						|
        if hasattr(self.data.policy, "hidden_state"):
 | 
						|
            state = self.data.policy.hidden_state  # it is a reference
 | 
						|
            if isinstance(state, torch.Tensor):
 | 
						|
                state[id].zero_()
 | 
						|
            elif isinstance(state, np.ndarray):
 | 
						|
                state[id] = None if state.dtype == object else 0
 | 
						|
            elif isinstance(state, Batch):
 | 
						|
                state.empty_(id)
 | 
						|
 | 
						|
    def _reset_env_with_ids(
 | 
						|
        self,
 | 
						|
        local_ids: list[int] | np.ndarray,
 | 
						|
        global_ids: list[int] | np.ndarray,
 | 
						|
        gym_reset_kwargs: dict[str, Any] | None = None,
 | 
						|
    ) -> None:
 | 
						|
        gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
 | 
						|
        obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs)
 | 
						|
        if self.preprocess_fn:
 | 
						|
            processed_data = self.preprocess_fn(obs=obs_reset, info=info, env_id=global_ids)
 | 
						|
            obs_reset = processed_data.get("obs", obs_reset)
 | 
						|
            info = processed_data.get("info", info)
 | 
						|
        self.data.info[local_ids] = info  # type: ignore
 | 
						|
 | 
						|
        self.data.obs_next[local_ids] = obs_reset  # type: ignore
 | 
						|
 | 
						|
    def collect(
 | 
						|
        self,
 | 
						|
        n_step: int | None = None,
 | 
						|
        n_episode: int | None = None,
 | 
						|
        random: bool = False,
 | 
						|
        render: float | None = None,
 | 
						|
        no_grad: bool = True,
 | 
						|
        gym_reset_kwargs: dict[str, Any] | None = None,
 | 
						|
    ) -> CollectStats:
 | 
						|
        """Collect a specified number of step or episode.
 | 
						|
 | 
						|
        To ensure unbiased sampling result with n_episode option, this function will
 | 
						|
        first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
 | 
						|
        episodes, they will be collected evenly from each env.
 | 
						|
 | 
						|
        :param n_step: how many steps you want to collect.
 | 
						|
        :param n_episode: how many episodes you want to collect.
 | 
						|
        :param random: whether to use random policy for collecting data. Default
 | 
						|
            to False.
 | 
						|
        :param render: the sleep time between rendering consecutive frames.
 | 
						|
            Default to None (no rendering).
 | 
						|
        :param no_grad: whether to retain gradient in policy.forward(). Default to
 | 
						|
            True (no gradient retaining).
 | 
						|
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
 | 
						|
            reset function. Defaults to None (extra keyword arguments)
 | 
						|
 | 
						|
        .. note::
 | 
						|
 | 
						|
            One and only one collection number specification is permitted, either
 | 
						|
            ``n_step`` or ``n_episode``.
 | 
						|
 | 
						|
        :return: A dataclass object
 | 
						|
        """
 | 
						|
        assert not self.env.is_async, "Please use AsyncCollector if using async venv."
 | 
						|
        if n_step is not None:
 | 
						|
            assert n_episode is None, (
 | 
						|
                f"Only one of n_step or n_episode is allowed in Collector."
 | 
						|
                f"collect, got n_step={n_step}, n_episode={n_episode}."
 | 
						|
            )
 | 
						|
            assert n_step > 0
 | 
						|
            if n_step % self.env_num != 0:
 | 
						|
                warnings.warn(
 | 
						|
                    f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
 | 
						|
                    "which may cause extra transitions collected into the buffer.",
 | 
						|
                )
 | 
						|
            ready_env_ids = np.arange(self.env_num)
 | 
						|
        elif n_episode is not None:
 | 
						|
            assert n_episode > 0
 | 
						|
            ready_env_ids = np.arange(min(self.env_num, n_episode))
 | 
						|
            self.data = self.data[: min(self.env_num, n_episode)]
 | 
						|
        else:
 | 
						|
            raise TypeError(
 | 
						|
                "Please specify at least one (either n_step or n_episode) "
 | 
						|
                "in AsyncCollector.collect().",
 | 
						|
            )
 | 
						|
 | 
						|
        start_time = time.time()
 | 
						|
 | 
						|
        step_count = 0
 | 
						|
        episode_count = 0
 | 
						|
        episode_returns: list[float] = []
 | 
						|
        episode_lens: list[int] = []
 | 
						|
        episode_start_indices: list[int] = []
 | 
						|
 | 
						|
        while True:
 | 
						|
            assert len(self.data) == len(ready_env_ids)
 | 
						|
            # restore the state: if the last state is None, it won't store
 | 
						|
            last_state = self.data.policy.pop("hidden_state", None)
 | 
						|
 | 
						|
            # get the next action
 | 
						|
            if random:
 | 
						|
                try:
 | 
						|
                    act_sample = [self._action_space[i].sample() for i in ready_env_ids]
 | 
						|
                except TypeError:  # envpool's action space is not for per-env
 | 
						|
                    act_sample = [self._action_space.sample() for _ in ready_env_ids]
 | 
						|
                act_sample = self.policy.map_action_inverse(act_sample)  # type: ignore
 | 
						|
                self.data.update(act=act_sample)
 | 
						|
            else:
 | 
						|
                if no_grad:
 | 
						|
                    with torch.no_grad():  # faster than retain_grad version
 | 
						|
                        # self.data.obs will be used by agent to get result
 | 
						|
                        result = self.policy(self.data, last_state)
 | 
						|
                else:
 | 
						|
                    result = self.policy(self.data, last_state)
 | 
						|
                # update state / act / policy into self.data
 | 
						|
                policy = result.get("policy", Batch())
 | 
						|
                assert isinstance(policy, Batch)
 | 
						|
                state = result.get("state", None)
 | 
						|
                if state is not None:
 | 
						|
                    policy.hidden_state = state  # save state into buffer
 | 
						|
                act = to_numpy(result.act)
 | 
						|
                if self.exploration_noise:
 | 
						|
                    act = self.policy.exploration_noise(act, self.data)
 | 
						|
                self.data.update(policy=policy, act=act)
 | 
						|
 | 
						|
            # get bounded and remapped actions first (not saved into buffer)
 | 
						|
            action_remap = self.policy.map_action(self.data.act)
 | 
						|
            # step in env
 | 
						|
 | 
						|
            obs_next, rew, terminated, truncated, info = self.env.step(
 | 
						|
                action_remap,
 | 
						|
                ready_env_ids,
 | 
						|
            )
 | 
						|
            done = np.logical_or(terminated, truncated)
 | 
						|
 | 
						|
            self.data.update(
 | 
						|
                obs_next=obs_next,
 | 
						|
                rew=rew,
 | 
						|
                terminated=terminated,
 | 
						|
                truncated=truncated,
 | 
						|
                done=done,
 | 
						|
                info=info,
 | 
						|
            )
 | 
						|
            if self.preprocess_fn:
 | 
						|
                self.data.update(
 | 
						|
                    self.preprocess_fn(
 | 
						|
                        obs_next=self.data.obs_next,
 | 
						|
                        rew=self.data.rew,
 | 
						|
                        done=self.data.done,
 | 
						|
                        info=self.data.info,
 | 
						|
                        policy=self.data.policy,
 | 
						|
                        env_id=ready_env_ids,
 | 
						|
                        act=self.data.act,
 | 
						|
                    ),
 | 
						|
                )
 | 
						|
 | 
						|
            if render:
 | 
						|
                self.env.render()
 | 
						|
                if render > 0 and not np.isclose(render, 0):
 | 
						|
                    time.sleep(render)
 | 
						|
 | 
						|
            # add data into the buffer
 | 
						|
            ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids)
 | 
						|
 | 
						|
            # collect statistics
 | 
						|
            step_count += len(ready_env_ids)
 | 
						|
 | 
						|
            if np.any(done):
 | 
						|
                env_ind_local = np.where(done)[0]
 | 
						|
                env_ind_global = ready_env_ids[env_ind_local]
 | 
						|
                episode_count += len(env_ind_local)
 | 
						|
                episode_lens.extend(ep_len[env_ind_local])
 | 
						|
                episode_returns.extend(ep_rew[env_ind_local])
 | 
						|
                episode_start_indices.extend(ep_idx[env_ind_local])
 | 
						|
                # now we copy obs_next to obs, but since there might be
 | 
						|
                # finished episodes, we have to reset finished envs first.
 | 
						|
                self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
 | 
						|
                for i in env_ind_local:
 | 
						|
                    self._reset_state(i)
 | 
						|
 | 
						|
                # remove surplus env id from ready_env_ids
 | 
						|
                # to avoid bias in selecting environments
 | 
						|
                if n_episode:
 | 
						|
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
 | 
						|
                    if surplus_env_num > 0:
 | 
						|
                        mask = np.ones_like(ready_env_ids, dtype=bool)
 | 
						|
                        mask[env_ind_local[:surplus_env_num]] = False
 | 
						|
                        ready_env_ids = ready_env_ids[mask]
 | 
						|
                        self.data = self.data[mask]
 | 
						|
 | 
						|
            self.data.obs = self.data.obs_next
 | 
						|
 | 
						|
            if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode):
 | 
						|
                break
 | 
						|
 | 
						|
        # generate statistics
 | 
						|
        self.collect_step += step_count
 | 
						|
        self.collect_episode += episode_count
 | 
						|
        collect_time = max(time.time() - start_time, 1e-9)
 | 
						|
        self.collect_time += collect_time
 | 
						|
 | 
						|
        if n_episode:
 | 
						|
            data = Batch(
 | 
						|
                obs={},
 | 
						|
                act={},
 | 
						|
                rew={},
 | 
						|
                terminated={},
 | 
						|
                truncated={},
 | 
						|
                done={},
 | 
						|
                obs_next={},
 | 
						|
                info={},
 | 
						|
                policy={},
 | 
						|
            )
 | 
						|
            self.data = cast(RolloutBatchProtocol, data)
 | 
						|
            self.reset_env()
 | 
						|
 | 
						|
        return CollectStats(
 | 
						|
            n_collected_episodes=episode_count,
 | 
						|
            n_collected_steps=step_count,
 | 
						|
            collect_time=collect_time,
 | 
						|
            collect_speed=step_count / collect_time,
 | 
						|
            returns=np.array(episode_returns),
 | 
						|
            returns_stat=SequenceSummaryStats.from_sequence(episode_returns)
 | 
						|
            if len(episode_returns) > 0
 | 
						|
            else None,
 | 
						|
            lens=np.array(episode_lens, int),
 | 
						|
            lens_stat=SequenceSummaryStats.from_sequence(episode_lens)
 | 
						|
            if len(episode_lens) > 0
 | 
						|
            else None,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class AsyncCollector(Collector):
 | 
						|
    """Async Collector handles async vector environment.
 | 
						|
 | 
						|
    The arguments are exactly the same as :class:`~tianshou.data.Collector`, please
 | 
						|
    refer to :class:`~tianshou.data.Collector` for more detailed explanation.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        policy: BasePolicy,
 | 
						|
        env: BaseVectorEnv,
 | 
						|
        buffer: ReplayBuffer | None = None,
 | 
						|
        preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
 | 
						|
        exploration_noise: bool = False,
 | 
						|
    ) -> None:
 | 
						|
        # assert env.is_async
 | 
						|
        warnings.warn("Using async setting may collect extra transitions into buffer.")
 | 
						|
        super().__init__(
 | 
						|
            policy,
 | 
						|
            env,
 | 
						|
            buffer,
 | 
						|
            preprocess_fn,
 | 
						|
            exploration_noise,
 | 
						|
        )
 | 
						|
 | 
						|
    def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None:
 | 
						|
        super().reset_env(gym_reset_kwargs)
 | 
						|
        self._ready_env_ids = np.arange(self.env_num)
 | 
						|
 | 
						|
    def collect(
 | 
						|
        self,
 | 
						|
        n_step: int | None = None,
 | 
						|
        n_episode: int | None = None,
 | 
						|
        random: bool = False,
 | 
						|
        render: float | None = None,
 | 
						|
        no_grad: bool = True,
 | 
						|
        gym_reset_kwargs: dict[str, Any] | None = None,
 | 
						|
    ) -> CollectStats:
 | 
						|
        """Collect a specified number of step or episode with async env setting.
 | 
						|
 | 
						|
        This function doesn't collect exactly n_step or n_episode number of
 | 
						|
        transitions. Instead, in order to support async setting, it may collect more
 | 
						|
        than given n_step or n_episode transitions and save into buffer.
 | 
						|
 | 
						|
        :param n_step: how many steps you want to collect.
 | 
						|
        :param n_episode: how many episodes you want to collect.
 | 
						|
        :param random: whether to use random policy for collecting data. Default
 | 
						|
            to False.
 | 
						|
        :param render: the sleep time between rendering consecutive frames.
 | 
						|
            Default to None (no rendering).
 | 
						|
        :param no_grad: whether to retain gradient in policy.forward(). Default to
 | 
						|
            True (no gradient retaining).
 | 
						|
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
 | 
						|
            reset function. Defaults to None (extra keyword arguments)
 | 
						|
 | 
						|
        .. note::
 | 
						|
 | 
						|
            One and only one collection number specification is permitted, either
 | 
						|
            ``n_step`` or ``n_episode``.
 | 
						|
 | 
						|
        :return: A dataclass object
 | 
						|
        """
 | 
						|
        # collect at least n_step or n_episode
 | 
						|
        if n_step is not None:
 | 
						|
            assert n_episode is None, (
 | 
						|
                "Only one of n_step or n_episode is allowed in Collector."
 | 
						|
                f"collect, got n_step={n_step}, n_episode={n_episode}."
 | 
						|
            )
 | 
						|
            assert n_step > 0
 | 
						|
        elif n_episode is not None:
 | 
						|
            assert n_episode > 0
 | 
						|
        else:
 | 
						|
            raise TypeError(
 | 
						|
                "Please specify at least one (either n_step or n_episode) "
 | 
						|
                "in AsyncCollector.collect().",
 | 
						|
            )
 | 
						|
 | 
						|
        ready_env_ids = self._ready_env_ids
 | 
						|
 | 
						|
        start_time = time.time()
 | 
						|
 | 
						|
        step_count = 0
 | 
						|
        episode_count = 0
 | 
						|
        episode_returns: list[float] = []
 | 
						|
        episode_lens: list[int] = []
 | 
						|
        episode_start_indices: list[int] = []
 | 
						|
 | 
						|
        while True:
 | 
						|
            whole_data = self.data
 | 
						|
            self.data = self.data[ready_env_ids]
 | 
						|
            assert len(whole_data) == self.env_num  # major difference
 | 
						|
            # restore the state: if the last state is None, it won't store
 | 
						|
            last_state = self.data.policy.pop("hidden_state", None)
 | 
						|
 | 
						|
            # get the next action
 | 
						|
            if random:
 | 
						|
                try:
 | 
						|
                    act_sample = [self._action_space[i].sample() for i in ready_env_ids]
 | 
						|
                except TypeError:  # envpool's action space is not for per-env
 | 
						|
                    act_sample = [self._action_space.sample() for _ in ready_env_ids]
 | 
						|
                act_sample = self.policy.map_action_inverse(act_sample)  # type: ignore
 | 
						|
                self.data.update(act=act_sample)
 | 
						|
            else:
 | 
						|
                if no_grad:
 | 
						|
                    with torch.no_grad():  # faster than retain_grad version
 | 
						|
                        # self.data.obs will be used by agent to get result
 | 
						|
                        result = self.policy(self.data, last_state)
 | 
						|
                else:
 | 
						|
                    result = self.policy(self.data, last_state)
 | 
						|
                # update state / act / policy into self.data
 | 
						|
                policy = result.get("policy", Batch())
 | 
						|
                assert isinstance(policy, Batch)
 | 
						|
                state = result.get("state", None)
 | 
						|
                if state is not None:
 | 
						|
                    policy.hidden_state = state  # save state into buffer
 | 
						|
                act = to_numpy(result.act)
 | 
						|
                if self.exploration_noise:
 | 
						|
                    act = self.policy.exploration_noise(act, self.data)
 | 
						|
                self.data.update(policy=policy, act=act)
 | 
						|
 | 
						|
            # save act/policy before env.step
 | 
						|
            try:
 | 
						|
                whole_data.act[ready_env_ids] = self.data.act  # type: ignore
 | 
						|
                whole_data.policy[ready_env_ids] = self.data.policy
 | 
						|
            except ValueError:
 | 
						|
                alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
 | 
						|
                whole_data[ready_env_ids] = self.data  # lots of overhead
 | 
						|
 | 
						|
            # get bounded and remapped actions first (not saved into buffer)
 | 
						|
            action_remap = self.policy.map_action(self.data.act)
 | 
						|
            # step in env
 | 
						|
            obs_next, rew, terminated, truncated, info = self.env.step(
 | 
						|
                action_remap,
 | 
						|
                ready_env_ids,
 | 
						|
            )
 | 
						|
            done = np.logical_or(terminated, truncated)
 | 
						|
 | 
						|
            # change self.data here because ready_env_ids has changed
 | 
						|
            try:
 | 
						|
                ready_env_ids = info["env_id"]
 | 
						|
            except Exception:
 | 
						|
                ready_env_ids = np.array([i["env_id"] for i in info])
 | 
						|
            self.data = whole_data[ready_env_ids]
 | 
						|
 | 
						|
            self.data.update(
 | 
						|
                obs_next=obs_next,
 | 
						|
                rew=rew,
 | 
						|
                terminated=terminated,
 | 
						|
                truncated=truncated,
 | 
						|
                info=info,
 | 
						|
            )
 | 
						|
            if self.preprocess_fn:
 | 
						|
                try:
 | 
						|
                    self.data.update(
 | 
						|
                        self.preprocess_fn(
 | 
						|
                            obs_next=self.data.obs_next,
 | 
						|
                            rew=self.data.rew,
 | 
						|
                            terminated=self.data.terminated,
 | 
						|
                            truncated=self.data.truncated,
 | 
						|
                            info=self.data.info,
 | 
						|
                            env_id=ready_env_ids,
 | 
						|
                            act=self.data.act,
 | 
						|
                        ),
 | 
						|
                    )
 | 
						|
                except TypeError:
 | 
						|
                    self.data.update(
 | 
						|
                        self.preprocess_fn(
 | 
						|
                            obs_next=self.data.obs_next,
 | 
						|
                            rew=self.data.rew,
 | 
						|
                            done=self.data.done,
 | 
						|
                            info=self.data.info,
 | 
						|
                            env_id=ready_env_ids,
 | 
						|
                            act=self.data.act,
 | 
						|
                        ),
 | 
						|
                    )
 | 
						|
 | 
						|
            if render:
 | 
						|
                self.env.render()
 | 
						|
                if render > 0 and not np.isclose(render, 0):
 | 
						|
                    time.sleep(render)
 | 
						|
 | 
						|
            # add data into the buffer
 | 
						|
            ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids)
 | 
						|
 | 
						|
            # collect statistics
 | 
						|
            step_count += len(ready_env_ids)
 | 
						|
 | 
						|
            if np.any(done):
 | 
						|
                env_ind_local = np.where(done)[0]
 | 
						|
                env_ind_global = ready_env_ids[env_ind_local]
 | 
						|
                episode_count += len(env_ind_local)
 | 
						|
                episode_lens.extend(ep_len[env_ind_local])
 | 
						|
                episode_returns.extend(ep_rew[env_ind_local])
 | 
						|
                episode_start_indices.extend(ep_idx[env_ind_local])
 | 
						|
                # now we copy obs_next to obs, but since there might be
 | 
						|
                # finished episodes, we have to reset finished envs first.
 | 
						|
                self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
 | 
						|
                for i in env_ind_local:
 | 
						|
                    self._reset_state(i)
 | 
						|
 | 
						|
            try:
 | 
						|
                # Need to ignore types b/c according to mypy Tensors cannot be indexed
 | 
						|
                # by arrays (which they can...)
 | 
						|
                whole_data.obs[ready_env_ids] = self.data.obs_next  # type: ignore
 | 
						|
                whole_data.rew[ready_env_ids] = self.data.rew
 | 
						|
                whole_data.done[ready_env_ids] = self.data.done
 | 
						|
                whole_data.info[ready_env_ids] = self.data.info  # type: ignore
 | 
						|
            except ValueError:
 | 
						|
                alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
 | 
						|
                self.data.obs = self.data.obs_next
 | 
						|
                # lots of overhead
 | 
						|
                whole_data[ready_env_ids] = self.data
 | 
						|
            self.data = whole_data
 | 
						|
 | 
						|
            if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode):
 | 
						|
                break
 | 
						|
 | 
						|
        self._ready_env_ids = ready_env_ids
 | 
						|
 | 
						|
        # generate statistics
 | 
						|
        self.collect_step += step_count
 | 
						|
        self.collect_episode += episode_count
 | 
						|
        collect_time = max(time.time() - start_time, 1e-9)
 | 
						|
        self.collect_time += collect_time
 | 
						|
 | 
						|
        return CollectStats(
 | 
						|
            n_collected_episodes=episode_count,
 | 
						|
            n_collected_steps=step_count,
 | 
						|
            collect_time=collect_time,
 | 
						|
            collect_speed=step_count / collect_time,
 | 
						|
            returns=np.array(episode_returns),
 | 
						|
            returns_stat=SequenceSummaryStats.from_sequence(episode_returns)
 | 
						|
            if len(episode_returns) > 0
 | 
						|
            else None,
 | 
						|
            lens=np.array(episode_lens, int),
 | 
						|
            lens_stat=SequenceSummaryStats.from_sequence(episode_lens)
 | 
						|
            if len(episode_lens) > 0
 | 
						|
            else None,
 | 
						|
        )
 |