Support Actor preprocessing network reuse for continuous case, fixes in DQN network (#1128)
This PR fixes a bug in DQN and lifts a limination in reusing the actor's preprocessing network for continuous environments. * `atari_network.DQN`: * Fix input validation * Fix output_dim not being set if features_only=True and output_dim_added_layer not None * `continuous.Critic`: * Add flag `apply_preprocess_net_to_obs_only` to allow the preprocessing network to be applied to the observations only (without the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network * CriticFactoryReuseActor: Use the flag, fixing the case where we want to reuse an actor's preprocessing network for the critic (must be applied before concatenating the actions) * Minor improvements in docs/docstrings
This commit is contained in:
commit
a65920fc68
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1,6 +1,6 @@
|
|||||||
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
|
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
|
||||||
- [ ] I have provided a description of the changes in this Pull Request
|
- [ ] I have provided a description of the changes in this Pull Request
|
||||||
- [ ] I have added documentation for my changes
|
- [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md
|
||||||
- [ ] If applicable, I have added tests to cover my changes.
|
- [ ] If applicable, I have added tests to cover my changes.
|
||||||
- [ ] I have reformatted the code using `poe format`
|
- [ ] I have reformatted the code using `poe format`
|
||||||
- [ ] I have checked style and types with `poe lint` and `poe type-check`
|
- [ ] I have checked style and types with `poe lint` and `poe type-check`
|
||||||
|
13
CHANGELOG.md
13
CHANGELOG.md
@ -19,6 +19,19 @@
|
|||||||
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!).
|
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!).
|
||||||
Launchers for parallelization currently in alpha state. #1074
|
Launchers for parallelization currently in alpha state. #1074
|
||||||
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
|
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
|
||||||
|
- `continuous.Critic`:
|
||||||
|
- Add flag `apply_preprocess_net_to_obs_only` to allow the
|
||||||
|
preprocessing network to be applied to the observations only (without
|
||||||
|
the actions concatenated), which is essential for the case where we want
|
||||||
|
to reuse the actor's preprocessing network #1128
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
|
||||||
|
fixing the case where we want to reuse an actor's preprocessing network for the critic (affects usages
|
||||||
|
of the experiment builder method `with_critic_factory_use_actor` with continuous environments) #1128
|
||||||
|
- `atari_network.DQN`:
|
||||||
|
- Fix constructor input validation #1128
|
||||||
|
- Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128
|
||||||
|
|
||||||
### Internal Improvements
|
### Internal Improvements
|
||||||
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
||||||
|
@ -152,7 +152,7 @@
|
|||||||
"id": "Lh2-hwE5Dn9I"
|
"id": "Lh2-hwE5Dn9I"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
"Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -66,7 +66,7 @@ class DQN(NetBase[Any]):
|
|||||||
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Add docstring
|
# TODO: Add docstring
|
||||||
if features_only and output_dim_added_layer is not None:
|
if not features_only and output_dim_added_layer is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.",
|
"Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.",
|
||||||
)
|
)
|
||||||
@ -98,6 +98,7 @@ class DQN(NetBase[Any]):
|
|||||||
layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)),
|
layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
|
self.output_dim = output_dim_added_layer
|
||||||
else:
|
else:
|
||||||
self.output_dim = base_cnn_output_dim
|
self.output_dim = base_cnn_output_dim
|
||||||
|
|
||||||
|
@ -172,18 +172,19 @@ class ReplayBuffer:
|
|||||||
return np.array([last] if not self.done[last] and self._size else [], int)
|
return np.array([last] if not self.done[last] and self._size else [], int)
|
||||||
|
|
||||||
def prev(self, index: int | np.ndarray) -> np.ndarray:
|
def prev(self, index: int | np.ndarray) -> np.ndarray:
|
||||||
"""Return the index of previous transition.
|
"""Return the index of preceding step within the same episode if it exists.
|
||||||
|
If it does not exist (because it is the first index within the episode),
|
||||||
The index won't be modified if it is the beginning of an episode.
|
the index remains unmodified.
|
||||||
"""
|
"""
|
||||||
index = (index - 1) % self._size
|
index = (index - 1) % self._size # compute preceding index with wrap-around
|
||||||
|
# end_flag will be 1 if the previous index is the last step of an episode or
|
||||||
|
# if it is the very last index of the buffer (wrap-around case), and 0 otherwise
|
||||||
end_flag = self.done[index] | (index == self.last_index[0])
|
end_flag = self.done[index] | (index == self.last_index[0])
|
||||||
return (index + end_flag) % self._size
|
return (index + end_flag) % self._size
|
||||||
|
|
||||||
def next(self, index: int | np.ndarray) -> np.ndarray:
|
def next(self, index: int | np.ndarray) -> np.ndarray:
|
||||||
"""Return the index of next transition.
|
"""Return the index of next step if there is a next step within the episode.
|
||||||
|
If there isn't a next step, the index remains unmodified.
|
||||||
The index won't be modified if it is the end of an episode.
|
|
||||||
"""
|
"""
|
||||||
end_flag = self.done[index] | (index == self.last_index[0])
|
end_flag = self.done[index] | (index == self.last_index[0])
|
||||||
return (index + (1 - end_flag)) % self._size
|
return (index + (1 - end_flag)) % self._size
|
||||||
|
@ -118,9 +118,12 @@ class SamplingConfig(ToStringMixin):
|
|||||||
replay_buffer_ignore_obs_next: bool = False
|
replay_buffer_ignore_obs_next: bool = False
|
||||||
|
|
||||||
replay_buffer_save_only_last_obs: bool = False
|
replay_buffer_save_only_last_obs: bool = False
|
||||||
"""if True, only the most recent frame is saved when appending to experiences rather than the
|
"""if True, for the case where the environment outputs stacked frames (e.g. because it
|
||||||
full stacked frames. This avoids duplicating observations in buffer memory. Set to False to
|
is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate
|
||||||
save stacked frames in full.
|
observations in buffer memory. Specifically, if the environment outputs observations `obs` with
|
||||||
|
shape (N, ...), only obs[-1] of shape (...) will be stored.
|
||||||
|
Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting
|
||||||
|
:attr:`replay_buffer_stack_num`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
replay_buffer_stack_num: int = 1
|
replay_buffer_stack_num: int = 1
|
||||||
@ -128,6 +131,9 @@ class SamplingConfig(ToStringMixin):
|
|||||||
the number of consecutive environment observations to stack and use as the observation input
|
the number of consecutive environment observations to stack and use as the observation input
|
||||||
to the agent for each time step. Setting this to a value greater than 1 can help agents learn
|
to the agent for each time step. Setting this to a value greater than 1 can help agents learn
|
||||||
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
|
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
|
||||||
|
|
||||||
|
If the environment already stacks frames (e.g. using a `FrameStack` wrapper), this should either not
|
||||||
|
be used or should be used in conjunction with :attr:`replay_buffer_save_only_last_obs`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -197,7 +197,11 @@ class CriticFactoryReuseActor(CriticFactory):
|
|||||||
last_size=last_size,
|
last_size=last_size,
|
||||||
).to(device)
|
).to(device)
|
||||||
elif envs.get_type().is_continuous():
|
elif envs.get_type().is_continuous():
|
||||||
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
return continuous.Critic(
|
||||||
|
actor.get_preprocess_net(),
|
||||||
|
device=device,
|
||||||
|
apply_preprocess_net_to_obs_only=True,
|
||||||
|
).to(device)
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from tianshou.utils.net.common import (
|
|||||||
TLinearLayer,
|
TLinearLayer,
|
||||||
get_output_dim,
|
get_output_dim,
|
||||||
)
|
)
|
||||||
|
from tianshou.utils.pickle import setstate
|
||||||
|
|
||||||
SIGMA_MIN = -20
|
SIGMA_MIN = -20
|
||||||
SIGMA_MAX = 2
|
SIGMA_MAX = 2
|
||||||
@ -109,6 +110,9 @@ class Critic(CriticBase):
|
|||||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||||
:param linear_layer: use this module as linear layer.
|
:param linear_layer: use this module as linear layer.
|
||||||
:param flatten_input: whether to flatten input data for the last layer.
|
:param flatten_input: whether to flatten input data for the last layer.
|
||||||
|
:param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before
|
||||||
|
concatenating with the action) - and without the observations being modified in any way beforehand.
|
||||||
|
This allows the actor's preprocessing network to be reused for the critic.
|
||||||
|
|
||||||
For advanced usage (how to customize the network), please refer to
|
For advanced usage (how to customize the network), please refer to
|
||||||
:ref:`build_the_network`.
|
:ref:`build_the_network`.
|
||||||
@ -122,11 +126,13 @@ class Critic(CriticBase):
|
|||||||
preprocess_net_output_dim: int | None = None,
|
preprocess_net_output_dim: int | None = None,
|
||||||
linear_layer: TLinearLayer = nn.Linear,
|
linear_layer: TLinearLayer = nn.Linear,
|
||||||
flatten_input: bool = True,
|
flatten_input: bool = True,
|
||||||
|
apply_preprocess_net_to_obs_only: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = 1
|
self.output_dim = 1
|
||||||
|
self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only
|
||||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
self.last = MLP(
|
self.last = MLP(
|
||||||
input_dim,
|
input_dim,
|
||||||
@ -137,6 +143,14 @@ class Critic(CriticBase):
|
|||||||
flatten_input=flatten_input,
|
flatten_input=flatten_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __setstate__(self, state: dict) -> None:
|
||||||
|
setstate(
|
||||||
|
Critic,
|
||||||
|
self,
|
||||||
|
state,
|
||||||
|
new_default_properties={"apply_preprocess_net_to_obs_only": False},
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
obs: np.ndarray | torch.Tensor,
|
obs: np.ndarray | torch.Tensor,
|
||||||
@ -148,7 +162,10 @@ class Critic(CriticBase):
|
|||||||
obs,
|
obs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
).flatten(1)
|
)
|
||||||
|
if self.apply_preprocess_net_to_obs_only:
|
||||||
|
obs, _ = self.preprocess(obs)
|
||||||
|
obs = obs.flatten(1)
|
||||||
if act is not None:
|
if act is not None:
|
||||||
act = torch.as_tensor(
|
act = torch.as_tensor(
|
||||||
act,
|
act,
|
||||||
@ -156,8 +173,9 @@ class Critic(CriticBase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
).flatten(1)
|
).flatten(1)
|
||||||
obs = torch.cat([obs, act], dim=1)
|
obs = torch.cat([obs, act], dim=1)
|
||||||
values_B, hidden_BH = self.preprocess(obs)
|
if not self.apply_preprocess_net_to_obs_only:
|
||||||
return self.last(values_B)
|
obs, _ = self.preprocess(obs)
|
||||||
|
return self.last(obs)
|
||||||
|
|
||||||
|
|
||||||
class ActorProb(BaseActor):
|
class ActorProb(BaseActor):
|
||||||
|
97
tianshou/utils/pickle.py
Normal file
97
tianshou/utils/pickle.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Helper functions for persistence/pickling, which have been copied from sensAI (specifically `sensai.util.pickle`)."""
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from copy import copy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def setstate(
|
||||||
|
cls: type,
|
||||||
|
obj: Any,
|
||||||
|
state: dict[str, Any],
|
||||||
|
renamed_properties: dict[str, str] | None = None,
|
||||||
|
new_optional_properties: list[str] | None = None,
|
||||||
|
new_default_properties: dict[str, Any] | None = None,
|
||||||
|
removed_properties: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Helper function for safe implementations of `__setstate__` in classes, which appropriately handles the cases where
|
||||||
|
a parent class already implements `__setstate__` and where it does not. Call this function whenever you would actually
|
||||||
|
like to call the super-class' implementation.
|
||||||
|
Unfortunately, `__setstate__` is not implemented in `object`, rendering `super().__setstate__(state)` invalid in the general case.
|
||||||
|
|
||||||
|
:param cls: the class in which you are implementing `__setstate__`
|
||||||
|
:param obj: the instance of `cls`
|
||||||
|
:param state: the state dictionary
|
||||||
|
:param renamed_properties: a mapping from old property names to new property names
|
||||||
|
:param new_optional_properties: a list of names of new property names, which, if not present, shall be initialized with None
|
||||||
|
:param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present
|
||||||
|
:param removed_properties: a list of names of properties that are no longer being used
|
||||||
|
"""
|
||||||
|
# handle new/changed properties
|
||||||
|
if renamed_properties is not None:
|
||||||
|
for mOld, mNew in renamed_properties.items():
|
||||||
|
if mOld in state:
|
||||||
|
state[mNew] = state[mOld]
|
||||||
|
del state[mOld]
|
||||||
|
if new_optional_properties is not None:
|
||||||
|
for mNew in new_optional_properties:
|
||||||
|
if mNew not in state:
|
||||||
|
state[mNew] = None
|
||||||
|
if new_default_properties is not None:
|
||||||
|
for mNew, mValue in new_default_properties.items():
|
||||||
|
if mNew not in state:
|
||||||
|
state[mNew] = mValue
|
||||||
|
if removed_properties is not None:
|
||||||
|
for p in removed_properties:
|
||||||
|
if p in state:
|
||||||
|
del state[p]
|
||||||
|
# call super implementation, if any
|
||||||
|
s = super(cls, obj)
|
||||||
|
if hasattr(s, "__setstate__"):
|
||||||
|
s.__setstate__(state)
|
||||||
|
else:
|
||||||
|
obj.__dict__ = state
|
||||||
|
|
||||||
|
|
||||||
|
def getstate(
|
||||||
|
cls: type,
|
||||||
|
obj: Any,
|
||||||
|
transient_properties: Iterable[str] | None = None,
|
||||||
|
excluded_properties: Iterable[str] | None = None,
|
||||||
|
override_properties: dict[str, Any] | None = None,
|
||||||
|
excluded_default_properties: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Helper function for safe implementations of `__getstate__` in classes, which appropriately handles the cases where
|
||||||
|
a parent class already implements `__getstate__` and where it does not. Call this function whenever you would actually
|
||||||
|
like to call the super-class' implementation.
|
||||||
|
Unfortunately, `__getstate__` is not implemented in `object`, rendering `super().__getstate__()` invalid in the general case.
|
||||||
|
|
||||||
|
:param cls: the class in which you are implementing `__getstate__`
|
||||||
|
:param obj: the instance of `cls`
|
||||||
|
:param transient_properties: transient properties which shall be set to None in serializations
|
||||||
|
:param excluded_properties: properties which shall be completely removed from serializations
|
||||||
|
:param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set;
|
||||||
|
use this to set a fixed value for an existing property or to add a completely new property
|
||||||
|
:param excluded_default_properties: properties which shall be completely removed from serializations, if they are set
|
||||||
|
to the given default value
|
||||||
|
:return: the state dictionary, which may be modified by the receiver
|
||||||
|
"""
|
||||||
|
s = super(cls, obj)
|
||||||
|
d = s.__getstate__() if hasattr(s, "__getstate__") else obj.__dict__
|
||||||
|
d = copy(d)
|
||||||
|
if transient_properties is not None:
|
||||||
|
for p in transient_properties:
|
||||||
|
if p in d:
|
||||||
|
d[p] = None
|
||||||
|
if excluded_properties is not None:
|
||||||
|
for p in excluded_properties:
|
||||||
|
if p in d:
|
||||||
|
del d[p]
|
||||||
|
if override_properties is not None:
|
||||||
|
for k, v in override_properties.items():
|
||||||
|
d[k] = v
|
||||||
|
if excluded_default_properties is not None:
|
||||||
|
for p, v in excluded_default_properties.items():
|
||||||
|
if p in d and d[p] == v:
|
||||||
|
del d[p]
|
||||||
|
return d
|
Loading…
x
Reference in New Issue
Block a user