518 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			518 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import warnings
 | 
						|
from collections.abc import Sequence
 | 
						|
from typing import Any
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from torch import nn
 | 
						|
 | 
						|
from tianshou.utils.net.common import (
 | 
						|
    MLP,
 | 
						|
    BaseActor,
 | 
						|
    TActionShape,
 | 
						|
    TLinearLayer,
 | 
						|
    get_output_dim,
 | 
						|
)
 | 
						|
 | 
						|
SIGMA_MIN = -20
 | 
						|
SIGMA_MAX = 2
 | 
						|
 | 
						|
 | 
						|
class Actor(BaseActor):
 | 
						|
    """Simple actor network.
 | 
						|
 | 
						|
    It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
 | 
						|
 | 
						|
    :param preprocess_net: a self-defined preprocess_net which output a
 | 
						|
        flattened hidden state.
 | 
						|
    :param action_shape: a sequence of int for the shape of action.
 | 
						|
    :param hidden_sizes: a sequence of int for constructing the MLP after
 | 
						|
        preprocess_net. Default to empty sequence (where the MLP now contains
 | 
						|
        only a single linear layer).
 | 
						|
    :param max_action: the scale for the final action logits. Default to
 | 
						|
        1.
 | 
						|
    :param preprocess_net_output_dim: the output dimension of
 | 
						|
        preprocess_net.
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
 | 
						|
    .. seealso::
 | 
						|
 | 
						|
        Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
 | 
						|
        of how preprocess_net is suggested to be defined.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        preprocess_net: nn.Module,
 | 
						|
        action_shape: TActionShape,
 | 
						|
        hidden_sizes: Sequence[int] = (),
 | 
						|
        max_action: float = 1.0,
 | 
						|
        device: str | int | torch.device = "cpu",
 | 
						|
        preprocess_net_output_dim: int | None = None,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        self.device = device
 | 
						|
        self.preprocess = preprocess_net
 | 
						|
        self.output_dim = int(np.prod(action_shape))
 | 
						|
        input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
 | 
						|
        self.last = MLP(
 | 
						|
            input_dim,
 | 
						|
            self.output_dim,
 | 
						|
            hidden_sizes,
 | 
						|
            device=self.device,
 | 
						|
        )
 | 
						|
        self.max_action = max_action
 | 
						|
 | 
						|
    def get_preprocess_net(self) -> nn.Module:
 | 
						|
        return self.preprocess
 | 
						|
 | 
						|
    def get_output_dim(self) -> int:
 | 
						|
        return self.output_dim
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        obs: np.ndarray | torch.Tensor,
 | 
						|
        state: Any = None,
 | 
						|
        info: dict[str, Any] | None = None,
 | 
						|
    ) -> tuple[torch.Tensor, Any]:
 | 
						|
        """Mapping: obs -> logits -> action."""
 | 
						|
        if info is None:
 | 
						|
            info = {}
 | 
						|
        logits, hidden = self.preprocess(obs, state)
 | 
						|
        logits = self.max_action * torch.tanh(self.last(logits))
 | 
						|
        return logits, hidden
 | 
						|
 | 
						|
 | 
						|
class Critic(nn.Module):
 | 
						|
    """Simple critic network.
 | 
						|
 | 
						|
    It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
 | 
						|
 | 
						|
    :param preprocess_net: a self-defined preprocess_net which output a
 | 
						|
        flattened hidden state.
 | 
						|
    :param hidden_sizes: a sequence of int for constructing the MLP after
 | 
						|
        preprocess_net. Default to empty sequence (where the MLP now contains
 | 
						|
        only a single linear layer).
 | 
						|
    :param preprocess_net_output_dim: the output dimension of
 | 
						|
        preprocess_net.
 | 
						|
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
 | 
						|
    :param flatten_input: whether to flatten input data for the last layer.
 | 
						|
        Default to True.
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
 | 
						|
    .. seealso::
 | 
						|
 | 
						|
        Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
 | 
						|
        of how preprocess_net is suggested to be defined.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        preprocess_net: nn.Module,
 | 
						|
        hidden_sizes: Sequence[int] = (),
 | 
						|
        device: str | int | torch.device = "cpu",
 | 
						|
        preprocess_net_output_dim: int | None = None,
 | 
						|
        linear_layer: TLinearLayer = nn.Linear,
 | 
						|
        flatten_input: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        self.device = device
 | 
						|
        self.preprocess = preprocess_net
 | 
						|
        self.output_dim = 1
 | 
						|
        input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
 | 
						|
        self.last = MLP(
 | 
						|
            input_dim,
 | 
						|
            1,
 | 
						|
            hidden_sizes,
 | 
						|
            device=self.device,
 | 
						|
            linear_layer=linear_layer,
 | 
						|
            flatten_input=flatten_input,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        obs: np.ndarray | torch.Tensor,
 | 
						|
        act: np.ndarray | torch.Tensor | None = None,
 | 
						|
        info: dict[str, Any] | None = None,
 | 
						|
    ) -> torch.Tensor:
 | 
						|
        """Mapping: (s, a) -> logits -> Q(s, a)."""
 | 
						|
        if info is None:
 | 
						|
            info = {}
 | 
						|
        obs = torch.as_tensor(
 | 
						|
            obs,
 | 
						|
            device=self.device,
 | 
						|
            dtype=torch.float32,
 | 
						|
        ).flatten(1)
 | 
						|
        if act is not None:
 | 
						|
            act = torch.as_tensor(
 | 
						|
                act,
 | 
						|
                device=self.device,
 | 
						|
                dtype=torch.float32,
 | 
						|
            ).flatten(1)
 | 
						|
            obs = torch.cat([obs, act], dim=1)
 | 
						|
        logits, hidden = self.preprocess(obs)
 | 
						|
        return self.last(logits)
 | 
						|
 | 
						|
 | 
						|
