2020-04-28 20:56:02 +08:00
|
|
|
import pprint
|
2020-06-20 22:23:12 +08:00
|
|
|
import warnings
|
2023-08-25 23:40:56 +02:00
|
|
|
from collections.abc import Collection, Iterable, Iterator, Sequence
|
2020-07-06 20:30:15 +08:00
|
|
|
from copy import deepcopy
|
2020-06-24 15:43:48 +02:00
|
|
|
from numbers import Number
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
from typing import (
|
|
|
|
Any,
|
|
|
|
Optional,
|
|
|
|
Protocol,
|
|
|
|
TypeVar,
|
|
|
|
Union,
|
|
|
|
cast,
|
|
|
|
overload,
|
|
|
|
runtime_checkable,
|
|
|
|
)
|
2021-09-03 05:05:04 +08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
IndexType = Union[slice, int, np.ndarray, list[int]]
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
TBatch = TypeVar("TBatch", bound="BatchProtocol")
|
|
|
|
arr_type = Union[torch.Tensor, np.ndarray]
|
2021-03-30 16:06:03 +08:00
|
|
|
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2022-01-30 00:53:56 +08:00
|
|
|
def _is_batch_set(obj: Any) -> bool:
|
2020-07-16 19:36:32 +08:00
|
|
|
# Batch set is a list/tuple of dict/Batch objects,
|
2021-03-30 16:06:03 +08:00
|
|
|
# or 1-D np.ndarray with object type,
|
2020-07-16 19:36:32 +08:00
|
|
|
# where each element is a dict/Batch object
|
2022-01-30 00:53:56 +08:00
|
|
|
if isinstance(obj, np.ndarray): # most often case
|
|
|
|
# "for element in obj" will just unpack the first dimension,
|
|
|
|
# but obj.tolist() will flatten ndarray of objects
|
|
|
|
# so do not use obj.tolist()
|
fix a bug in batch._is_batch_set (#825)
- [ ] I have marked all applicable categories:
+ [x] exception-raising fix
+ [ ] algorithm implementation fix
+ [ ] documentation modification
+ [ ] new feature
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
- [ ] If applicable, I have listed every items in this Pull Request
below
I'm developing a new PettingZoo environment. It is a two players turns
board game.
```
obs_space = dict(
board = gym.spaces.MultiBinary([8, 8]),
player = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2),
other_player = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2)
)
self._observation_space = gym.spaces.Dict(spaces=obs_space)
self._action_space = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2)
...
# this cache ensures that same space object is returned for the same
agent
# allows action space seeding to work as expected
@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
# gymnasium spaces are defined and documented here:
https://gymnasium.farama.org/api/spaces/
return self._observation_space
@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return self._action_space
```
My test is:
```
def test_with_tianshou():
action = None
# env = gym.make('qwertyenv/CollectCoins-v0', pieces=['rock', 'rock'])
env = CollectCoinsEnv(pieces=['rock', 'rock'], with_mask=True)
def another_action_taken(action_taken):
nonlocal action
action = action_taken
# Wrapping the original environment as to make sure a valid action will
be taken.
env = EnsureValidAction(
env,
env.check_action_valid,
env.provide_alternative_valid_action,
another_action_taken
)
env = PettingZooEnv(env)
policies = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()],
env)
env = DummyVectorEnv([lambda: env])
collector = Collector(policies, env)
result = collector.collect(n_step=200, render=0.1)
```
I have also a wrapper that may be redundant as of Tianshou capability to action_mask, yet it is still part of the code:
```
from typing import TypeVar, Callable
import gymnasium as gym
from pettingzoo.utils.wrappers import BaseWrapper
Action = TypeVar("Action")
class ActionWrapper(BaseWrapper):
def __init__(self, env: gym.Env):
super().__init__(env)
def step(self, action):
action = self.action(action)
self.env.step(action)
def action(self, action):
pass
def render(self, *args, **kwargs):
self.env.render(*args, **kwargs)
class EnsureValidAction(ActionWrapper):
"""
A gym environment wrapper to help with the case that the agent wants to
take invalid actions.
For example consider a Chess game, where you let the action_space be any
piece moving to any square on the board,
but then when a wrong move is taken, instead of returing a big negative
reward, you just take another action,
this time a valid one. To make sure the learning algorithm is aware of
the action taken, a callback should be provided.
"""
def __init__(self, env: gym.Env,
check_action_valid: Callable[[Action], bool],
provide_alternative_valid_action: Callable[[Action], Action],
alternative_action_cb: Callable[[Action], None]):
super().__init__(env)
self.check_action_valid = check_action_valid
self.provide_alternative_valid_action = provide_alternative_valid_action
self.alternative_action_cb = alternative_action_cb
def action(self, action: Action) -> Action:
if self.check_action_valid(action):
return action
alternative_action = self.provide_alternative_valid_action(action)
self.alternative_action_cb(alternative_action)
return alternative_action
```
To make above work I had to patch a bit PettingZoo (opened a pull-request there), and a small patch here (this PR).
Maybe I'm doing something wrong, yet I fail to see it.
With my both fixes of PZ and of Tianshou, I have two tests, one of the environment by itself, and the other as of above.
2023-03-13 01:58:09 +01:00
|
|
|
if obj.shape == ():
|
|
|
|
return False
|
2023-08-25 23:40:56 +02:00
|
|
|
return obj.dtype == object and all(isinstance(element, (dict, Batch)) for element in obj)
|
|
|
|
if (
|
|
|
|
isinstance(obj, (list, tuple))
|
|
|
|
and len(obj) > 0
|
|
|
|
and all(isinstance(element, (dict, Batch)) for element in obj)
|
|
|
|
):
|
|
|
|
return True
|
2020-06-26 12:37:50 +02:00
|
|
|
return False
|
|
|
|
|
|
|
|
|
2020-07-19 15:20:35 +08:00
|
|
|
def _is_scalar(value: Any) -> bool:
|
|
|
|
# check if the value is a scalar
|
|
|
|
# 1. python bool object, number object: isinstance(value, Number)
|
|
|
|
# 2. numpy scalar: isinstance(value, np.generic)
|
|
|
|
# 3. python object rather than dict / Batch / tensor
|
|
|
|
# the check of dict / Batch is omitted because this only checks a value.
|
|
|
|
# a dict / Batch will eventually check their values
|
2020-07-21 10:47:56 +02:00
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
|
return value.numel() == 1 and not value.shape
|
2023-08-25 23:40:56 +02:00
|
|
|
# np.asanyarray will cause dead loop in some cases
|
|
|
|
return np.isscalar(value)
|
2020-07-19 15:20:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
def _is_number(value: Any) -> bool:
|
|
|
|
# isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc.
|
|
|
|
# isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc.
|
|
|
|
# isinstance(value, np.bool_) checks np.bool_(True), etc.
|
2020-08-27 12:15:18 +08:00
|
|
|
# similar to np.isscalar but np.isscalar('st') returns True
|
|
|
|
return isinstance(value, (Number, np.number, np.bool_))
|
2020-07-19 15:20:35 +08:00
|
|
|
|
|
|
|
|
2022-01-30 00:53:56 +08:00
|
|
|
def _to_array_with_correct_type(obj: Any) -> np.ndarray:
|
2023-08-25 23:40:56 +02:00
|
|
|
if isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, (np.bool_, np.number)):
|
2022-01-30 00:53:56 +08:00
|
|
|
return obj # most often case
|
2020-07-19 15:20:35 +08:00
|
|
|
# convert the value to np.ndarray
|
2022-01-30 00:53:56 +08:00
|
|
|
# convert to object obj type if neither bool nor number
|
2021-09-03 05:05:04 +08:00
|
|
|
# raises an exception if array's elements are tensors themselves
|
2023-08-09 19:27:18 +02:00
|
|
|
try:
|
|
|
|
obj_array = np.asanyarray(obj)
|
|
|
|
except ValueError:
|
|
|
|
obj_array = np.asanyarray(obj, dtype=object)
|
2022-01-30 00:53:56 +08:00
|
|
|
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
|
|
|
|
obj_array = obj_array.astype(object)
|
|
|
|
if obj_array.dtype == object:
|
|
|
|
# scalar ndarray with object obj type is very annoying
|
2020-07-19 15:20:35 +08:00
|
|
|
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
|
|
|
|
# a is not array([{}, {}], dtype=object), and a[0]={} results in
|
|
|
|
# something very strange:
|
|
|
|
# array([{}, array({}, dtype=object)], dtype=object)
|
2022-01-30 00:53:56 +08:00
|
|
|
if not obj_array.shape:
|
|
|
|
obj_array = obj_array.item(0)
|
|
|
|
elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
|
|
|
|
return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
|
|
|
|
elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
|
2020-07-21 10:47:56 +02:00
|
|
|
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
2022-01-30 00:53:56 +08:00
|
|
|
return obj_array
|
2020-07-19 15:20:35 +08:00
|
|
|
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def create_value(
|
|
|
|
inst: Any,
|
|
|
|
size: int,
|
|
|
|
stack: bool = True,
|
|
|
|
) -> Union["Batch", np.ndarray, torch.Tensor]:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Create empty place-holders accroding to inst's shape.
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
:param bool stack: whether to stack or to concatenate. E.g. if inst has shape of
|
|
|
|
(3, 5), size = 10, stack=True returns an np.ndarry with shape of (10, 3, 5),
|
|
|
|
otherwise (10, 5)
|
2020-07-12 23:45:42 +08:00
|
|
|
"""
|
|
|
|
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
|
2020-07-19 15:20:35 +08:00
|
|
|
is_scalar = _is_scalar(inst)
|
2020-07-12 23:45:42 +08:00
|
|
|
if not stack and is_scalar:
|
2021-02-19 10:33:49 +08:00
|
|
|
# should never hit since it has already checked in Batch.cat_ , here we do not
|
|
|
|
# consider scalar types, following the behavior of numpy which does not support
|
|
|
|
# concatenation of zero-dimensional arrays (scalars)
|
2020-07-16 19:36:32 +08:00
|
|
|
raise TypeError(f"cannot concatenate with {inst} which is scalar")
|
2020-07-12 23:45:42 +08:00
|
|
|
if has_shape:
|
|
|
|
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
|
2020-06-27 03:06:40 +02:00
|
|
|
if isinstance(inst, np.ndarray):
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
target_type = (
|
2023-08-25 23:40:56 +02:00
|
|
|
inst.dtype.type if issubclass(inst.dtype.type, (np.bool_, np.number)) else object
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
)
|
2023-08-25 23:40:56 +02:00
|
|
|
return np.full(shape, fill_value=None if target_type == object else 0, dtype=target_type)
|
|
|
|
if isinstance(inst, torch.Tensor):
|
2021-02-19 10:33:49 +08:00
|
|
|
return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype)
|
2023-08-25 23:40:56 +02:00
|
|
|
if isinstance(inst, (dict, Batch)):
|
2020-06-27 03:06:40 +02:00
|
|
|
zero_batch = Batch()
|
|
|
|
for key, val in inst.items():
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
zero_batch.__dict__[key] = create_value(val, size, stack=stack)
|
2020-06-27 03:06:40 +02:00
|
|
|
return zero_batch
|
2023-08-25 23:40:56 +02:00
|
|
|
if is_scalar:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
return create_value(np.asarray(inst), size, stack=stack)
|
2023-08-25 23:40:56 +02:00
|
|
|
# fall back to object
|
|
|
|
return np.array([None for _ in range(size)], object)
|
2020-06-27 03:06:40 +02:00
|
|
|
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def _assert_type_keys(keys: Iterable[str]) -> None:
|
2023-08-25 23:40:56 +02:00
|
|
|
assert all(isinstance(key, str) for key in keys), f"keys should all be string, but got {keys}"
|
2020-07-11 09:44:47 +08:00
|
|
|
|
|
|
|
|
2022-01-30 00:53:56 +08:00
|
|
|
def _parse_value(obj: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
|
|
|
|
if isinstance(obj, Batch): # most often case
|
|
|
|
return obj
|
2023-08-25 23:40:56 +02:00
|
|
|
if (
|
|
|
|
(isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, (np.bool_, np.number)))
|
|
|
|
or isinstance(obj, torch.Tensor)
|
|
|
|
or obj is None
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
): # third often case
|
2022-01-30 00:53:56 +08:00
|
|
|
return obj
|
2023-08-25 23:40:56 +02:00
|
|
|
if _is_number(obj): # second often case, but it is more time-consuming
|
2022-01-30 00:53:56 +08:00
|
|
|
return np.asanyarray(obj)
|
2023-08-25 23:40:56 +02:00
|
|
|
if isinstance(obj, dict):
|
2022-01-30 00:53:56 +08:00
|
|
|
return Batch(obj)
|
2023-08-25 23:40:56 +02:00
|
|
|
if (
|
|
|
|
not isinstance(obj, np.ndarray)
|
|
|
|
and isinstance(obj, Collection)
|
|
|
|
and len(obj) > 0
|
|
|
|
and all(isinstance(element, torch.Tensor) for element in obj)
|
|
|
|
):
|
|
|
|
try:
|
|
|
|
obj = cast(list[torch.Tensor], obj)
|
|
|
|
return torch.stack(obj)
|
|
|
|
except RuntimeError as exception:
|
|
|
|
raise TypeError(
|
|
|
|
"Batch does not support non-stackable iterable"
|
|
|
|
" of torch.Tensor as unique value yet.",
|
|
|
|
) from exception
|
|
|
|
if _is_batch_set(obj):
|
|
|
|
obj = Batch(obj) # list of dict / Batch
|
2020-07-19 15:20:35 +08:00
|
|
|
else:
|
2023-08-25 23:40:56 +02:00
|
|
|
# None, scalar, normal obj list (main case)
|
|
|
|
# or an actual list of objects
|
|
|
|
try:
|
|
|
|
obj = _to_array_with_correct_type(obj)
|
|
|
|
except ValueError as exception:
|
|
|
|
raise TypeError(
|
|
|
|
"Batch does not support heterogeneous list/tuple of tensors as unique value yet.",
|
|
|
|
) from exception
|
|
|
|
return obj
|
2020-07-19 15:20:35 +08:00
|
|
|
|
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def alloc_by_keys_diff(
|
|
|
|
meta: "BatchProtocol",
|
|
|
|
batch: "BatchProtocol",
|
|
|
|
size: int,
|
2023-08-25 23:40:56 +02:00
|
|
|
stack: bool = True,
|
2021-03-02 12:28:28 +08:00
|
|
|
) -> None:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
"""Creates place-holders inside meta for keys that are in batch but not in meta.
|
|
|
|
|
|
|
|
This mainly is an internal method, use it only if you know what you are doing.
|
|
|
|
"""
|
2021-03-02 12:28:28 +08:00
|
|
|
for key in batch.keys():
|
|
|
|
if key in meta.keys():
|
|
|
|
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
alloc_by_keys_diff(meta[key], batch[key], size, stack)
|
2021-03-02 12:28:28 +08:00
|
|
|
elif isinstance(meta[key], Batch) and meta[key].is_empty():
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
meta[key] = create_value(batch[key], size, stack)
|
2021-03-02 12:28:28 +08:00
|
|
|
else:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
meta[key] = create_value(batch[key], size, stack)
|
2021-03-02 12:28:28 +08:00
|
|
|
|
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
# Note: This is implemented as a protocol because the interface
|
|
|
|
# of Batch is always extended by adding new fields. Having a hierarchy of
|
|
|
|
# protocols building off this one allows for type safety and IDE support despite
|
|
|
|
# the dynamic nature of Batch
|
|
|
|
@runtime_checkable
|
|
|
|
class BatchProtocol(Protocol):
|
2020-09-11 07:55:37 +08:00
|
|
|
"""The internal data structure in Tianshou.
|
|
|
|
|
2021-09-03 05:05:04 +08:00
|
|
|
Batch is a kind of supercharged array (of temporal data) stored individually in a
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
(recursive) dictionary of objects that can be either numpy arrays, torch tensors, or
|
|
|
|
batches themselves. It is designed to make it extremely easily to access, manipulate
|
2021-09-03 05:05:04 +08:00
|
|
|
and set partial view of the heterogeneous data conveniently.
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-07-19 15:20:35 +08:00
|
|
|
For a detailed description, please refer to :ref:`batch_concept`.
|
|
|
|
"""
|
2020-08-27 12:15:18 +08:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
@property
|
2023-08-25 23:40:56 +02:00
|
|
|
def shape(self) -> list[int]:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
...
|
|
|
|
|
|
|
|
def __setattr__(self, key: str, value: Any) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __getattr__(self, key: str) -> Any:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __getstate__(self) -> dict:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __setstate__(self, state: dict) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def __getitem__(self, index: str) -> Any:
|
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __getitem__(self, index: Union[str, IndexType]) -> Any:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __iadd__(self: TBatch, other: Union[TBatch, Number, np.number]) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __add__(self: TBatch, other: Union[TBatch, Number, np.number]) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __imul__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __mul__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __itruediv__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __truediv__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
|
|
|
...
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
...
|
|
|
|
|
|
|
|
def to_numpy(self) -> None:
|
|
|
|
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
|
|
|
...
|
|
|
|
|
|
|
|
def to_torch(
|
|
|
|
self,
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
device: Union[str, int, torch.device] = "cpu",
|
|
|
|
) -> None:
|
|
|
|
"""Change all numpy.ndarray to torch.Tensor in-place."""
|
|
|
|
...
|
|
|
|
|
|
|
|
def cat_(self, batches: Union[TBatch, Sequence[Union[dict, TBatch]]]) -> None:
|
|
|
|
"""Concatenate a list of (or one) Batch objects into current batch."""
|
|
|
|
...
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def cat(batches: Sequence[Union[dict, TBatch]]) -> TBatch:
|
|
|
|
"""Concatenate a list of Batch object into a single new batch.
|
|
|
|
|
|
|
|
For keys that are not shared across all batches, batches that do not
|
|
|
|
have these keys will be padded by zeros with appropriate shapes. E.g.
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
|
|
|
|
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
|
|
|
|
>>> c = Batch.cat([a, b])
|
|
|
|
>>> c.a.shape
|
|
|
|
(7, 4)
|
|
|
|
>>> c.b.shape
|
|
|
|
(7, 3)
|
|
|
|
>>> c.common.c.shape
|
|
|
|
(7, 5)
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
def stack_(self, batches: Sequence[Union[dict, TBatch]], axis: int = 0) -> None:
|
|
|
|
"""Stack a list of Batch object into current batch."""
|
|
|
|
...
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def stack(batches: Sequence[Union[dict, TBatch]], axis: int = 0) -> TBatch:
|
|
|
|
"""Stack a list of Batch object into a single new batch.
|
|
|
|
|
|
|
|
For keys that are not shared across all batches, batches that do not
|
|
|
|
have these keys will be padded by zeros. E.g.
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
|
|
|
|
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
|
|
|
|
>>> c = Batch.stack([a, b])
|
|
|
|
>>> c.a.shape
|
|
|
|
(2, 4, 4)
|
|
|
|
>>> c.b.shape
|
|
|
|
(2, 4, 6)
|
|
|
|
>>> c.common.c.shape
|
|
|
|
(2, 4, 5)
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
If there are keys that are not shared across all batches, ``stack``
|
|
|
|
with ``axis != 0`` is undefined, and will cause an exception.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
def empty_(self: TBatch, index: Optional[Union[slice, IndexType]] = None) -> TBatch:
|
|
|
|
"""Return an empty Batch object with 0 or None filled.
|
|
|
|
|
|
|
|
If "index" is specified, it will only reset the specific indexed-data.
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> data.empty_()
|
|
|
|
>>> print(data)
|
|
|
|
Batch(
|
|
|
|
a: array([[0., 0.],
|
|
|
|
[0., 0.]]),
|
|
|
|
b: array([None, None], dtype=object),
|
|
|
|
)
|
|
|
|
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
|
|
|
|
>>> data = Batch(a=[False, True], b=b)
|
|
|
|
>>> data[0] = Batch.empty(data[1])
|
|
|
|
>>> data
|
|
|
|
Batch(
|
|
|
|
a: array([False, True]),
|
|
|
|
b: Batch(
|
|
|
|
c: array([None, 'st']),
|
|
|
|
d: array([0., 0.]),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def empty(batch: TBatch, index: Optional[IndexType] = None) -> TBatch:
|
|
|
|
"""Return an empty Batch object with 0 or None filled.
|
|
|
|
|
|
|
|
The shape is the same as the given Batch.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def update(self, batch: Optional[Union[dict, TBatch]] = None, **kwargs: Any) -> None:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
"""Update this batch from another dict/Batch."""
|
|
|
|
...
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
...
|
|
|
|
|
|
|
|
def is_empty(self, recurse: bool = False) -> bool:
|
|
|
|
...
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def split(
|
|
|
|
self: TBatch,
|
|
|
|
size: int,
|
|
|
|
shuffle: bool = True,
|
|
|
|
merge_last: bool = False,
|
|
|
|
) -> Iterator[TBatch]:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
"""Split whole data into multiple small batches.
|
|
|
|
|
|
|
|
:param int size: divide the data batch with the given size, but one
|
|
|
|
batch if the length of the batch is smaller than "size". Size of -1 means
|
|
|
|
the whole batch.
|
|
|
|
:param bool shuffle: randomly shuffle the entire data batch if it is
|
|
|
|
True, otherwise remain in the same. Default to True.
|
|
|
|
:param bool merge_last: merge the last batch into the previous one.
|
|
|
|
Default to False.
|
|
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
class Batch(BatchProtocol):
|
|
|
|
"""See :class:`~tianshou.data.batch.BatchProtocol`."""
|
|
|
|
|
|
|
|
__doc__ = BatchProtocol.__doc__
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-08-25 23:40:56 +02:00
|
|
|
batch_dict: Optional[
|
|
|
|
Union[dict, BatchProtocol, Sequence[Union[dict, BatchProtocol]], np.ndarray]
|
|
|
|
] = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
copy: bool = False,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> None:
|
2020-07-06 20:30:15 +08:00
|
|
|
if copy:
|
|
|
|
batch_dict = deepcopy(batch_dict)
|
2020-07-08 21:00:00 +08:00
|
|
|
if batch_dict is not None:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
if isinstance(batch_dict, (dict, BatchProtocol)):
|
2020-07-11 09:44:47 +08:00
|
|
|
_assert_type_keys(batch_dict.keys())
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in batch_dict.items():
|
|
|
|
self.__dict__[batch_key] = _parse_value(obj)
|
2020-07-08 21:00:00 +08:00
|
|
|
elif _is_batch_set(batch_dict):
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
batch_dict = cast(Sequence[Union[dict, BatchProtocol]], batch_dict)
|
|
|
|
self.stack_(batch_dict)
|
2020-06-23 16:50:59 +02:00
|
|
|
if len(kwargs) > 0:
|
2020-09-13 19:31:50 +08:00
|
|
|
self.__init__(kwargs, copy=copy) # type: ignore
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-08-27 12:15:18 +08:00
|
|
|
def __setattr__(self, key: str, value: Any) -> None:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Set self.key = value."""
|
2020-07-19 15:20:35 +08:00
|
|
|
self.__dict__[key] = _parse_value(value)
|
2020-06-30 18:02:44 +08:00
|
|
|
|
2020-09-13 19:31:50 +08:00
|
|
|
def __getattr__(self, key: str) -> Any:
|
|
|
|
"""Return self.key. The "Any" return type is needed for mypy."""
|
|
|
|
return getattr(self.__dict__, key)
|
|
|
|
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
|
|
"""Return key in self."""
|
|
|
|
return key in self.__dict__
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Pickling interface.
|
|
|
|
|
|
|
|
Only the actual data are serialized for both efficiency and simplicity.
|
2020-05-30 15:29:33 +02:00
|
|
|
"""
|
|
|
|
state = {}
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in self.items():
|
|
|
|
if isinstance(obj, Batch):
|
2023-08-25 23:40:56 +02:00
|
|
|
state[batch_key] = obj.__getstate__()
|
|
|
|
else:
|
|
|
|
state[batch_key] = obj
|
2020-05-30 15:29:33 +02:00
|
|
|
return state
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Unpickling interface.
|
|
|
|
|
|
|
|
At this point, self is an empty Batch instance that has not been
|
|
|
|
initialized, so it can safely be initialized by the pickle state.
|
2020-05-30 15:29:33 +02:00
|
|
|
"""
|
2020-09-13 19:31:50 +08:00
|
|
|
self.__init__(**state) # type: ignore
|
2020-05-30 15:29:33 +02:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
@overload
|
|
|
|
def __getitem__(self, index: str) -> Any:
|
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
|
|
|
|
...
|
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
def __getitem__(self, index: Union[str, IndexType]) -> Any:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Return self[index]."""
|
2020-04-28 20:56:02 +08:00
|
|
|
if isinstance(index, str):
|
2020-06-26 12:37:50 +02:00
|
|
|
return self.__dict__[index]
|
2020-07-08 16:29:37 +02:00
|
|
|
batch_items = self.items()
|
|
|
|
if len(batch_items) > 0:
|
2022-01-30 00:53:56 +08:00
|
|
|
new_batch = Batch()
|
|
|
|
for batch_key, obj in batch_items:
|
|
|
|
if isinstance(obj, Batch) and obj.is_empty():
|
|
|
|
new_batch.__dict__[batch_key] = Batch()
|
2020-07-08 16:29:37 +02:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
new_batch.__dict__[batch_key] = obj[index]
|
|
|
|
return new_batch
|
2023-08-25 23:40:56 +02:00
|
|
|
raise IndexError("Cannot access item from empty Batch object.")
|
2020-06-24 15:43:48 +02:00
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
|
2020-06-27 03:06:40 +02:00
|
|
|
"""Assign value to self[index]."""
|
2020-08-27 12:15:18 +08:00
|
|
|
value = _parse_value(value)
|
2020-06-25 14:39:30 +02:00
|
|
|
if isinstance(index, str):
|
2020-08-27 12:15:18 +08:00
|
|
|
self.__dict__[index] = value
|
2020-06-26 12:37:50 +02:00
|
|
|
return
|
2020-09-13 19:31:50 +08:00
|
|
|
if not isinstance(value, Batch):
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
raise ValueError(
|
2021-09-03 05:05:04 +08:00
|
|
|
"Batch does not supported tensor assignment. "
|
2023-08-25 23:40:56 +02:00
|
|
|
"Use a compatible Batch or dict instead.",
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
|
|
|
if not set(value.keys()).issubset(self.__dict__.keys()):
|
|
|
|
raise ValueError("Creating keys is not supported by item assignment.")
|
2020-06-26 12:37:50 +02:00
|
|
|
for key, val in self.items():
|
|
|
|
try:
|
|
|
|
self.__dict__[key][index] = value[key]
|
|
|
|
except KeyError:
|
|
|
|
if isinstance(val, Batch):
|
|
|
|
self.__dict__[key][index] = Batch()
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
elif isinstance(val, torch.Tensor) or (
|
|
|
|
isinstance(val, np.ndarray)
|
|
|
|
and issubclass(val.dtype.type, (np.bool_, np.number))
|
|
|
|
):
|
2020-06-26 12:37:50 +02:00
|
|
|
self.__dict__[key][index] = 0
|
|
|
|
else:
|
|
|
|
self.__dict__[key][index] = None
|
2020-06-25 14:39:30 +02:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def __iadd__(self: TBatch, other: Union[TBatch, Number, np.number]) -> TBatch:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Algebraic addition with another Batch instance in-place."""
|
2020-06-27 03:06:40 +02:00
|
|
|
if isinstance(other, Batch):
|
2022-01-30 00:53:56 +08:00
|
|
|
for (batch_key, obj), value in zip(
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__.items(),
|
|
|
|
other.__dict__.values(),
|
2020-09-12 15:39:01 +08:00
|
|
|
): # TODO are keys consistent?
|
2022-01-30 00:53:56 +08:00
|
|
|
if isinstance(obj, Batch) and obj.is_empty():
|
2020-06-26 12:37:50 +02:00
|
|
|
continue
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[batch_key] += value
|
2020-06-24 15:43:48 +02:00
|
|
|
return self
|
2023-08-25 23:40:56 +02:00
|
|
|
if _is_number(other):
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in self.items():
|
|
|
|
if isinstance(obj, Batch) and obj.is_empty():
|
2020-06-26 12:37:50 +02:00
|
|
|
continue
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[batch_key] += other
|
2020-06-24 15:43:48 +02:00
|
|
|
return self
|
2023-08-25 23:40:56 +02:00
|
|
|
raise TypeError("Only addition of Batch or number is supported.")
|
2020-06-24 15:43:48 +02:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def __add__(self: TBatch, other: Union[TBatch, Number, np.number]) -> TBatch:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Algebraic addition with another Batch instance out-of-place."""
|
2020-07-06 20:30:15 +08:00
|
|
|
return deepcopy(self).__iadd__(other)
|
2020-06-24 15:43:48 +02:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def __imul__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
2020-06-27 03:06:40 +02:00
|
|
|
"""Algebraic multiplication with a scalar value in-place."""
|
2022-01-30 00:53:56 +08:00
|
|
|
assert _is_number(value), "Only multiplication by a number is supported."
|
|
|
|
for batch_key, obj in self.__dict__.items():
|
|
|
|
if isinstance(obj, Batch) and obj.is_empty():
|
2020-07-16 19:36:32 +08:00
|
|
|
continue
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[batch_key] *= value
|
2020-06-26 12:37:50 +02:00
|
|
|
return self
|
2020-06-24 15:43:48 +02:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def __mul__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
2020-06-27 03:06:40 +02:00
|
|
|
"""Algebraic multiplication with a scalar value out-of-place."""
|
2022-01-30 00:53:56 +08:00
|
|
|
return deepcopy(self).__imul__(value)
|
2020-06-26 12:37:50 +02:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def __itruediv__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
2020-07-08 13:45:29 +08:00
|
|
|
"""Algebraic division with a scalar value in-place."""
|
2022-01-30 00:53:56 +08:00
|
|
|
assert _is_number(value), "Only division by a number is supported."
|
|
|
|
for batch_key, obj in self.__dict__.items():
|
|
|
|
if isinstance(obj, Batch) and obj.is_empty():
|
2020-07-16 19:36:32 +08:00
|
|
|
continue
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[batch_key] /= value
|
2020-06-26 12:37:50 +02:00
|
|
|
return self
|
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def __truediv__(self: TBatch, value: Union[Number, np.number]) -> TBatch:
|
2020-07-08 13:45:29 +08:00
|
|
|
"""Algebraic division with a scalar value out-of-place."""
|
2022-01-30 00:53:56 +08:00
|
|
|
return deepcopy(self).__itruediv__(value)
|
2020-04-28 20:56:02 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __repr__(self) -> str:
|
2020-04-09 19:53:45 +08:00
|
|
|
"""Return str(self)."""
|
2022-01-30 00:53:56 +08:00
|
|
|
self_str = self.__class__.__name__ + "(\n"
|
2020-04-09 19:53:45 +08:00
|
|
|
flag = False
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in self.__dict__.items():
|
|
|
|
rpl = "\n" + " " * (6 + len(batch_key))
|
|
|
|
obj_name = pprint.pformat(obj).replace("\n", rpl)
|
|
|
|
self_str += f" {batch_key}: {obj_name},\n"
|
2020-06-24 15:43:48 +02:00
|
|
|
flag = True
|
2020-04-09 19:53:45 +08:00
|
|
|
if flag:
|
2022-01-30 00:53:56 +08:00
|
|
|
self_str += ")"
|
2020-04-09 19:53:45 +08:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
self_str = self.__class__.__name__ + "()"
|
|
|
|
return self_str
|
2020-04-09 19:53:45 +08:00
|
|
|
|
2020-05-29 14:45:21 +02:00
|
|
|
def to_numpy(self) -> None:
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in self.items():
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
|
|
self.__dict__[batch_key] = obj.detach().cpu().numpy()
|
|
|
|
elif isinstance(obj, Batch):
|
|
|
|
obj.to_numpy()
|
2020-05-29 14:45:21 +02:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def to_torch(
|
|
|
|
self,
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
device: Union[str, int, torch.device] = "cpu",
|
|
|
|
) -> None:
|
2020-05-30 15:40:31 +02:00
|
|
|
if not isinstance(device, torch.device):
|
|
|
|
device = torch.device(device)
|
|
|
|
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in self.items():
|
|
|
|
if isinstance(obj, torch.Tensor):
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
if (
|
2023-08-25 23:40:56 +02:00
|
|
|
dtype is not None
|
|
|
|
and obj.dtype != dtype
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
or obj.device.type != device.type
|
|
|
|
or device.index != obj.device.index
|
|
|
|
):
|
2020-05-30 15:40:31 +02:00
|
|
|
if dtype is not None:
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[batch_key] = obj.type(dtype).to(device)
|
|
|
|
else:
|
|
|
|
self.__dict__[batch_key] = obj.to(device)
|
2022-01-30 00:53:56 +08:00
|
|
|
elif isinstance(obj, Batch):
|
|
|
|
obj.to_torch(dtype, device)
|
2020-07-19 15:20:35 +08:00
|
|
|
else:
|
|
|
|
# ndarray or scalar
|
2022-01-30 00:53:56 +08:00
|
|
|
if not isinstance(obj, np.ndarray):
|
2023-08-25 23:40:56 +02:00
|
|
|
obj = np.asanyarray(obj) # noqa: PLW2901
|
|
|
|
obj = torch.from_numpy(obj).to(device) # noqa: PLW2901
|
2020-07-19 15:20:35 +08:00
|
|
|
if dtype is not None:
|
2023-08-25 23:40:56 +02:00
|
|
|
obj = obj.type(dtype) # noqa: PLW2901
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[batch_key] = obj
|
2020-04-29 17:48:48 +08:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def __cat(self: TBatch, batches: Sequence[Union[dict, TBatch]], lens: list[int]) -> None:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Private method for Batch.cat_.
|
|
|
|
|
|
|
|
::
|
2020-07-16 19:36:32 +08:00
|
|
|
|
|
|
|
>>> a = Batch(a=np.random.randn(3, 4))
|
|
|
|
>>> x = Batch(a=a, b=np.random.randn(4, 4))
|
|
|
|
>>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4))
|
|
|
|
|
|
|
|
If we want to concatenate x and y, we want to pad y.a.a with zeros.
|
|
|
|
Without ``lens`` as a hint, when we concatenate x.a and y.a, we would
|
|
|
|
not be able to know how to pad y.a. So ``Batch.cat_`` should compute
|
|
|
|
the ``lens`` to give ``Batch.__cat`` a hint.
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> ans = Batch.cat([x, y])
|
|
|
|
>>> # this is equivalent to the following line
|
|
|
|
>>> ans = Batch(); ans.__cat([x, y], lens=[3, 4])
|
|
|
|
>>> # this lens is equal to [len(a), len(b)]
|
2020-06-20 22:23:12 +08:00
|
|
|
"""
|
2020-07-11 21:46:01 +08:00
|
|
|
# partial keys will be padded by zeros
|
|
|
|
# with the shape of [len, rest_shape]
|
2020-07-12 23:45:42 +08:00
|
|
|
sum_lens = [0]
|
2022-01-30 00:53:56 +08:00
|
|
|
for len_ in lens:
|
|
|
|
sum_lens.append(sum_lens[-1] + len_)
|
2020-07-16 19:36:32 +08:00
|
|
|
# collect non-empty keys
|
|
|
|
keys_map = [
|
2023-08-25 23:40:56 +02:00
|
|
|
{
|
|
|
|
batch_key
|
|
|
|
for batch_key, obj in batch.items()
|
2022-01-30 00:53:56 +08:00
|
|
|
if not (isinstance(obj, Batch) and obj.is_empty())
|
2023-08-25 23:40:56 +02:00
|
|
|
}
|
|
|
|
for batch in batches
|
2021-09-03 05:05:04 +08:00
|
|
|
]
|
2020-07-11 21:46:01 +08:00
|
|
|
keys_shared = set.intersection(*keys_map)
|
2022-01-30 00:53:56 +08:00
|
|
|
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
|
|
|
|
for key, shared_value in zip(keys_shared, values_shared):
|
|
|
|
if all(isinstance(element, (dict, Batch)) for element in shared_value):
|
2020-07-16 19:36:32 +08:00
|
|
|
batch_holder = Batch()
|
2022-01-30 00:53:56 +08:00
|
|
|
batch_holder.__cat(shared_value, lens=lens)
|
|
|
|
self.__dict__[key] = batch_holder
|
|
|
|
elif all(isinstance(element, torch.Tensor) for element in shared_value):
|
|
|
|
self.__dict__[key] = torch.cat(shared_value)
|
2020-07-11 21:46:01 +08:00
|
|
|
else:
|
2020-07-16 19:36:32 +08:00
|
|
|
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
|
|
|
|
# will fail here
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[key] = _to_array_with_correct_type(np.concatenate(shared_value))
|
2022-01-30 00:53:56 +08:00
|
|
|
keys_total = set.union(*[set(batch.keys()) for batch in batches])
|
2020-07-16 19:36:32 +08:00
|
|
|
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
|
|
|
|
# keys that are reserved in all batches
|
|
|
|
keys_reserve = set.difference(keys_total, set.union(*keys_map))
|
|
|
|
# keys that occur only in some batches, but not all
|
|
|
|
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
|
2022-01-30 00:53:56 +08:00
|
|
|
for key in keys_reserve:
|
2020-07-16 19:36:32 +08:00
|
|
|
# reserved keys
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[key] = Batch()
|
|
|
|
for key in keys_partial:
|
|
|
|
for i, batch in enumerate(batches):
|
|
|
|
if key not in batch.__dict__:
|
2020-07-16 19:36:32 +08:00
|
|
|
continue
|
2022-01-30 00:53:56 +08:00
|
|
|
value = batch.get(key)
|
|
|
|
if isinstance(value, Batch) and value.is_empty():
|
2020-07-16 19:36:32 +08:00
|
|
|
continue
|
|
|
|
try:
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
|
2020-07-16 19:36:32 +08:00
|
|
|
except KeyError:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
self.__dict__[key] = create_value(value, sum_lens[-1], stack=False)
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
|
2020-07-16 19:36:32 +08:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def cat_(self, batches: Union[BatchProtocol, Sequence[Union[dict, BatchProtocol]]]) -> None:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
if isinstance(batches, (BatchProtocol, dict)):
|
2020-07-16 19:36:32 +08:00
|
|
|
batches = [batches]
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
# check input format
|
|
|
|
batch_list = []
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch in batches:
|
|
|
|
if isinstance(batch, dict):
|
|
|
|
if len(batch) > 0:
|
|
|
|
batch_list.append(Batch(batch))
|
|
|
|
elif isinstance(batch, Batch):
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
# x.is_empty() means that x is Batch() and should be ignored
|
2022-01-30 00:53:56 +08:00
|
|
|
if not batch.is_empty():
|
|
|
|
batch_list.append(batch)
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
if len(batch_list) == 0:
|
2020-07-16 19:36:32 +08:00
|
|
|
return
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
batches = batch_list
|
2020-07-16 19:36:32 +08:00
|
|
|
try:
|
|
|
|
# x.is_empty(recurse=True) here means x is a nested empty batch
|
|
|
|
# like Batch(a=Batch), and we have to treat it as length zero and
|
|
|
|
# keep it.
|
2023-08-25 23:40:56 +02:00
|
|
|
lens = [0 if batch.is_empty(recurse=True) else len(batch) for batch in batches]
|
2022-01-30 00:53:56 +08:00
|
|
|
except TypeError as exception:
|
2020-08-27 12:15:18 +08:00
|
|
|
raise ValueError(
|
2020-09-12 15:39:01 +08:00
|
|
|
"Batch.cat_ meets an exception. Maybe because there is any "
|
|
|
|
f"scalar in {batches} but Batch.cat_ does not support the "
|
2023-08-25 23:40:56 +02:00
|
|
|
"concatenation of scalar.",
|
2022-01-30 00:53:56 +08:00
|
|
|
) from exception
|
2020-07-16 19:36:32 +08:00
|
|
|
if not self.is_empty():
|
2023-08-25 23:40:56 +02:00
|
|
|
batches = [self, *list(batches)]
|
|
|
|
lens = [0 if self.is_empty(recurse=True) else len(self), *lens]
|
2020-09-12 15:39:01 +08:00
|
|
|
self.__cat(batches, lens)
|
2020-03-17 11:37:31 +08:00
|
|
|
|
2020-06-30 18:02:44 +08:00
|
|
|
@staticmethod
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def cat(batches: Sequence[Union[dict, TBatch]]) -> TBatch:
|
2020-06-30 18:02:44 +08:00
|
|
|
batch = Batch()
|
2020-07-11 21:46:01 +08:00
|
|
|
batch.cat_(batches)
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
return batch # type: ignore
|
2020-06-23 16:50:59 +02:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def stack_(self, batches: Sequence[Union[dict, BatchProtocol]], axis: int = 0) -> None:
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
# check input format
|
|
|
|
batch_list = []
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch in batches:
|
|
|
|
if isinstance(batch, dict):
|
|
|
|
if len(batch) > 0:
|
|
|
|
batch_list.append(Batch(batch))
|
|
|
|
elif isinstance(batch, Batch):
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
# x.is_empty() means that x is Batch() and should be ignored
|
2022-01-30 00:53:56 +08:00
|
|
|
if not batch.is_empty():
|
|
|
|
batch_list.append(batch)
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
if len(batch_list) == 0:
|
2020-07-12 23:45:42 +08:00
|
|
|
return
|
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
a: Batch(),
b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g.
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
info: Batch(
key1: array([0, 1]),
key2: array([0, 3]),
),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
|
|
|
batches = batch_list
|
2020-07-16 19:36:32 +08:00
|
|
|
if not self.is_empty():
|
2023-08-25 23:40:56 +02:00
|
|
|
batches = [self, *batches]
|
2020-07-16 19:36:32 +08:00
|
|
|
# collect non-empty keys
|
|
|
|
keys_map = [
|
2023-08-25 23:40:56 +02:00
|
|
|
{
|
|
|
|
batch_key
|
|
|
|
for batch_key, obj in batch.items()
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
if not (isinstance(obj, BatchProtocol) and obj.is_empty())
|
2023-08-25 23:40:56 +02:00
|
|
|
}
|
|
|
|
for batch in batches
|
2021-09-03 05:05:04 +08:00
|
|
|
]
|
2020-06-27 03:06:40 +02:00
|
|
|
keys_shared = set.intersection(*keys_map)
|
2022-01-30 00:53:56 +08:00
|
|
|
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
|
|
|
|
for shared_key, value in zip(keys_shared, values_shared):
|
|
|
|
# second often
|
|
|
|
if all(isinstance(element, torch.Tensor) for element in value):
|
|
|
|
self.__dict__[shared_key] = torch.stack(value, axis)
|
|
|
|
# third often
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
elif all(isinstance(element, (BatchProtocol, dict)) for element in value):
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[shared_key] = Batch.stack(value, axis)
|
2020-08-27 12:15:18 +08:00
|
|
|
else: # most often case is np.ndarray
|
2021-06-26 18:08:41 +08:00
|
|
|
try:
|
2023-08-25 23:40:56 +02:00
|
|
|
self.__dict__[shared_key] = _to_array_with_correct_type(np.stack(value, axis))
|
2021-06-26 18:08:41 +08:00
|
|
|
except ValueError:
|
2021-09-03 05:05:04 +08:00
|
|
|
warnings.warn(
|
|
|
|
"You are using tensors with different shape,"
|
2023-08-25 23:40:56 +02:00
|
|
|
" fallback to dtype=object by default.",
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[shared_key] = np.array(value, dtype=object)
|
2020-07-16 19:36:32 +08:00
|
|
|
# all the keys
|
2022-01-30 00:53:56 +08:00
|
|
|
keys_total = set.union(*[set(batch.keys()) for batch in batches])
|
2020-07-16 19:36:32 +08:00
|
|
|
# keys that are reserved in all batches
|
|
|
|
keys_reserve = set.difference(keys_total, set.union(*keys_map))
|
|
|
|
# keys that are either partial or reserved
|
|
|
|
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
|
|
|
|
# keys that occur only in some batches, but not all
|
|
|
|
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
|
2020-07-12 23:45:42 +08:00
|
|
|
if keys_partial and axis != 0:
|
|
|
|
raise ValueError(
|
2020-09-12 15:39:01 +08:00
|
|
|
f"Stack of Batch with non-shared keys {keys_partial} is only "
|
2023-08-25 23:40:56 +02:00
|
|
|
f"supported with axis=0, but got axis={axis}!",
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2022-01-30 00:53:56 +08:00
|
|
|
for key in keys_reserve:
|
2020-07-16 19:36:32 +08:00
|
|
|
# reserved keys
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[key] = Batch()
|
|
|
|
for key in keys_partial:
|
|
|
|
for i, batch in enumerate(batches):
|
|
|
|
if key not in batch.__dict__:
|
2020-07-16 19:36:32 +08:00
|
|
|
continue
|
2022-01-30 00:53:56 +08:00
|
|
|
value = batch.get(key)
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
# TODO: fix code/annotations s.t. the ignores can be removed
|
|
|
|
if (
|
|
|
|
isinstance(value, BatchProtocol) # type: ignore
|
|
|
|
and value.is_empty() # type: ignore
|
|
|
|
):
|
2022-01-30 00:53:56 +08:00
|
|
|
continue # type: ignore
|
2020-07-16 19:36:32 +08:00
|
|
|
try:
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[key][i] = value
|
2020-07-16 19:36:32 +08:00
|
|
|
except KeyError:
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
self.__dict__[key] = create_value(value, len(batches))
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[key][i] = value
|
2020-06-27 03:06:40 +02:00
|
|
|
|
|
|
|
@staticmethod
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def stack(batches: Sequence[Union[dict, TBatch]], axis: int = 0) -> TBatch:
|
2020-06-27 03:06:40 +02:00
|
|
|
batch = Batch()
|
|
|
|
batch.stack_(batches, axis)
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
# can't cast to a generic type, so we have to ignore the type here
|
|
|
|
return batch # type: ignore
|
2020-07-19 15:20:35 +08:00
|
|
|
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def empty_(self: TBatch, index: Optional[Union[slice, IndexType]] = None) -> TBatch:
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in self.items():
|
|
|
|
if isinstance(obj, torch.Tensor): # most often case
|
|
|
|
self.__dict__[batch_key][index] = 0
|
|
|
|
elif obj is None:
|
2020-08-27 12:15:18 +08:00
|
|
|
continue
|
2022-01-30 00:53:56 +08:00
|
|
|
elif isinstance(obj, np.ndarray):
|
|
|
|
if obj.dtype == object:
|
|
|
|
self.__dict__[batch_key][index] = None
|
2020-07-06 20:30:15 +08:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[batch_key][index] = 0
|
|
|
|
elif isinstance(obj, Batch):
|
|
|
|
self.__dict__[batch_key].empty_(index=index)
|
2020-07-06 20:30:15 +08:00
|
|
|
else: # scalar value
|
2021-09-03 05:05:04 +08:00
|
|
|
warnings.warn(
|
|
|
|
"You are calling Batch.empty on a NumPy scalar, "
|
2023-08-25 23:40:56 +02:00
|
|
|
"which may cause undefined behaviors.",
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2022-01-30 00:53:56 +08:00
|
|
|
if _is_number(obj):
|
|
|
|
self.__dict__[batch_key] = obj.__class__(0)
|
2020-07-06 20:30:15 +08:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
self.__dict__[batch_key] = None
|
2020-06-30 18:02:44 +08:00
|
|
|
return self
|
|
|
|
|
|
|
|
@staticmethod
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
def empty(batch: TBatch, index: Optional[IndexType] = None) -> TBatch:
|
2020-07-06 20:30:15 +08:00
|
|
|
return deepcopy(batch).empty_(index)
|
2020-06-30 18:02:44 +08:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def update(self, batch: Optional[Union[dict, TBatch]] = None, **kwargs: Any) -> None:
|
2020-07-11 21:46:01 +08:00
|
|
|
if batch is None:
|
|
|
|
self.update(kwargs)
|
|
|
|
return
|
2022-01-30 00:53:56 +08:00
|
|
|
for batch_key, obj in batch.items():
|
|
|
|
self.__dict__[batch_key] = _parse_value(obj)
|
2020-07-11 21:46:01 +08:00
|
|
|
if kwargs:
|
|
|
|
self.update(kwargs)
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __len__(self) -> int:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Return len(self)."""
|
2022-01-30 00:53:56 +08:00
|
|
|
lens = []
|
|
|
|
for obj in self.__dict__.values():
|
|
|
|
if isinstance(obj, Batch) and obj.is_empty(recurse=True):
|
2020-06-24 15:43:48 +02:00
|
|
|
continue
|
2023-08-25 23:40:56 +02:00
|
|
|
if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
|
2022-01-30 00:53:56 +08:00
|
|
|
lens.append(len(obj))
|
2020-06-24 15:43:48 +02:00
|
|
|
else:
|
2022-01-30 00:53:56 +08:00
|
|
|
raise TypeError(f"Object {obj} in {self} has no len()")
|
|
|
|
if len(lens) == 0:
|
2020-07-19 15:20:35 +08:00
|
|
|
# empty batch has the shape of any, like the tensorflow '?' shape.
|
|
|
|
# So it has no length.
|
2020-07-16 19:36:32 +08:00
|
|
|
raise TypeError(f"Object {self} has no len()")
|
2022-01-30 00:53:56 +08:00
|
|
|
return min(lens)
|
2020-06-24 15:43:48 +02:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def is_empty(self, recurse: bool = False) -> bool:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Test if a Batch is empty.
|
|
|
|
|
|
|
|
If ``recurse=True``, it further tests the values of the object; else
|
|
|
|
it only tests the existence of any key.
|
2020-07-16 19:36:32 +08:00
|
|
|
|
|
|
|
``b.is_empty(recurse=True)`` is mainly used to distinguish
|
|
|
|
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
|
|
|
|
exceptions when applied to ``len()``, but the former can be used in
|
|
|
|
``cat``, while the latter is a scalar and cannot be used in ``cat``.
|
|
|
|
|
|
|
|
Another usage is in ``__len__``, where we have to skip checking the
|
2020-07-19 15:20:35 +08:00
|
|
|
length of recursively empty Batch.
|
2020-07-16 19:36:32 +08:00
|
|
|
::
|
|
|
|
|
|
|
|
>>> Batch().is_empty()
|
|
|
|
True
|
|
|
|
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
|
|
|
|
False
|
|
|
|
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
|
|
|
|
True
|
|
|
|
>>> Batch(d=1).is_empty()
|
|
|
|
False
|
|
|
|
>>> Batch(a=np.float64(1.0)).is_empty()
|
|
|
|
False
|
|
|
|
"""
|
|
|
|
if len(self.__dict__) == 0:
|
|
|
|
return True
|
|
|
|
if not recurse:
|
|
|
|
return False
|
2020-09-12 15:39:01 +08:00
|
|
|
return all(
|
2022-01-30 00:53:56 +08:00
|
|
|
False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
|
|
|
|
for obj in self.values()
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2020-07-11 09:44:47 +08:00
|
|
|
|
2020-06-24 15:43:48 +02:00
|
|
|
@property
|
2023-08-25 23:40:56 +02:00
|
|
|
def shape(self) -> list[int]:
|
2020-06-30 18:02:44 +08:00
|
|
|
"""Return self.shape."""
|
2020-07-16 19:36:32 +08:00
|
|
|
if self.is_empty():
|
2020-06-30 18:02:44 +08:00
|
|
|
return []
|
2023-08-25 23:40:56 +02:00
|
|
|
data_shape = []
|
|
|
|
for obj in self.__dict__.values():
|
|
|
|
try:
|
|
|
|
data_shape.append(list(obj.shape))
|
|
|
|
except AttributeError:
|
|
|
|
data_shape.append([])
|
|
|
|
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 else data_shape[0]
|
|
|
|
|
|
|
|
def split(
|
|
|
|
self: TBatch,
|
|
|
|
size: int,
|
|
|
|
shuffle: bool = True,
|
|
|
|
merge_last: bool = False,
|
|
|
|
) -> Iterator[TBatch]:
|
2020-04-03 21:28:12 +08:00
|
|
|
length = len(self)
|
Improved typing and reduced duplication (#912)
# Goals of the PR
The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.
## Major Change 1 - BatchProtocol
**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.
**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like
```python
class ActionBatchProtocol(BatchProtocol):
logits: Sequence[Union[tuple, torch.Tensor]]
dist: torch.distributions.Distribution
act: torch.Tensor
state: Optional[torch.Tensor]
class RolloutBatchProtocol(BatchProtocol):
obs: torch.Tensor
obs_next: torch.Tensor
info: Dict[str, Any]
rew: torch.Tensor
terminated: torch.Tensor
truncated: torch.Tensor
class PGPolicy(BasePolicy):
...
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> ActionBatchProtocol:
```
The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.
## Major Change 2 - remove duplication in trainer package
**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.
**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.
1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.
I have some more reasons, but maybe the above ones are convincing
enough.
## Minor changes: improved input validation and types
I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.
I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.
@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`
## Breaking Changes
The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them
---------
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 18:54:46 +02:00
|
|
|
if size == -1:
|
|
|
|
size = length
|
2023-08-25 23:40:56 +02:00
|
|
|
assert size >= 1 # size can be greater than length, return whole batch
|
|
|
|
indices = np.random.permutation(length) if shuffle else np.arange(length)
|
2020-08-27 12:15:18 +08:00
|
|
|
merge_last = merge_last and length % size > 0
|
|
|
|
for idx in range(0, length, size):
|
|
|
|
if merge_last and idx + size + size >= length:
|
|
|
|
yield self[indices[idx:]]
|
|
|
|
break
|
2023-08-25 23:40:56 +02:00
|
|
|
yield self[indices[idx : idx + size]]
|