diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 213b49d..592f091 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -226,6 +226,7 @@ class Collector(object): ] except TypeError: # envpool's action space is not for per-env act_sample = [self._action_space.sample() for _ in ready_env_ids] + act_sample = self.policy.map_action_inverse(act_sample) # type: ignore self.data.update(act=act_sample) else: if no_grad: @@ -451,6 +452,7 @@ class AsyncCollector(Collector): ] except TypeError: # envpool's action space is not for per-env act_sample = [self._action_space.sample() for _ in ready_env_ids] + act_sample = self.policy.map_action_inverse(act_sample) # type: ignore self.data.update(act=act_sample) else: if no_grad: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3572e83..fd055ab 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import gym import numpy as np @@ -178,6 +178,33 @@ class BasePolicy(ABC, nn.Module): act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore return act + def map_action_inverse( + self, act: Union[Batch, List, np.ndarray] + ) -> Union[Batch, List, np.ndarray]: + """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. + + This function is called in :meth:`~tianshou.data.Collector.collect` for + random initial steps. It scales [action_space.low, action_space.high] to + the value ranges of policy.forward. + + :param act: a data batch, list or numpy.ndarray which is the action taken + by gym.spaces.Box.sample(). + + :return: action remapped. + """ + if isinstance(self.action_space, gym.spaces.Box): + act = to_numpy(act) + if isinstance(act, np.ndarray): + if self.action_scaling: + low, high = self.action_space.low, self.action_space.high + scale = high - low + eps = np.finfo(np.float32).eps.item() + scale[scale < eps] += eps + act = (act - low) * 2.0 / scale - 1.0 + if self.action_bound_method == "tanh": + act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore + return act + def process_fn( self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray ) -> Batch: