Tianshou/tianshou/data/collector.py
Carlo Cagnetta 5fc314bd4b
Docs/use nbqa on notebooks (#1041)
- Added nbqa to pyproject.toml
- Resolved mypy issues on notebooks and related files
- Conducting ruff checks on notebooks
- Add DataclassPPrintMixin for better stats representation
- Improved Notebooks wording and explanations

Resolve: #1004
Related to #974
2024-02-07 17:28:16 +01:00

659 lines
26 KiB
Python

import time
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, cast
import gymnasium as gym
import numpy as np
import torch
from tianshou.data import (
Batch,
CachedReplayBuffer,
PrioritizedReplayBuffer,
ReplayBuffer,
ReplayBufferManager,
SequenceSummaryStats,
VectorReplayBuffer,
to_numpy,
)
from tianshou.data.batch import alloc_by_keys_diff
from tianshou.data.types import RolloutBatchProtocol
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin
@dataclass(kw_only=True)
class CollectStatsBase(DataclassPPrintMixin):
"""The most basic stats, often used for offline learning."""
n_collected_episodes: int = 0
"""The number of collected episodes."""
n_collected_steps: int = 0
"""The number of collected steps."""
@dataclass(kw_only=True)
class CollectStats(CollectStatsBase):
"""A data structure for storing the statistics of rollouts."""
collect_time: float = 0.0
"""The time for collecting transitions."""
collect_speed: float = 0.0
"""The speed of collecting (env_step per second)."""
returns: np.ndarray
"""The collected episode returns."""
returns_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step
"""Stats of the collected returns."""
lens: np.ndarray
"""The collected episode lengths."""
lens_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step
"""Stats of the collected episode lengths."""
class Collector:
"""Collector enables the policy to interact with different types of envs with exact number of steps or episodes.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
If set to None, it will not store the data. Default to None.
:param function preprocess_fn: a function called before the data has been added to
the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
:param exploration_noise: determine whether the action needs to be modified
with corresponding policy's exploration noise. If so, "policy.
exploration_noise(act, batch)" will be called automatically to add the
exploration noise into action. Default to False.
The "preprocess_fn" is a function called before the data has been added to the
buffer with batch format. It will receive only "obs" and "env_id" when the
collector resets the environment, and will receive the keys "obs_next", "rew",
"terminated", "truncated, "info", "policy" and "env_id" in a normal env step.
Alternatively, it may also accept the keys "obs_next", "rew", "done", "info",
"policy" and "env_id".
It returns either a dict or a :class:`~tianshou.data.Batch` with the modified
keys and values. Examples are in "test/base/test_collector.py".
.. note::
Please make sure the given environment has a time limitation if using n_episode
collect option.
.. note::
In past versions of Tianshou, the replay buffer that was passed to `__init__`
was automatically reset. This is not done in the current implementation.
"""
def __init__(
self,
policy: BasePolicy,
env: gym.Env | BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
self.env = DummyVectorEnv([lambda: env])
else:
self.env = env # type: ignore
self.env_num = len(self.env)
self.exploration_noise = exploration_noise
self.buffer: ReplayBuffer
self._assign_buffer(buffer)
self.policy = policy
self.preprocess_fn = preprocess_fn
self._action_space = self.env.action_space
self.data: RolloutBatchProtocol
# avoid creating attribute outside __init__
self.reset(False)
def _assign_buffer(self, buffer: ReplayBuffer | None) -> None:
"""Check if the buffer matches the constraint."""
if buffer is None:
buffer = VectorReplayBuffer(self.env_num, self.env_num)
elif isinstance(buffer, ReplayBufferManager):
assert buffer.buffer_num >= self.env_num
if isinstance(buffer, CachedReplayBuffer):
assert buffer.cached_buffer_num >= self.env_num
else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0
if self.env_num > 1:
if isinstance(buffer, ReplayBuffer):
buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer"
if isinstance(buffer, PrioritizedReplayBuffer):
buffer_type = "PrioritizedReplayBuffer"
vector_type = "PrioritizedVectorReplayBuffer"
raise TypeError(
f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.",
)
self.buffer = buffer
def reset(
self,
reset_buffer: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> None:
"""Reset the environment, statistics, current data and possibly replay memory.
:param reset_buffer: if true, reset the replay buffer that is attached
to the collector.
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
"""
# use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy
data = Batch(
obs={},
act={},
rew={},
terminated={},
truncated={},
done={},
obs_next={},
info={},
policy={},
)
self.data = cast(RolloutBatchProtocol, data)
self.reset_env(gym_reset_kwargs)
if reset_buffer:
self.reset_buffer()
self.reset_stat()
def reset_stat(self) -> None:
"""Reset the statistic variables."""
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
def reset_buffer(self, keep_statistics: bool = False) -> None:
"""Reset the data buffer."""
self.buffer.reset(keep_statistics=keep_statistics)
def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None:
"""Reset all of the environments."""
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
obs, info = self.env.reset(**gym_reset_kwargs)
if self.preprocess_fn:
processed_data = self.preprocess_fn(obs=obs, info=info, env_id=np.arange(self.env_num))
obs = processed_data.get("obs", obs)
info = processed_data.get("info", info)
self.data.info = info # type: ignore
self.data.obs = obs
def _reset_state(self, id: int | list[int]) -> None:
"""Reset the hidden state: self.data.state[id]."""
if hasattr(self.data.policy, "hidden_state"):
state = self.data.policy.hidden_state # it is a reference
if isinstance(state, torch.Tensor):
state[id].zero_()
elif isinstance(state, np.ndarray):
state[id] = None if state.dtype == object else 0
elif isinstance(state, Batch):
state.empty_(id)
def _reset_env_with_ids(
self,
local_ids: list[int] | np.ndarray,
global_ids: list[int] | np.ndarray,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs)
if self.preprocess_fn:
processed_data = self.preprocess_fn(obs=obs_reset, info=info, env_id=global_ids)
obs_reset = processed_data.get("obs", obs_reset)
info = processed_data.get("info", info)
self.data.info[local_ids] = info # type: ignore
self.data.obs_next[local_ids] = obs_reset # type: ignore
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
"""Collect a specified number of step or episode.
To ensure unbiased sampling result with n_episode option, this function will
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
episodes, they will be collected evenly from each env.
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data. Default
to False.
:param render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: A dataclass object
"""
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None:
assert n_episode is None, (
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
if n_step % self.env_num != 0:
warnings.warn(
f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
"which may cause extra transitions collected into the buffer.",
)
ready_env_ids = np.arange(self.env_num)
elif n_episode is not None:
assert n_episode > 0
ready_env_ids = np.arange(min(self.env_num, n_episode))
self.data = self.data[: min(self.env_num, n_episode)]
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)
start_time = time.time()
step_count = 0
episode_count = 0
episode_returns: list[float] = []
episode_lens: list[int] = []
episode_start_indices: list[int] = []
while True:
assert len(self.data) == len(ready_env_ids)
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)
# get the next action
if random:
try:
act_sample = [self._action_space[i].sample() for i in ready_env_ids]
except TypeError: # envpool's action space is not for per-env
act_sample = [self._action_space.sample() for _ in ready_env_ids]
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
self.data.update(act=act_sample)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
policy = result.get("policy", Batch())
assert isinstance(policy, Batch)
state = result.get("state", None)
if state is not None:
policy.hidden_state = state # save state into buffer
act = to_numpy(result.act)
if self.exploration_noise:
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, terminated, truncated, info = self.env.step(
action_remap,
ready_env_ids,
)
done = np.logical_or(terminated, truncated)
self.data.update(
obs_next=obs_next,
rew=rew,
terminated=terminated,
truncated=truncated,
done=done,
info=info,
)
if self.preprocess_fn:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
env_id=ready_env_ids,
act=self.data.act,
),
)
if render:
self.env.render()
if render > 0 and not np.isclose(render, 0):
time.sleep(render)
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids)
# collect statistics
step_count += len(ready_env_ids)
if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.extend(ep_len[env_ind_local])
episode_returns.extend(ep_rew[env_ind_local])
episode_start_indices.extend(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
for i in env_ind_local:
self._reset_state(i)
# remove surplus env id from ready_env_ids
# to avoid bias in selecting environments
if n_episode:
surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
if surplus_env_num > 0:
mask = np.ones_like(ready_env_ids, dtype=bool)
mask[env_ind_local[:surplus_env_num]] = False
ready_env_ids = ready_env_ids[mask]
self.data = self.data[mask]
self.data.obs = self.data.obs_next
if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode):
break
# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
collect_time = max(time.time() - start_time, 1e-9)
self.collect_time += collect_time
if n_episode:
data = Batch(
obs={},
act={},
rew={},
terminated={},
truncated={},
done={},
obs_next={},
info={},
policy={},
)
self.data = cast(RolloutBatchProtocol, data)
self.reset_env()
return CollectStats(
n_collected_episodes=episode_count,
n_collected_steps=step_count,
collect_time=collect_time,
collect_speed=step_count / collect_time,
returns=np.array(episode_returns),
returns_stat=SequenceSummaryStats.from_sequence(episode_returns)
if len(episode_returns) > 0
else None,
lens=np.array(episode_lens, int),
lens_stat=SequenceSummaryStats.from_sequence(episode_lens)
if len(episode_lens) > 0
else None,
)
class AsyncCollector(Collector):
"""Async Collector handles async vector environment.
The arguments are exactly the same as :class:`~tianshou.data.Collector`, please
refer to :class:`~tianshou.data.Collector` for more detailed explanation.
"""
def __init__(
self,
policy: BasePolicy,
env: BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
# assert env.is_async
warnings.warn("Using async setting may collect extra transitions into buffer.")
super().__init__(
policy,
env,
buffer,
preprocess_fn,
exploration_noise,
)
def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None:
super().reset_env(gym_reset_kwargs)
self._ready_env_ids = np.arange(self.env_num)
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
"""Collect a specified number of step or episode with async env setting.
This function doesn't collect exactly n_step or n_episode number of
transitions. Instead, in order to support async setting, it may collect more
than given n_step or n_episode transitions and save into buffer.
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data. Default
to False.
:param render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: A dataclass object
"""
# collect at least n_step or n_episode
if n_step is not None:
assert n_episode is None, (
"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
elif n_episode is not None:
assert n_episode > 0
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)
ready_env_ids = self._ready_env_ids
start_time = time.time()
step_count = 0
episode_count = 0
episode_returns: list[float] = []
episode_lens: list[int] = []
episode_start_indices: list[int] = []
while True:
whole_data = self.data
self.data = self.data[ready_env_ids]
assert len(whole_data) == self.env_num # major difference
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)
# get the next action
if random:
try:
act_sample = [self._action_space[i].sample() for i in ready_env_ids]
except TypeError: # envpool's action space is not for per-env
act_sample = [self._action_space.sample() for _ in ready_env_ids]
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
self.data.update(act=act_sample)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
policy = result.get("policy", Batch())
assert isinstance(policy, Batch)
state = result.get("state", None)
if state is not None:
policy.hidden_state = state # save state into buffer
act = to_numpy(result.act)
if self.exploration_noise:
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# save act/policy before env.step
try:
whole_data.act[ready_env_ids] = self.data.act # type: ignore
whole_data.policy[ready_env_ids] = self.data.policy
except ValueError:
alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
whole_data[ready_env_ids] = self.data # lots of overhead
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, terminated, truncated, info = self.env.step(
action_remap,
ready_env_ids,
)
done = np.logical_or(terminated, truncated)
# change self.data here because ready_env_ids has changed
try:
ready_env_ids = info["env_id"]
except Exception:
ready_env_ids = np.array([i["env_id"] for i in info])
self.data = whole_data[ready_env_ids]
self.data.update(
obs_next=obs_next,
rew=rew,
terminated=terminated,
truncated=truncated,
info=info,
)
if self.preprocess_fn:
try:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
terminated=self.data.terminated,
truncated=self.data.truncated,
info=self.data.info,
env_id=ready_env_ids,
act=self.data.act,
),
)
except TypeError:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
env_id=ready_env_ids,
act=self.data.act,
),
)
if render:
self.env.render()
if render > 0 and not np.isclose(render, 0):
time.sleep(render)
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids)
# collect statistics
step_count += len(ready_env_ids)
if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.extend(ep_len[env_ind_local])
episode_returns.extend(ep_rew[env_ind_local])
episode_start_indices.extend(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
for i in env_ind_local:
self._reset_state(i)
try:
# Need to ignore types b/c according to mypy Tensors cannot be indexed
# by arrays (which they can...)
whole_data.obs[ready_env_ids] = self.data.obs_next # type: ignore
whole_data.rew[ready_env_ids] = self.data.rew
whole_data.done[ready_env_ids] = self.data.done
whole_data.info[ready_env_ids] = self.data.info # type: ignore
except ValueError:
alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
self.data.obs = self.data.obs_next
# lots of overhead
whole_data[ready_env_ids] = self.data
self.data = whole_data
if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode):
break
self._ready_env_ids = ready_env_ids
# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
collect_time = max(time.time() - start_time, 1e-9)
self.collect_time += collect_time
return CollectStats(
n_collected_episodes=episode_count,
n_collected_steps=step_count,
collect_time=collect_time,
collect_speed=step_count / collect_time,
returns=np.array(episode_returns),
returns_stat=SequenceSummaryStats.from_sequence(episode_returns)
if len(episode_returns) > 0
else None,
lens=np.array(episode_lens, int),
lens_stat=SequenceSummaryStats.from_sequence(episode_lens)
if len(episode_lens) > 0
else None,
)