Add map_action_inverse for fixing error of storing random action (#568)

(Issue #512) Random start in Collector sample actions from the action space, while policies output action in a range (typically [-1, 1]) and map action to the action space. The buffer only stores unmapped actions, so the actions randomly initialized are not correct when the action range is not [-1, 1]. This may influence policy learning and particularly model learning in model-based methods.

This PR fixes it by adding an inverse operation before adding random initial actions to the buffer.
This commit is contained in:
Minhui Li 2022-03-12 22:26:00 +08:00 committed by GitHub
parent 9cb74e60c9
commit 39f8391cfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 1 deletions

View File

@ -226,6 +226,7 @@ class Collector(object):
]
except TypeError: # envpool's action space is not for per-env
act_sample = [self._action_space.sample() for _ in ready_env_ids]
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
self.data.update(act=act_sample)
else:
if no_grad:
@ -451,6 +452,7 @@ class AsyncCollector(Collector):
]
except TypeError: # envpool's action space is not for per-env
act_sample = [self._action_space.sample() for _ in ready_env_ids]
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
self.data.update(act=act_sample)
else:
if no_grad:

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
@ -178,6 +178,33 @@ class BasePolicy(ABC, nn.Module):
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
return act
def map_action_inverse(
self, act: Union[Batch, List, np.ndarray]
) -> Union[Batch, List, np.ndarray]:
"""Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`.
This function is called in :meth:`~tianshou.data.Collector.collect` for
random initial steps. It scales [action_space.low, action_space.high] to
the value ranges of policy.forward.
:param act: a data batch, list or numpy.ndarray which is the action taken
by gym.spaces.Box.sample().
:return: action remapped.
"""
if isinstance(self.action_space, gym.spaces.Box):
act = to_numpy(act)
if isinstance(act, np.ndarray):
if self.action_scaling:
low, high = self.action_space.low, self.action_space.high
scale = high - low
eps = np.finfo(np.float32).eps.item()
scale[scale < eps] += eps
act = (act - low) * 2.0 / scale - 1.0
if self.action_bound_method == "tanh":
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
return act
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch: