Add map_action_inverse for fixing error of storing random action (#568)
(Issue #512) Random start in Collector sample actions from the action space, while policies output action in a range (typically [-1, 1]) and map action to the action space. The buffer only stores unmapped actions, so the actions randomly initialized are not correct when the action range is not [-1, 1]. This may influence policy learning and particularly model learning in model-based methods. This PR fixes it by adding an inverse operation before adding random initial actions to the buffer.
This commit is contained in:
parent
9cb74e60c9
commit
39f8391cfb
@ -226,6 +226,7 @@ class Collector(object):
|
||||
]
|
||||
except TypeError: # envpool's action space is not for per-env
|
||||
act_sample = [self._action_space.sample() for _ in ready_env_ids]
|
||||
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
|
||||
self.data.update(act=act_sample)
|
||||
else:
|
||||
if no_grad:
|
||||
@ -451,6 +452,7 @@ class AsyncCollector(Collector):
|
||||
]
|
||||
except TypeError: # envpool's action space is not for per-env
|
||||
act_sample = [self._action_space.sample() for _ in ready_env_ids]
|
||||
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
|
||||
self.data.update(act=act_sample)
|
||||
else:
|
||||
if no_grad:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -178,6 +178,33 @@ class BasePolicy(ABC, nn.Module):
|
||||
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
|
||||
return act
|
||||
|
||||
def map_action_inverse(
|
||||
self, act: Union[Batch, List, np.ndarray]
|
||||
) -> Union[Batch, List, np.ndarray]:
|
||||
"""Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`.
|
||||
|
||||
This function is called in :meth:`~tianshou.data.Collector.collect` for
|
||||
random initial steps. It scales [action_space.low, action_space.high] to
|
||||
the value ranges of policy.forward.
|
||||
|
||||
:param act: a data batch, list or numpy.ndarray which is the action taken
|
||||
by gym.spaces.Box.sample().
|
||||
|
||||
:return: action remapped.
|
||||
"""
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
act = to_numpy(act)
|
||||
if isinstance(act, np.ndarray):
|
||||
if self.action_scaling:
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
scale = high - low
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
scale[scale < eps] += eps
|
||||
act = (act - low) * 2.0 / scale - 1.0
|
||||
if self.action_bound_method == "tanh":
|
||||
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
|
||||
return act
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
) -> Batch:
|
||||
|
Loading…
x
Reference in New Issue
Block a user