2024-04-26 16:46:03 +02:00
|
|
|
import logging
|
2020-03-16 11:11:29 +08:00
|
|
|
import time
|
2020-03-28 07:27:18 +08:00
|
|
|
import warnings
|
2024-04-26 16:46:03 +02:00
|
|
|
from abc import ABC, abstractmethod
|
2024-03-28 18:02:31 +01:00
|
|
|
from copy import copy
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
from dataclasses import dataclass
|
2024-03-28 18:02:31 +01:00
|
|
|
from typing import Any, Self, TypeVar, cast
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
import gymnasium as gym
|
2020-03-28 15:14:41 +08:00
|
|
|
import numpy as np
|
2021-09-03 05:05:04 +08:00
|
|
|
import torch
|
2024-04-26 16:46:03 +02:00
|
|
|
from overrides import override
|
2020-04-09 19:53:45 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
from tianshou.data import (
|
|
|
|
Batch,
|
2021-09-03 05:05:04 +08:00
|
|
|
CachedReplayBuffer,
|
2021-02-19 10:33:49 +08:00
|
|
|
ReplayBuffer,
|
|
|
|
ReplayBufferManager,
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
SequenceSummaryStats,
|
2021-02-19 10:33:49 +08:00
|
|
|
VectorReplayBuffer,
|
|
|
|
to_numpy,
|
|
|
|
)
|
2024-03-28 18:02:31 +01:00
|
|
|
from tianshou.data.types import (
|
|
|
|
ObsBatchProtocol,
|
|
|
|
RolloutBatchProtocol,
|
|
|
|
)
|
2021-09-03 05:05:04 +08:00
|
|
|
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
|
|
|
from tianshou.policy import BasePolicy
|
2024-02-07 17:28:16 +01:00
|
|
|
from tianshou.utils.print import DataclassPPrintMixin
|
2024-04-26 16:46:03 +02:00
|
|
|
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-03-13 17:49:22 +08:00
|
|
|
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
@dataclass(kw_only=True)
|
2024-02-07 17:28:16 +01:00
|
|
|
class CollectStatsBase(DataclassPPrintMixin):
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
"""The most basic stats, often used for offline learning."""
|
|
|
|
|
|
|
|
n_collected_episodes: int = 0
|
|
|
|
"""The number of collected episodes."""
|
|
|
|
n_collected_steps: int = 0
|
|
|
|
"""The number of collected steps."""
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
|
|
class CollectStats(CollectStatsBase):
|
|
|
|
"""A data structure for storing the statistics of rollouts."""
|
|
|
|
|
|
|
|
collect_time: float = 0.0
|
|
|
|
"""The time for collecting transitions."""
|
|
|
|
collect_speed: float = 0.0
|
|
|
|
"""The speed of collecting (env_step per second)."""
|
|
|
|
returns: np.ndarray
|
|
|
|
"""The collected episode returns."""
|
2024-03-28 18:02:31 +01:00
|
|
|
returns_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
"""Stats of the collected returns."""
|
|
|
|
lens: np.ndarray
|
|
|
|
"""The collected episode lengths."""
|
2024-03-28 18:02:31 +01:00
|
|
|
lens_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
"""Stats of the collected episode lengths."""
|
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
@classmethod
|
|
|
|
def with_autogenerated_stats(
|
|
|
|
cls,
|
|
|
|
returns: np.ndarray,
|
|
|
|
lens: np.ndarray,
|
|
|
|
n_collected_episodes: int = 0,
|
|
|
|
n_collected_steps: int = 0,
|
|
|
|
collect_time: float = 0.0,
|
|
|
|
collect_speed: float = 0.0,
|
|
|
|
) -> Self:
|
|
|
|
"""Return a new instance with the stats autogenerated from the given lists."""
|
|
|
|
returns_stat = SequenceSummaryStats.from_sequence(returns) if returns.size > 0 else None
|
|
|
|
lens_stat = SequenceSummaryStats.from_sequence(lens) if lens.size > 0 else None
|
|
|
|
return cls(
|
|
|
|
n_collected_episodes=n_collected_episodes,
|
|
|
|
n_collected_steps=n_collected_steps,
|
|
|
|
collect_time=collect_time,
|
|
|
|
collect_speed=collect_speed,
|
|
|
|
returns=returns,
|
|
|
|
returns_stat=returns_stat,
|
|
|
|
lens=np.array(lens, int),
|
|
|
|
lens_stat=lens_stat,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None")
|
|
|
|
|
|
|
|
|
|
|
|
def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike:
|
|
|
|
"""Return None, or the values at the given indices if the object is not None."""
|
|
|
|
if obj is not None:
|
|
|
|
return obj[indices] # type: ignore[index, return-value]
|
|
|
|
return None # type: ignore[unreachable]
|
|
|
|
|
|
|
|
|
|
|
|
def _dict_of_arr_to_arr_of_dicts(dict_of_arr: dict[str, np.ndarray | dict]) -> np.ndarray:
|
|
|
|
return np.array(Batch(dict_of_arr).to_list_of_dicts())
|
|
|
|
|
|
|
|
|
|
|
|
def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch:
|
|
|
|
"""TODO: this exists because of multiple bugs in Batch and to restore backwards compatibility.
|
|
|
|
Batch should be fixed and this function should be removed asap!.
|
|
|
|
"""
|
|
|
|
if info_array.dtype != np.dtype("O"):
|
|
|
|
raise ValueError(
|
|
|
|
f"Expected info_array to have dtype=object, but got {info_array.dtype}.",
|
|
|
|
)
|
|
|
|
|
|
|
|
truthy_info_indices = info_array.nonzero()[0]
|
|
|
|
falsy_info_indices = set(range(len(info_array))) - set(truthy_info_indices)
|
|
|
|
falsy_info_indices = np.array(list(falsy_info_indices), dtype=int)
|
|
|
|
|
|
|
|
if len(falsy_info_indices) == len(info_array):
|
|
|
|
return Batch()
|
|
|
|
|
|
|
|
some_nonempty_info = None
|
|
|
|
for info in info_array:
|
|
|
|
if info:
|
|
|
|
some_nonempty_info = info
|
|
|
|
break
|
|
|
|
|
|
|
|
info_array = copy(info_array)
|
|
|
|
info_array[falsy_info_indices] = some_nonempty_info
|
|
|
|
result_batch_parent = Batch(info=info_array)
|
|
|
|
result_batch_parent.info[falsy_info_indices] = {}
|
|
|
|
return result_batch_parent.info
|
|
|
|
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
class BaseCollector(ABC):
|
|
|
|
"""Used to collect data from a vector environment into a buffer using a given policy.
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
.. note::
|
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
Please make sure the given environment has a time limitation if using `n_episode`
|
2021-02-19 10:33:49 +08:00
|
|
|
collect option.
|
2022-01-13 01:46:28 +01:00
|
|
|
|
|
|
|
.. note::
|
2022-02-25 07:40:33 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
In past versions of Tianshou, the replay buffer passed to `__init__`
|
2022-01-13 01:46:28 +01:00
|
|
|
was automatically reset. This is not done in the current implementation.
|
2020-04-05 18:34:45 +08:00
|
|
|
"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
2024-04-26 16:46:03 +02:00
|
|
|
env: BaseVectorEnv | gym.Env,
|
2023-09-05 23:34:23 +02:00
|
|
|
buffer: ReplayBuffer | None = None,
|
2021-02-19 10:33:49 +08:00
|
|
|
exploration_noise: bool = False,
|
2020-09-12 15:39:01 +08:00
|
|
|
) -> None:
|
2021-06-26 18:08:41 +08:00
|
|
|
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
2024-02-07 17:28:16 +01:00
|
|
|
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
|
|
|
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
|
2024-04-26 16:46:03 +02:00
|
|
|
env = DummyVectorEnv([lambda: env]) # type: ignore
|
|
|
|
|
|
|
|
if buffer is None:
|
|
|
|
buffer = VectorReplayBuffer(len(env), len(env))
|
|
|
|
|
|
|
|
self.buffer: ReplayBuffer = buffer
|
2020-03-12 22:20:33 +08:00
|
|
|
self.policy = policy
|
2024-04-26 16:46:03 +02:00
|
|
|
self.env = cast(BaseVectorEnv, env)
|
|
|
|
self.exploration_noise = exploration_noise
|
|
|
|
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
|
|
|
|
2022-03-16 14:38:51 +01:00
|
|
|
self._action_space = self.env.action_space
|
2024-04-26 16:46:03 +02:00
|
|
|
self._is_closed = False
|
2020-04-13 19:37:27 +08:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
self._validate_buffer()
|
2024-03-28 18:02:31 +01:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
def _validate_buffer(self) -> None:
|
|
|
|
buf = self.buffer
|
|
|
|
# TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager.
|
|
|
|
# We should probably rename the manager
|
|
|
|
if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num:
|
|
|
|
raise ValueError(
|
|
|
|
f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.",
|
|
|
|
)
|
|
|
|
if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num:
|
|
|
|
raise ValueError(
|
|
|
|
f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.",
|
|
|
|
)
|
|
|
|
# Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance
|
|
|
|
if not isinstance(buf, ReplayBufferManager):
|
|
|
|
if buf.maxsize == 0:
|
|
|
|
raise ValueError("Buffer maxsize should be greater than 0.")
|
|
|
|
if self.env_num > 1:
|
|
|
|
raise ValueError(
|
|
|
|
f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). "
|
|
|
|
f"Please use the corresponding VectorReplayBuffer instead.",
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def env_num(self) -> int:
|
|
|
|
return len(self.env)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def action_space(self) -> gym.spaces.Space:
|
|
|
|
return self._action_space
|
2024-03-28 18:02:31 +01:00
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
"""Close the collector and the environment."""
|
|
|
|
self.env.close()
|
|
|
|
self._is_closed = True
|
|
|
|
|
2022-06-27 18:52:21 -04:00
|
|
|
def reset(
|
|
|
|
self,
|
|
|
|
reset_buffer: bool = True,
|
2024-03-28 18:02:31 +01:00
|
|
|
reset_stats: bool = True,
|
2023-09-05 23:34:23 +02:00
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
2024-04-26 16:46:03 +02:00
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
2024-03-28 18:02:31 +01:00
|
|
|
"""Reset the environment, statistics, and data needed to start the collection.
|
2022-01-13 01:46:28 +01:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
:param reset_buffer: if true, reset the replay buffer attached
|
2022-01-13 01:46:28 +01:00
|
|
|
to the collector.
|
2024-03-28 18:02:31 +01:00
|
|
|
:param reset_stats: if true, reset the statistics attached to the collector.
|
2022-06-27 18:52:21 -04:00
|
|
|
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
|
|
|
reset function. Defaults to None (extra keyword arguments)
|
2024-04-26 18:14:20 +02:00
|
|
|
:return: The initial observation and info from the environment.
|
2022-01-13 01:46:28 +01:00
|
|
|
"""
|
2024-04-26 16:46:03 +02:00
|
|
|
obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
2022-01-13 01:46:28 +01:00
|
|
|
if reset_buffer:
|
|
|
|
self.reset_buffer()
|
2024-03-28 18:02:31 +01:00
|
|
|
if reset_stats:
|
|
|
|
self.reset_stat()
|
|
|
|
self._is_closed = False
|
2024-04-26 16:46:03 +02:00
|
|
|
return obs_NO, info_N
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-09-22 16:28:46 +08:00
|
|
|
def reset_stat(self) -> None:
|
|
|
|
"""Reset the statistic variables."""
|
2021-02-19 10:33:49 +08:00
|
|
|
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
2020-09-22 16:28:46 +08:00
|
|
|
|
2021-03-27 16:58:48 +08:00
|
|
|
def reset_buffer(self, keep_statistics: bool = False) -> None:
|
2021-02-19 10:33:49 +08:00
|
|
|
"""Reset the data buffer."""
|
2021-03-27 16:58:48 +08:00
|
|
|
self.buffer.reset(keep_statistics=keep_statistics)
|
2020-03-27 09:04:29 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
def reset_env(
|
2022-06-27 18:52:21 -04:00
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
2024-04-26 16:46:03 +02:00
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
2024-03-28 18:02:31 +01:00
|
|
|
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
|
|
|
|
gym_reset_kwargs = gym_reset_kwargs or {}
|
2024-04-26 16:46:03 +02:00
|
|
|
obs_NO, info_N = self.env.reset(**gym_reset_kwargs)
|
2024-03-28 18:02:31 +01:00
|
|
|
# TODO: hack, wrap envpool envs such that they don't return a dict
|
2024-04-26 16:46:03 +02:00
|
|
|
if isinstance(info_N, dict): # type: ignore[unreachable]
|
2024-03-28 18:02:31 +01:00
|
|
|
# this can happen if the env is an envpool env. Then the thing returned by reset is a dict
|
|
|
|
# with array entries instead of an array of dicts
|
|
|
|
# We use Batch to turn it into an array of dicts
|
2024-04-26 16:46:03 +02:00
|
|
|
info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable]
|
|
|
|
return obs_NO, info_N
|
2024-03-28 18:02:31 +01:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
@abstractmethod
|
|
|
|
def _collect(
|
|
|
|
self,
|
|
|
|
n_step: int | None = None,
|
|
|
|
n_episode: int | None = None,
|
|
|
|
random: bool = False,
|
|
|
|
render: float | None = None,
|
|
|
|
no_grad: bool = True,
|
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
|
|
|
) -> CollectStats:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def collect(
|
|
|
|
self,
|
|
|
|
n_step: int | None = None,
|
|
|
|
n_episode: int | None = None,
|
|
|
|
random: bool = False,
|
|
|
|
render: float | None = None,
|
|
|
|
no_grad: bool = True,
|
|
|
|
reset_before_collect: bool = False,
|
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
|
|
|
eval_mode: bool = False,
|
|
|
|
) -> CollectStats:
|
|
|
|
"""Collect a specified number of steps or episodes.
|
|
|
|
|
|
|
|
To ensure an unbiased sampling result with the n_episode option, this function will
|
|
|
|
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
|
|
|
|
episodes, they will be collected evenly from each env.
|
|
|
|
|
|
|
|
:param n_step: how many steps you want to collect.
|
|
|
|
:param n_episode: how many episodes you want to collect.
|
|
|
|
:param random: whether to use random policy for collecting data.
|
|
|
|
:param render: the sleep time between rendering consecutive frames.
|
|
|
|
:param no_grad: whether to retain gradient in policy.forward().
|
|
|
|
:param reset_before_collect: whether to reset the environment before collecting data.
|
|
|
|
(The collector needs the initial obs and info to function properly.)
|
|
|
|
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
|
|
|
reset function. Only used if reset_before_collect is True.
|
|
|
|
:param eval_mode: whether to collect data in evaluation mode. Will
|
|
|
|
set the policy to training mode otherwise.
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
One and only one collection number specification is permitted, either
|
|
|
|
``n_step`` or ``n_episode``.
|
|
|
|
|
|
|
|
:return: The collected stats
|
|
|
|
"""
|
|
|
|
# check that exactly one of n_step or n_episode is set and that the other is larger than 0
|
|
|
|
self._validate_n_step_n_episode(n_episode, n_step)
|
|
|
|
|
|
|
|
if reset_before_collect:
|
|
|
|
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
|
|
|
|
|
|
|
|
policy_mode_context = in_eval_mode if eval_mode else in_train_mode
|
|
|
|
with policy_mode_context(self.policy):
|
|
|
|
return self._collect(
|
|
|
|
n_step=n_step,
|
|
|
|
n_episode=n_episode,
|
|
|
|
random=random,
|
|
|
|
render=render,
|
|
|
|
no_grad=no_grad,
|
|
|
|
gym_reset_kwargs=gym_reset_kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None:
|
|
|
|
if not n_step and not n_episode:
|
|
|
|
raise ValueError(
|
|
|
|
f"Only one of n_step and n_episode should be set to a value larger than zero "
|
|
|
|
f"but got {n_step=}, {n_episode=}.",
|
|
|
|
)
|
|
|
|
if n_step is None and n_episode is None:
|
|
|
|
raise ValueError(
|
|
|
|
"Exactly one of n_step and n_episode should be set but got None for both.",
|
|
|
|
)
|
|
|
|
if n_step and n_step % self.env_num != 0:
|
|
|
|
warnings.warn(
|
|
|
|
f"{n_step=} is not a multiple of ({self.env_num=}), "
|
|
|
|
"which may cause extra transitions being collected into the buffer.",
|
|
|
|
)
|
|
|
|
if n_episode and self.env_num > n_episode:
|
|
|
|
warnings.warn(
|
|
|
|
f"{n_episode=} should be larger than {self.env_num=} to "
|
|
|
|
f"collect at least one trajectory in each environment.",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Collector(BaseCollector):
|
|
|
|
# NAMING CONVENTION (mostly suffixes):
|
|
|
|
# episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
|
|
|
|
# the corresponding env is either reset or removed from the ready envs.
|
|
|
|
# N - number of envs, always fixed and >= R.
|
|
|
|
# R - number ready env ids. Note that this might change when envs get idle.
|
|
|
|
# This can only happen in n_episode case, see explanation in the corresponding block.
|
|
|
|
# For n_step, we always use all envs to collect the data, while for n_episode,
|
|
|
|
# R will be at most n_episode at the beginning, but can decrease during the collection.
|
|
|
|
# O - dimension(s) of observations
|
|
|
|
# A - dimension(s) of actions
|
|
|
|
# H - dimension(s) of hidden state
|
|
|
|
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
|
|
|
|
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
|
|
|
|
# Only used in n_episode case. Then, R becomes R-S.
|
|
|
|
|
|
|
|
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
|
|
|
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
|
|
|
# policy.deterministic_eval)
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
env: gym.Env | BaseVectorEnv,
|
|
|
|
buffer: ReplayBuffer | None = None,
|
|
|
|
exploration_noise: bool = False,
|
|
|
|
) -> None:
|
|
|
|
""":param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
|
|
|
:param env: a ``gym.Env`` environment or an instance of the
|
|
|
|
:class:`~tianshou.env.BaseVectorEnv` class.
|
|
|
|
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
|
|
|
|
If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
|
|
|
|
as the default buffer.
|
|
|
|
:param exploration_noise: determine whether the action needs to be modified
|
|
|
|
with the corresponding policy's exploration noise. If so, "policy.
|
|
|
|
exploration_noise(act, batch)" will be called automatically to add the
|
|
|
|
exploration noise into action. Default to False.
|
|
|
|
"""
|
|
|
|
super().__init__(policy, env, buffer, exploration_noise=exploration_noise)
|
|
|
|
self._pre_collect_obs_RO: np.ndarray | None = None
|
|
|
|
self._pre_collect_info_R: np.ndarray | None = None
|
|
|
|
self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None
|
|
|
|
|
|
|
|
self._is_closed = False
|
|
|
|
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
super().close()
|
|
|
|
self._pre_collect_obs_RO = None
|
|
|
|
self._pre_collect_info_R = None
|
|
|
|
|
|
|
|
def reset_env(
|
|
|
|
self,
|
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
|
|
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
|
|
|
|
obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
|
|
|
# We assume that R = N when reset is called.
|
|
|
|
# TODO: there is currently no mechanism that ensures this and it's a public method!
|
|
|
|
self._pre_collect_obs_RO = obs_NO
|
|
|
|
self._pre_collect_info_R = info_N
|
2024-03-28 18:02:31 +01:00
|
|
|
self._pre_collect_hidden_state_RH = None
|
2024-04-26 16:46:03 +02:00
|
|
|
return obs_NO, info_N
|
2024-03-28 18:02:31 +01:00
|
|
|
|
|
|
|
def _compute_action_policy_hidden(
|
|
|
|
self,
|
|
|
|
random: bool,
|
|
|
|
ready_env_ids_R: np.ndarray,
|
|
|
|
use_grad: bool,
|
|
|
|
last_obs_RO: np.ndarray,
|
|
|
|
last_info_R: np.ndarray,
|
|
|
|
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None,
|
|
|
|
) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]:
|
|
|
|
"""Returns the action, the normalized action, a "policy" entry, and the hidden state."""
|
|
|
|
if random:
|
|
|
|
try:
|
|
|
|
act_normalized_RA = np.array(
|
|
|
|
[self._action_space[i].sample() for i in ready_env_ids_R],
|
|
|
|
)
|
|
|
|
# TODO: test whether envpool env explicitly
|
|
|
|
except TypeError: # envpool's action space is not for per-env
|
|
|
|
act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R])
|
|
|
|
act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA))
|
|
|
|
policy_R = Batch()
|
|
|
|
hidden_state_RH = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
info_batch = _HACKY_create_info_batch(last_info_R)
|
|
|
|
obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch))
|
2023-02-03 20:57:27 +01:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
with torch.set_grad_enabled(use_grad):
|
|
|
|
act_batch_RA = self.policy(
|
|
|
|
obs_batch_R,
|
|
|
|
last_hidden_state_RH,
|
|
|
|
)
|
|
|
|
|
|
|
|
act_RA = to_numpy(act_batch_RA.act)
|
|
|
|
if self.exploration_noise:
|
|
|
|
act_RA = self.policy.exploration_noise(act_RA, obs_batch_R)
|
|
|
|
act_normalized_RA = self.policy.map_action(act_RA)
|
|
|
|
|
|
|
|
# TODO: cleanup the whole policy in batch thing
|
|
|
|
# todo policy_R can also be none, check
|
|
|
|
policy_R = act_batch_RA.get("policy", Batch())
|
|
|
|
if not isinstance(policy_R, Batch):
|
|
|
|
raise RuntimeError(
|
|
|
|
f"The policy result should be a {Batch}, but got {type(policy_R)}",
|
|
|
|
)
|
2022-06-27 18:52:21 -04:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
hidden_state_RH = act_batch_RA.get("state", None)
|
|
|
|
# TODO: do we need the conditional? Would be better to just add hidden_state which could be None
|
|
|
|
if hidden_state_RH is not None:
|
|
|
|
policy_R.hidden_state = (
|
|
|
|
hidden_state_RH # save state into buffer through policy attr
|
|
|
|
)
|
|
|
|
return act_RA, act_normalized_RA, policy_R, hidden_state_RH
|
|
|
|
|
|
|
|
# TODO: reduce complexity, remove the noqa
|
2024-04-26 16:46:03 +02:00
|
|
|
def _collect(
|
2020-09-12 15:39:01 +08:00
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
n_step: int | None = None,
|
|
|
|
n_episode: int | None = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
random: bool = False,
|
2023-09-05 23:34:23 +02:00
|
|
|
render: float | None = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
no_grad: bool = True,
|
2023-09-05 23:34:23 +02:00
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
) -> CollectStats:
|
2024-04-26 16:46:03 +02:00
|
|
|
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector
|
|
|
|
if self.env.is_async:
|
|
|
|
raise ValueError(
|
|
|
|
f"Please use {AsyncCollector.__name__} for asynchronous environments. "
|
|
|
|
f"Env class: {self.env.__class__.__name__}.",
|
|
|
|
)
|
2024-03-28 18:02:31 +01:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
if n_step is not None:
|
2024-03-28 18:02:31 +01:00
|
|
|
ready_env_ids_R = np.arange(self.env_num)
|
2021-02-19 10:33:49 +08:00
|
|
|
elif n_episode is not None:
|
2024-03-28 18:02:31 +01:00
|
|
|
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
use_grad = not no_grad
|
2024-03-28 18:02:31 +01:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
start_time = time.time()
|
2024-03-28 18:02:31 +01:00
|
|
|
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
|
|
|
|
raise ValueError(
|
|
|
|
"Initial obs and info should not be None. "
|
|
|
|
"Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.",
|
|
|
|
)
|
|
|
|
|
|
|
|
# get the first obs to be the current obs in the n_step case as
|
|
|
|
# episodes as a new call to collect does not restart trajectories
|
|
|
|
# (which we also really don't want)
|
2020-07-23 16:40:53 +08:00
|
|
|
step_count = 0
|
2024-03-28 18:02:31 +01:00
|
|
|
num_collected_episodes = 0
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
episode_returns: list[float] = []
|
|
|
|
episode_lens: list[int] = []
|
|
|
|
episode_start_indices: list[int] = []
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
# in case we select fewer episodes than envs, we run only some of them
|
|
|
|
last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R)
|
|
|
|
last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R)
|
|
|
|
last_hidden_state_RH = _nullable_slice(
|
|
|
|
self._pre_collect_hidden_state_RH,
|
|
|
|
ready_env_ids_R,
|
|
|
|
)
|
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
while True:
|
2024-03-28 18:02:31 +01:00
|
|
|
# todo check if we need this when using cur_rollout_batch
|
|
|
|
# if len(cur_rollout_batch) != len(ready_env_ids):
|
|
|
|
# raise RuntimeError(
|
|
|
|
# f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids"
|
|
|
|
# f"{len(ready_env_ids)}. This should not happen and could be a bug!",
|
|
|
|
# )
|
2021-02-19 10:33:49 +08:00
|
|
|
# restore the state: if the last state is None, it won't store
|
|
|
|
|
|
|
|
# get the next action
|
2024-03-28 18:02:31 +01:00
|
|
|
(
|
|
|
|
act_RA,
|
|
|
|
act_normalized_RA,
|
|
|
|
policy_R,
|
|
|
|
hidden_state_RH,
|
|
|
|
) = self._compute_action_policy_hidden(
|
|
|
|
random=random,
|
|
|
|
ready_env_ids_R=ready_env_ids_R,
|
|
|
|
use_grad=use_grad,
|
|
|
|
last_obs_RO=last_obs_RO,
|
|
|
|
last_info_R=last_info_R,
|
|
|
|
last_hidden_state_RH=last_hidden_state_RH,
|
|
|
|
)
|
2023-08-25 23:40:56 +02:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step(
|
|
|
|
act_normalized_RA,
|
|
|
|
ready_env_ids_R,
|
2023-02-03 20:57:27 +01:00
|
|
|
)
|
2024-03-28 18:02:31 +01:00
|
|
|
if isinstance(info_R, dict): # type: ignore[unreachable]
|
|
|
|
# This can happen if the env is an envpool env. Then the info returned by step is a dict
|
|
|
|
info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable]
|
|
|
|
done_R = np.logical_or(terminated_R, truncated_R)
|
|
|
|
|
|
|
|
current_iteration_batch = cast(
|
|
|
|
RolloutBatchProtocol,
|
|
|
|
Batch(
|
|
|
|
obs=last_obs_RO,
|
|
|
|
act=act_RA,
|
|
|
|
policy=policy_R,
|
|
|
|
obs_next=obs_next_RO,
|
|
|
|
rew=rew_R,
|
|
|
|
terminated=terminated_R,
|
|
|
|
truncated=truncated_R,
|
|
|
|
done=done_R,
|
|
|
|
info=info_R,
|
|
|
|
),
|
2022-09-26 18:31:23 +02:00
|
|
|
)
|
2020-07-23 16:40:53 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
# TODO: only makes sense if render_mode is human.
|
|
|
|
# Also, doubtful whether it makes sense at all for true vectorized envs
|
2020-05-05 13:39:51 +08:00
|
|
|
if render:
|
2020-09-11 07:55:37 +08:00
|
|
|
self.env.render()
|
2024-03-28 18:02:31 +01:00
|
|
|
if not np.isclose(render, 0):
|
2021-02-19 10:33:49 +08:00
|
|
|
time.sleep(render)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
# add data into the buffer
|
2024-03-28 18:02:31 +01:00
|
|
|
ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add(
|
|
|
|
current_iteration_batch,
|
|
|
|
buffer_ids=ready_env_ids_R,
|
|
|
|
)
|
2021-02-19 10:33:49 +08:00
|
|
|
|
|
|
|
# collect statistics
|
2024-03-28 18:02:31 +01:00
|
|
|
num_episodes_done_this_iter = np.sum(done_R)
|
|
|
|
num_collected_episodes += num_episodes_done_this_iter
|
|
|
|
step_count += len(ready_env_ids_R)
|
|
|
|
|
|
|
|
# preparing for the next iteration
|
2024-04-26 16:46:03 +02:00
|
|
|
# obs_next, info and hidden_state will be modified inplace in the code below,
|
|
|
|
# so we copy to not affect the data in the buffer
|
2024-03-28 18:02:31 +01:00
|
|
|
last_obs_RO = copy(obs_next_RO)
|
|
|
|
last_info_R = copy(info_R)
|
|
|
|
last_hidden_state_RH = copy(hidden_state_RH)
|
|
|
|
|
|
|
|
# Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration
|
|
|
|
# Resetting envs that reached done, or removing some of them from the collection if needed (see below)
|
|
|
|
if num_episodes_done_this_iter > 0:
|
|
|
|
# TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays
|
|
|
|
# D - number of envs that reached done in the rollout above
|
|
|
|
env_ind_local_D = np.where(done_R)[0]
|
|
|
|
env_ind_global_D = ready_env_ids_R[env_ind_local_D]
|
|
|
|
episode_lens.extend(ep_len_R[env_ind_local_D])
|
|
|
|
episode_returns.extend(ep_rew_R[env_ind_local_D])
|
|
|
|
episode_start_indices.extend(ep_idx_R[env_ind_local_D])
|
2021-02-19 10:33:49 +08:00
|
|
|
# now we copy obs_next to obs, but since there might be
|
|
|
|
# finished episodes, we have to reset finished envs first.
|
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
gym_reset_kwargs = gym_reset_kwargs or {}
|
2024-03-28 18:02:31 +01:00
|
|
|
obs_reset_DO, info_reset_D = self.env.reset(
|
|
|
|
env_id=env_ind_global_D,
|
|
|
|
**gym_reset_kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Set the hidden state to zero or None for the envs that reached done
|
|
|
|
# TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of
|
|
|
|
# this complex logic
|
|
|
|
self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH)
|
|
|
|
|
|
|
|
# preparing for the next iteration
|
|
|
|
last_obs_RO[env_ind_local_D] = obs_reset_DO
|
|
|
|
last_info_R[env_ind_local_D] = info_reset_D
|
|
|
|
|
|
|
|
# Handling the case when we have more ready envs than desired and are not done yet
|
|
|
|
#
|
|
|
|
# This can only happen if we are collecting a fixed number of episodes
|
|
|
|
# If we have more ready envs than there are remaining episodes to collect,
|
|
|
|
# we will remove some of them for the next rollout
|
|
|
|
# One effect of this is the following: only envs that have completed an episode
|
|
|
|
# in the last step can ever be removed from the ready envs.
|
|
|
|
# Thus, this guarantees that each env will contribute at least one episode to the
|
|
|
|
# collected data (the buffer). This effect was previous called "avoiding bias in selecting environments"
|
|
|
|
# However, it is not at all clear whether this is actually useful or necessary.
|
|
|
|
# Additional naming convention:
|
|
|
|
# S - number of surplus envs
|
|
|
|
# TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones.
|
|
|
|
# Changing R to R-S highly increases the complexity of the code.
|
2021-02-19 10:33:49 +08:00
|
|
|
if n_episode:
|
2024-03-28 18:02:31 +01:00
|
|
|
remaining_episodes_to_collect = n_episode - num_collected_episodes
|
|
|
|
surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect
|
2021-02-19 10:33:49 +08:00
|
|
|
if surplus_env_num > 0:
|
2024-03-28 18:02:31 +01:00
|
|
|
# R becomes R-S here, preparing for the next iteration in while loop
|
|
|
|
# Everything that was of length R needs to be filtered and become of length R-S.
|
|
|
|
# Note that this won't be the last iteration, as one iteration equals one
|
|
|
|
# step and we still need to collect the remaining episodes to reach the breaking condition.
|
|
|
|
|
|
|
|
# creating the mask
|
|
|
|
env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num]
|
|
|
|
env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool)
|
|
|
|
env_should_remain_R[env_to_be_ignored_ind_local_S] = False
|
|
|
|
# stripping the "idle" indices, shortening the relevant quantities from R to R-S
|
|
|
|
ready_env_ids_R = ready_env_ids_R[env_should_remain_R]
|
|
|
|
last_obs_RO = last_obs_RO[env_should_remain_R]
|
|
|
|
last_info_R = last_info_R[env_should_remain_R]
|
|
|
|
if hidden_state_RH is not None:
|
|
|
|
last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R] # type: ignore[index]
|
|
|
|
|
|
|
|
if (n_step and step_count >= n_step) or (
|
|
|
|
n_episode and num_collected_episodes >= n_episode
|
|
|
|
):
|
2021-02-19 10:33:49 +08:00
|
|
|
break
|
|
|
|
|
|
|
|
# generate statistics
|
2020-07-23 16:40:53 +08:00
|
|
|
self.collect_step += step_count
|
2024-03-28 18:02:31 +01:00
|
|
|
self.collect_episode += num_collected_episodes
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
collect_time = max(time.time() - start_time, 1e-9)
|
|
|
|
self.collect_time += collect_time
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
if n_step:
|
|
|
|
# persist for future collect iterations
|
|
|
|
self._pre_collect_obs_RO = last_obs_RO
|
|
|
|
self._pre_collect_info_R = last_info_R
|
|
|
|
self._pre_collect_hidden_state_RH = last_hidden_state_RH
|
|
|
|
elif n_episode:
|
|
|
|
# reset envs and the _pre_collect fields
|
|
|
|
self.reset_env(gym_reset_kwargs) # todo still necessary?
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
return CollectStats.with_autogenerated_stats(
|
|
|
|
returns=np.array(episode_returns),
|
|
|
|
lens=np.array(episode_lens),
|
|
|
|
n_collected_episodes=num_collected_episodes,
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
n_collected_steps=step_count,
|
|
|
|
collect_time=collect_time,
|
|
|
|
collect_speed=step_count / collect_time,
|
|
|
|
)
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
def _reset_hidden_state_based_on_type(
|
|
|
|
self,
|
|
|
|
env_ind_local_D: np.ndarray,
|
|
|
|
last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None,
|
|
|
|
) -> None:
|
|
|
|
if isinstance(last_hidden_state_RH, torch.Tensor):
|
|
|
|
last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index]
|
|
|
|
elif isinstance(last_hidden_state_RH, np.ndarray):
|
|
|
|
last_hidden_state_RH[env_ind_local_D] = (
|
|
|
|
None if last_hidden_state_RH.dtype == object else 0
|
|
|
|
)
|
|
|
|
elif isinstance(last_hidden_state_RH, Batch):
|
|
|
|
last_hidden_state_RH.empty_(env_ind_local_D)
|
|
|
|
# todo is this inplace magic and just working?
|
|
|
|
|
2020-07-26 12:01:21 +02:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
class AsyncCollector(Collector):
|
|
|
|
"""Async Collector handles async vector environment.
|
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation.
|
2021-02-19 10:33:49 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
env: BaseVectorEnv,
|
2023-09-05 23:34:23 +02:00
|
|
|
buffer: ReplayBuffer | None = None,
|
2021-02-19 10:33:49 +08:00
|
|
|
exploration_noise: bool = False,
|
|
|
|
) -> None:
|
2024-04-26 16:46:03 +02:00
|
|
|
if not env.is_async:
|
|
|
|
# TODO: raise an exception?
|
|
|
|
log.error(
|
|
|
|
f"Please use {Collector.__name__} if not using async venv. "
|
|
|
|
f"Env class: {env.__class__.__name__}",
|
|
|
|
)
|
2021-10-04 11:19:07 -04:00
|
|
|
# assert env.is_async
|
2022-03-08 14:38:42 -08:00
|
|
|
warnings.warn("Using async setting may collect extra transitions into buffer.")
|
2022-06-27 18:52:21 -04:00
|
|
|
super().__init__(
|
|
|
|
policy,
|
|
|
|
env,
|
|
|
|
buffer,
|
|
|
|
exploration_noise,
|
|
|
|
)
|
2024-03-28 18:02:31 +01:00
|
|
|
# E denotes the number of parallel environments: self.env_num
|
|
|
|
# At init, E=R but during collection R <= E
|
|
|
|
# Keep in sync with reset!
|
|
|
|
self._ready_env_ids_R: np.ndarray = np.arange(self.env_num)
|
|
|
|
self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO)
|
|
|
|
self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R)
|
|
|
|
self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy(
|
|
|
|
self._pre_collect_hidden_state_RH,
|
|
|
|
)
|
|
|
|
self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num)
|
|
|
|
self._current_policy_in_all_envs_E: Batch | None = None
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
def reset(
|
|
|
|
self,
|
|
|
|
reset_buffer: bool = True,
|
|
|
|
reset_stats: bool = True,
|
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
2024-04-26 16:46:03 +02:00
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
2024-03-28 18:02:31 +01:00
|
|
|
"""Reset the environment, statistics, and data needed to start the collection.
|
|
|
|
|
|
|
|
:param reset_buffer: if true, reset the replay buffer attached
|
|
|
|
to the collector.
|
|
|
|
:param reset_stats: if true, reset the statistics attached to the collector.
|
|
|
|
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
|
|
|
reset function. Defaults to None (extra keyword arguments)
|
2024-04-26 18:14:20 +02:00
|
|
|
:return: The initial observation and info from the environment.
|
2024-03-28 18:02:31 +01:00
|
|
|
"""
|
|
|
|
# This sets the _pre_collect attrs
|
2024-04-26 16:46:03 +02:00
|
|
|
result = super().reset(
|
2024-03-28 18:02:31 +01:00
|
|
|
reset_buffer=reset_buffer,
|
|
|
|
reset_stats=reset_stats,
|
|
|
|
gym_reset_kwargs=gym_reset_kwargs,
|
|
|
|
)
|
|
|
|
# Keep in sync with init!
|
|
|
|
self._ready_env_ids_R = np.arange(self.env_num)
|
|
|
|
# E denotes the number of parallel environments self.env_num
|
|
|
|
self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO)
|
|
|
|
self._current_info_in_all_envs_E = copy(self._pre_collect_info_R)
|
|
|
|
self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH)
|
|
|
|
self._current_action_in_all_envs_EA = np.empty(self.env_num)
|
|
|
|
self._current_policy_in_all_envs_E = None
|
2024-04-26 16:46:03 +02:00
|
|
|
return result
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-04-26 16:46:03 +02:00
|
|
|
@override
|
|
|
|
def reset_env(
|
|
|
|
self,
|
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
|
|
# we need to step through the envs and wait until they are ready to be able to interact with them
|
|
|
|
if self.env.waiting_id:
|
|
|
|
self.env.step(None, id=self.env.waiting_id)
|
|
|
|
return super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
|
|
|
|
|
|
|
@override
|
|
|
|
def _collect(
|
2021-02-19 10:33:49 +08:00
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
n_step: int | None = None,
|
|
|
|
n_episode: int | None = None,
|
2021-02-19 10:33:49 +08:00
|
|
|
random: bool = False,
|
2023-09-05 23:34:23 +02:00
|
|
|
render: float | None = None,
|
2021-02-19 10:33:49 +08:00
|
|
|
no_grad: bool = True,
|
2023-09-05 23:34:23 +02:00
|
|
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
) -> CollectStats:
|
2024-03-28 18:02:31 +01:00
|
|
|
use_grad = not no_grad
|
2021-02-19 10:33:49 +08:00
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
step_count = 0
|
2024-03-28 18:02:31 +01:00
|
|
|
num_collected_episodes = 0
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
episode_returns: list[float] = []
|
|
|
|
episode_lens: list[int] = []
|
|
|
|
episode_start_indices: list[int] = []
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
ready_env_ids_R = self._ready_env_ids_R
|
|
|
|
# last_obs_RO= self._current_obs_in_all_envs_EO[ready_env_ids_R] # type: ignore[index]
|
|
|
|
# last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] # type: ignore[index]
|
|
|
|
# last_hidden_state_RH = self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] # type: ignore[index]
|
|
|
|
# last_obs_RO = self._pre_collect_obs_RO
|
|
|
|
# last_info_R = self._pre_collect_info_R
|
|
|
|
# last_hidden_state_RH = self._pre_collect_hidden_state_RH
|
|
|
|
if self._current_obs_in_all_envs_EO is None or self._current_info_in_all_envs_E is None:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Current obs or info array is None, did you call reset or pass reset_at_collect=True?",
|
|
|
|
)
|
|
|
|
|
|
|
|
last_obs_RO = self._current_obs_in_all_envs_EO[ready_env_ids_R]
|
|
|
|
last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R]
|
|
|
|
last_hidden_state_RH = _nullable_slice(
|
|
|
|
self._current_hidden_state_in_all_envs_EH,
|
|
|
|
ready_env_ids_R,
|
|
|
|
)
|
|
|
|
# Each iteration of the AsyncCollector is only stepping a subset of the
|
|
|
|
# envs. The last observation/ hidden state of the ones not included in
|
|
|
|
# the current iteration has to be retained.
|
2021-02-19 10:33:49 +08:00
|
|
|
while True:
|
2024-03-28 18:02:31 +01:00
|
|
|
# todo do we need this?
|
|
|
|
# todo extend to all current attributes but some could be None at init
|
|
|
|
if self._current_obs_in_all_envs_EO is None:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Current obs is None, did you call reset or pass reset_at_collect=True?",
|
|
|
|
)
|
|
|
|
if (
|
|
|
|
not len(self._current_obs_in_all_envs_EO)
|
|
|
|
== len(self._current_action_in_all_envs_EA)
|
|
|
|
== self.env_num
|
|
|
|
): # major difference
|
|
|
|
raise RuntimeError(
|
|
|
|
f"{len(self._current_obs_in_all_envs_EO)=} and"
|
|
|
|
f"{len(self._current_action_in_all_envs_EA)=} have to equal"
|
|
|
|
f" {self.env_num=} as it tracks the current transition"
|
|
|
|
f"in all envs",
|
|
|
|
)
|
2021-02-19 10:33:49 +08:00
|
|
|
|
|
|
|
# get the next action
|
2024-03-28 18:02:31 +01:00
|
|
|
(
|
|
|
|
act_RA,
|
|
|
|
act_normalized_RA,
|
|
|
|
policy_R,
|
|
|
|
hidden_state_RH,
|
|
|
|
) = self._compute_action_policy_hidden(
|
|
|
|
random=random,
|
|
|
|
ready_env_ids_R=ready_env_ids_R,
|
|
|
|
use_grad=use_grad,
|
|
|
|
last_obs_RO=last_obs_RO,
|
|
|
|
last_info_R=last_info_R,
|
|
|
|
last_hidden_state_RH=last_hidden_state_RH,
|
|
|
|
)
|
|
|
|
|
|
|
|
# save act_RA/policy_R/ hidden_state_RH before env.step
|
|
|
|
self._current_action_in_all_envs_EA[ready_env_ids_R] = act_RA
|
|
|
|
if self._current_policy_in_all_envs_E:
|
|
|
|
self._current_policy_in_all_envs_E[ready_env_ids_R] = policy_R
|
2021-02-19 10:33:49 +08:00
|
|
|
else:
|
2024-03-28 18:02:31 +01:00
|
|
|
self._current_policy_in_all_envs_E = policy_R # first iteration
|
|
|
|
if hidden_state_RH is not None:
|
|
|
|
if self._current_hidden_state_in_all_envs_EH is not None:
|
|
|
|
# Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not
|
|
|
|
# a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat
|
|
|
|
# and hope that if one of the two is a tensor, the other one is as well.
|
|
|
|
self._current_hidden_state_in_all_envs_EH = cast(
|
|
|
|
np.ndarray | Batch,
|
|
|
|
self._current_hidden_state_in_all_envs_EH,
|
|
|
|
)
|
|
|
|
self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = hidden_state_RH
|
2020-08-19 15:00:24 +08:00
|
|
|
else:
|
2024-03-28 18:02:31 +01:00
|
|
|
self._current_hidden_state_in_all_envs_EH = hidden_state_RH
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
# step in env
|
2024-03-28 18:02:31 +01:00
|
|
|
obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step(
|
|
|
|
act_normalized_RA,
|
|
|
|
ready_env_ids_R,
|
2023-02-03 20:57:27 +01:00
|
|
|
)
|
2024-03-28 18:02:31 +01:00
|
|
|
done_R = np.logical_or(terminated_R, truncated_R)
|
|
|
|
# Not all environments of the AsyncCollector might have performed a step in this iteration.
|
|
|
|
# Change batch_of_envs_with_step_in_this_iteration here to reflect that ready_env_ids_R has changed.
|
|
|
|
# This means especially that R is potentially changing every iteration
|
2021-10-04 11:19:07 -04:00
|
|
|
try:
|
2024-03-28 18:02:31 +01:00
|
|
|
ready_env_ids_R = cast(np.ndarray, info_R["env_id"])
|
|
|
|
# TODO: don't use bare Exception!
|
2021-10-04 11:19:07 -04:00
|
|
|
except Exception:
|
2024-03-28 18:02:31 +01:00
|
|
|
ready_env_ids_R = np.array([i["env_id"] for i in info_R])
|
|
|
|
|
|
|
|
current_iteration_batch = cast(
|
|
|
|
RolloutBatchProtocol,
|
|
|
|
Batch(
|
|
|
|
obs=self._current_obs_in_all_envs_EO[ready_env_ids_R],
|
|
|
|
act=self._current_action_in_all_envs_EA[ready_env_ids_R],
|
|
|
|
policy=self._current_policy_in_all_envs_E[ready_env_ids_R],
|
|
|
|
obs_next=obs_next_RO,
|
|
|
|
rew=rew_R,
|
|
|
|
terminated=terminated_R,
|
|
|
|
truncated=truncated_R,
|
|
|
|
done=done_R,
|
|
|
|
info=info_R,
|
|
|
|
),
|
2022-09-26 18:31:23 +02:00
|
|
|
)
|
2021-02-19 10:33:49 +08:00
|
|
|
|
|
|
|
if render:
|
|
|
|
self.env.render()
|
|
|
|
if render > 0 and not np.isclose(render, 0):
|
|
|
|
time.sleep(render)
|
|
|
|
|
|
|
|
# add data into the buffer
|
2024-03-28 18:02:31 +01:00
|
|
|
ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add(
|
|
|
|
current_iteration_batch,
|
|
|
|
buffer_ids=ready_env_ids_R,
|
|
|
|
)
|
2021-02-19 10:33:49 +08:00
|
|
|
|
|
|
|
# collect statistics
|
2024-03-28 18:02:31 +01:00
|
|
|
num_episodes_done_this_iter = np.sum(done_R)
|
|
|
|
step_count += len(ready_env_ids_R)
|
|
|
|
num_collected_episodes += num_episodes_done_this_iter
|
|
|
|
|
|
|
|
# preparing for the next iteration
|
|
|
|
# todo seem we can get rid of this last_sth stuff altogether
|
|
|
|
last_obs_RO = copy(obs_next_RO)
|
|
|
|
last_info_R = copy(info_R)
|
2024-04-26 16:46:03 +02:00
|
|
|
last_hidden_state_RH = copy(
|
|
|
|
self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index]
|
|
|
|
)
|
2024-03-28 18:02:31 +01:00
|
|
|
if num_episodes_done_this_iter:
|
|
|
|
env_ind_local_D = np.where(done_R)[0]
|
|
|
|
env_ind_global_D = ready_env_ids_R[env_ind_local_D]
|
|
|
|
episode_lens.extend(ep_len_R[env_ind_local_D])
|
|
|
|
episode_returns.extend(ep_rew_R[env_ind_local_D])
|
|
|
|
episode_start_indices.extend(ep_idx_R[env_ind_local_D])
|
|
|
|
|
|
|
|
# now we copy obs_next_RO to obs, but since there might be
|
2021-02-19 10:33:49 +08:00
|
|
|
# finished episodes, we have to reset finished envs first.
|
2024-04-26 16:46:03 +02:00
|
|
|
gym_reset_kwargs = gym_reset_kwargs or {}
|
2024-03-28 18:02:31 +01:00
|
|
|
obs_reset_DO, info_reset_D = self.env.reset(
|
|
|
|
env_id=env_ind_global_D,
|
|
|
|
**gym_reset_kwargs,
|
|
|
|
)
|
|
|
|
last_obs_RO[env_ind_local_D] = obs_reset_DO
|
|
|
|
last_info_R[env_ind_local_D] = info_reset_D
|
|
|
|
|
|
|
|
self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH)
|
|
|
|
|
|
|
|
# update based on the current transition in all envs
|
|
|
|
self._current_obs_in_all_envs_EO[ready_env_ids_R] = last_obs_RO
|
|
|
|
# this is a list, so loop over
|
|
|
|
for idx, ready_env_id in enumerate(ready_env_ids_R):
|
|
|
|
self._current_info_in_all_envs_E[ready_env_id] = last_info_R[idx]
|
|
|
|
if self._current_hidden_state_in_all_envs_EH is not None:
|
|
|
|
# Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not
|
|
|
|
# a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat
|
|
|
|
# and hope that if one of the two is a tensor, the other one is as well.
|
|
|
|
self._current_hidden_state_in_all_envs_EH = cast(
|
|
|
|
np.ndarray | Batch,
|
|
|
|
self._current_hidden_state_in_all_envs_EH,
|
|
|
|
)
|
|
|
|
self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = last_hidden_state_RH
|
|
|
|
else:
|
|
|
|
self._current_hidden_state_in_all_envs_EH = last_hidden_state_RH
|
2021-02-19 10:33:49 +08:00
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
if (n_step and step_count >= n_step) or (
|
|
|
|
n_episode and num_collected_episodes >= n_episode
|
|
|
|
):
|
2021-02-19 10:33:49 +08:00
|
|
|
break
|
|
|
|
|
|
|
|
# generate statistics
|
|
|
|
self.collect_step += step_count
|
2024-03-28 18:02:31 +01:00
|
|
|
self.collect_episode += num_collected_episodes
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
collect_time = max(time.time() - start_time, 1e-9)
|
|
|
|
self.collect_time += collect_time
|
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
# persist for future collect iterations
|
|
|
|
self._ready_env_ids_R = ready_env_ids_R
|
|
|
|
|
|
|
|
return CollectStats.with_autogenerated_stats(
|
|
|
|
returns=np.array(episode_returns),
|
|
|
|
lens=np.array(episode_lens),
|
|
|
|
n_collected_episodes=num_collected_episodes,
|
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
|
|
|
n_collected_steps=step_count,
|
|
|
|
collect_time=collect_time,
|
|
|
|
collect_speed=step_count / collect_time,
|
|
|
|
)
|