class ActorProb(BaseActor):
 | 
						|
    """Simple actor network (output with a Gauss distribution).
 | 
						|
 | 
						|
    :param preprocess_net: a self-defined preprocess_net which output a
 | 
						|
        flattened hidden state.
 | 
						|
    :param action_shape: a sequence of int for the shape of action.
 | 
						|
    :param hidden_sizes: a sequence of int for constructing the MLP after
 | 
						|
        preprocess_net. Default to empty sequence (where the MLP now contains
 | 
						|
        only a single linear layer).
 | 
						|
    :param max_action: the scale for the final action logits. Default to
 | 
						|
        1.
 | 
						|
    :param unbounded: whether to apply tanh activation on final logits.
 | 
						|
        Default to False.
 | 
						|
    :param conditioned_sigma: True when sigma is calculated from the
 | 
						|
        input, False when sigma is an independent parameter. Default to False.
 | 
						|
    :param preprocess_net_output_dim: the output dimension of
 | 
						|
        preprocess_net.
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
 | 
						|
    .. seealso::
 | 
						|
 | 
						|
        Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
 | 
						|
        of how preprocess_net is suggested to be defined.
 | 
						|
    """
 | 
						|
 | 
						|
    # TODO: force kwargs, adjust downstream code
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        preprocess_net: nn.Module,
 | 
						|
        action_shape: TActionShape,
 | 
						|
        hidden_sizes: Sequence[int] = (),
 | 
						|
        max_action: float = 1.0,
 | 
						|
        device: str | int | torch.device = "cpu",
 | 
						|
        unbounded: bool = False,
 | 
						|
        conditioned_sigma: bool = False,
 | 
						|
        preprocess_net_output_dim: int | None = None,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        if unbounded and not np.isclose(max_action, 1.0):
 | 
						|
            warnings.warn("Note that max_action input will be discarded when unbounded is True.")
 | 
						|
            max_action = 1.0
 | 
						|
        self.preprocess = preprocess_net
 | 
						|
        self.device = device
 | 
						|
        self.output_dim = int(np.prod(action_shape))
 | 
						|
        input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
 | 
						|
        self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
 | 
						|
        self._c_sigma = conditioned_sigma
 | 
						|
        if conditioned_sigma:
 | 
						|
            self.sigma = MLP(
 | 
						|
                input_dim,
 | 
						|
                self.output_dim,
 | 
						|
                hidden_sizes,
 | 
						|
                device=self.device,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
 | 
						|
        self.max_action = max_action
 | 
						|
        self._unbounded = unbounded
 | 
						|
 | 
						|
    def get_preprocess_net(self) -> nn.Module:
 | 
						|
        return self.preprocess
 | 
						|
 | 
						|
    def get_output_dim(self) -> int:
 | 
						|
        return self.output_dim
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        obs: np.ndarray | torch.Tensor,
 | 
						|
        state: Any = None,
 | 
						|
        info: dict[str, Any] | None = None,
 | 
						|
    ) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
 | 
						|
        """Mapping: obs -> logits -> (mu, sigma)."""
 | 
						|
        if info is None:
 | 
						|
            info = {}
 | 
						|
        logits, hidden = self.preprocess(obs, state)
 | 
						|
        mu = self.mu(logits)
 | 
						|
        if not self._unbounded:
 | 
						|
            mu = self.max_action * torch.tanh(mu)
 | 
						|
        if self._c_sigma:
 | 
						|
            sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
 | 
						|
        else:
 | 
						|
            shape = [1] * len(mu.shape)
 | 
						|
            shape[1] = -1
 | 
						|
            sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
 | 
						|
        return (mu, sigma), state
 | 
						|
 | 
						|
 | 
						|
class RecurrentActorProb(nn.Module):
 | 
						|
    """Recurrent version of ActorProb.
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        layer_num: int,
 | 
						|
        state_shape: Sequence[int],
 | 
						|
        action_shape: Sequence[int],
 | 
						|
        hidden_layer_size: int = 128,
 | 
						|
        max_action: float = 1.0,
 | 
						|
        device: str | int | torch.device = "cpu",
 | 
						|
        unbounded: bool = False,
 | 
						|
        conditioned_sigma: bool = False,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        if unbounded and not np.isclose(max_action, 1.0):
 | 
						|
            warnings.warn("Note that max_action input will be discarded when unbounded is True.")
 | 
						|
            max_action = 1.0
 | 
						|
        self.device = device
 | 
						|
        self.nn = nn.LSTM(
 | 
						|
            input_size=int(np.prod(state_shape)),
 | 
						|
            hidden_size=hidden_layer_size,
 | 
						|
            num_layers=layer_num,
 | 
						|
            batch_first=True,
 | 
						|
        )
 | 
						|
        output_dim = int(np.prod(action_shape))
 | 
						|
        self.mu = nn.Linear(hidden_layer_size, output_dim)
 | 
						|
        self._c_sigma = conditioned_sigma
 | 
						|
        if conditioned_sigma:
 | 
						|
            self.sigma = nn.Linear(hidden_layer_size, output_dim)
 | 
						|
        else:
 | 
						|
            self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
 | 
						|
        self.max_action = max_action
 | 
						|
        self._unbounded = unbounded
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        obs: np.ndarray | torch.Tensor,
 | 
						|
        state: dict[str, torch.Tensor] | None = None,
 | 
						|
        info: dict[str, Any] | None = None,
 | 
						|
    ) -> tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]:
 | 
						|
        """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
 | 
						|
        if info is None:
 | 
						|
            info = {}
 | 
						|
        obs = torch.as_tensor(
 | 
						|
            obs,
 | 
						|
            device=self.device,
 | 
						|
            dtype=torch.float32,
 | 
						|
        )
 | 
						|
        # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
 | 
						|
        # In short, the tensor's shape in training phase is longer than which
 | 
						|
        # in evaluation phase.
 | 
						|
        if len(obs.shape) == 2:
 | 
						|
            obs = obs.unsqueeze(-2)
 | 
						|
        self.nn.flatten_parameters()
 | 
						|
        if state is None:
 | 
						|
            obs, (hidden, cell) = self.nn(obs)
 | 
						|
        else:
 | 
						|
            # we store the stack data in [bsz, len, ...] format
 | 
						|
            # but pytorch rnn needs [len, bsz, ...]
 | 
						|
            obs, (hidden, cell) = self.nn(
 | 
						|
                obs,
 | 
						|
                (
 | 
						|
                    state["hidden"].transpose(0, 1).contiguous(),
 | 
						|
                    state["cell"].transpose(0, 1).contiguous(),
 | 
						|
                ),
 | 
						|
            )
 | 
						|
        logits = obs[:, -1]
 | 
						|
        mu = self.mu(logits)
 | 
						|
        if not self._unbounded:
 | 
						|
            mu = self.max_action * torch.tanh(mu)
 | 
						|
        if self._c_sigma:
 | 
						|
            sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
 | 
						|
        else:
 | 
						|
            shape = [1] * len(mu.shape)
 | 
						|
            shape[1] = -1
 | 
						|
            sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
 | 
						|
        # please ensure the first dim is batch size: [bsz, len, ...]
 | 
						|
        return (mu, sigma), {
 | 
						|
            "hidden": hidden.transpose(0, 1).detach(),
 | 
						|
            "cell": cell.transpose(0, 1).detach(),
 | 
						|
        }
 | 
						|
 | 
						|
 | 
						|
class RecurrentCritic(nn.Module):
 | 
						|
    """Recurrent version of Critic.
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        layer_num: int,
 | 
						|
        state_shape: Sequence[int],
 | 
						|
        action_shape: Sequence[int] = [0],
 | 
						|
        device: str | int | torch.device = "cpu",
 | 
						|
        hidden_layer_size: int = 128,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
        self.state_shape = state_shape
 | 
						|
        self.action_shape = action_shape
 | 
						|
        self.device = device
 | 
						|
        self.nn = nn.LSTM(
 | 
						|
            input_size=int(np.prod(state_shape)),
 | 
						|
            hidden_size=hidden_layer_size,
 | 
						|
            num_layers=layer_num,
 | 
						|
            batch_first=True,
 | 
						|
        )
 | 
						|
        self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        obs: np.ndarray | torch.Tensor,
 | 
						|
        act: np.ndarray | torch.Tensor | None = None,
 | 
						|
        info: dict[str, Any] | None = None,
 | 
						|
    ) -> torch.Tensor:
 | 
						|
        """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
 | 
						|
        if info is None:
 | 
						|
            info = {}
 | 
						|
        obs = torch.as_tensor(
 | 
						|
            obs,
 | 
						|
            device=self.device,
 | 
						|
            dtype=torch.float32,
 | 
						|
        )
 | 
						|
        # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
 | 
						|
        # In short, the tensor's shape in training phase is longer than which
 | 
						|
        # in evaluation phase.
 | 
						|
        assert len(obs.shape) == 3
 | 
						|
        self.nn.flatten_parameters()
 | 
						|
        obs, (hidden, cell) = self.nn(obs)
 | 
						|
        obs = obs[:, -1]
 | 
						|
        if act is not None:
 | 
						|
            act = torch.as_tensor(
 | 
						|
                act,
 | 
						|
                device=self.device,
 | 
						|
                dtype=torch.float32,
 | 
						|
            )
 | 
						|
            obs = torch.cat([obs, act], dim=1)
 | 
						|
        return self.fc2(obs)
 | 
						|
 | 
						|
 | 
						|
class Perturbation(nn.Module):
 | 
						|
    """Implementation of perturbation network in BCQ algorithm.
 | 
						|
 | 
						|
    Given a state and action, it can generate perturbed action.
 | 
						|
 | 
						|
    :param preprocess_net: a self-defined preprocess_net which output a
 | 
						|
        flattened hidden state.
 | 
						|
    :param max_action: the maximum value of each dimension of action.
 | 
						|
    :param device: which device to create this model on.
 | 
						|
        Default to cpu.
 | 
						|
    :param phi: max perturbation parameter for BCQ. Default to 0.05.
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
 | 
						|
    .. seealso::
 | 
						|
 | 
						|
        You can refer to `examples/offline/offline_bcq.py` to see how to use it.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        preprocess_net: nn.Module,
 | 
						|
        max_action: float,
 | 
						|
        device: str | int | torch.device = "cpu",
 | 
						|
        phi: float = 0.05,
 | 
						|
    ):
 | 
						|
        # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
 | 
						|
        super().__init__()
 | 
						|
        self.preprocess_net = preprocess_net
 | 
						|
        self.device = device
 | 
						|
        self.max_action = max_action
 | 
						|
        self.phi = phi
 | 
						|
 | 
						|
    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
 | 
						|
        # preprocess_net
 | 
						|
        logits = self.preprocess_net(torch.cat([state, action], -1))[0]
 | 
						|
        noise = self.phi * self.max_action * torch.tanh(logits)
 | 
						|
        # clip to [-max_action, max_action]
 | 
						|
        return (noise + action).clamp(-self.max_action, self.max_action)
 | 
						|
 | 
						|
 | 
						|
class VAE(nn.Module):
 | 
						|
    """Implementation of VAE.
 | 
						|
 | 
						|
    It models the distribution of action. Given a state, it can generate actions similar to those in batch.
 | 
						|
    It is used in BCQ algorithm.
 | 
						|
 | 
						|
    :param encoder: the encoder in VAE. Its input_dim must be
 | 
						|
        state_dim + action_dim, and output_dim must be hidden_dim.
 | 
						|
    :param decoder: the decoder in VAE. Its input_dim must be
 | 
						|
        state_dim + latent_dim, and output_dim must be action_dim.
 | 
						|
    :param hidden_dim: the size of the last linear-layer in encoder.
 | 
						|
    :param latent_dim: the size of latent layer.
 | 
						|
    :param max_action: the maximum value of each dimension of action.
 | 
						|
    :param device: which device to create this model on.
 | 
						|
        Default to "cpu".
 | 
						|
 | 
						|
    For advanced usage (how to customize the network), please refer to
 | 
						|
    :ref:`build_the_network`.
 | 
						|
 | 
						|
    .. seealso::
 | 
						|
 | 
						|
        You can refer to `examples/offline/offline_bcq.py` to see how to use it.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        encoder: nn.Module,
 | 
						|
        decoder: nn.Module,
 | 
						|
        hidden_dim: int,
 | 
						|
        latent_dim: int,
 | 
						|
        max_action: float,
 | 
						|
        device: str | torch.device = "cpu",
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.encoder = encoder
 | 
						|
 | 
						|
        self.mean = nn.Linear(hidden_dim, latent_dim)
 | 
						|
        self.log_std = nn.Linear(hidden_dim, latent_dim)
 | 
						|
 | 
						|
        self.decoder = decoder
 | 
						|
 | 
						|
        self.max_action = max_action
 | 
						|
        self.latent_dim = latent_dim
 | 
						|
        self.device = device
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        state: torch.Tensor,
 | 
						|
        action: torch.Tensor,
 | 
						|
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
						|
        # [state, action] -> z , [state, z] -> action
 | 
						|
        latent_z = self.encoder(torch.cat([state, action], -1))
 | 
						|
        # shape of z: (state.shape[:-1], hidden_dim)
 | 
						|
 | 
						|
        mean = self.mean(latent_z)
 | 
						|
        # Clamped for numerical stability
 | 
						|
        log_std = self.log_std(latent_z).clamp(-4, 15)
 | 
						|
        std = torch.exp(log_std)
 | 
						|
        # shape of mean, std: (state.shape[:-1], latent_dim)
 | 
						|
 | 
						|
        latent_z = mean + std * torch.randn_like(std)  # (state.shape[:-1], latent_dim)
 | 
						|
 | 
						|
        reconstruction = self.decode(state, latent_z)  # (state.shape[:-1], action_dim)
 | 
						|
        return reconstruction, mean, std
 | 
						|
 | 
						|
    def decode(
 | 
						|
        self,
 | 
						|
        state: torch.Tensor,
 | 
						|
        latent_z: torch.Tensor | None = None,
 | 
						|
    ) -> torch.Tensor:
 | 
						|
        # decode(state) -> action
 | 
						|
        if latent_z is None:
 | 
						|
            # state.shape[0] may be batch_size
 | 
						|
            # latent vector clipped to [-0.5, 0.5]
 | 
						|
            latent_z = (
 | 
						|
                torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5)
 | 
						|
            )
 | 
						|
 | 
						|
        # decode z with state!
 | 
						|
        return self.max_action * torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))
 |