41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
|
import numpy as np
|
||
|
from typing import Union, Optional, Dict, List
|
||
|
|
||
|
from tianshou.data import Batch
|
||
|
from tianshou.policy import BasePolicy
|
||
|
|
||
|
|
||
|
class RandomPolicy(BasePolicy):
|
||
|
"""A random agent used in multi-agent learning. It randomly chooses an
|
||
|
action from the legal action.
|
||
|
"""
|
||
|
|
||
|
def forward(self, batch: Batch,
|
||
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||
|
**kwargs) -> Batch:
|
||
|
"""Compute the random action over the given batch data. The input
|
||
|
should contain a mask in batch.obs, with "True" to be available and
|
||
|
"False" to be unavailable.
|
||
|
For example, ``batch.obs.mask == np.array([[False, True, False]])``
|
||
|
means with batch size 1, action "1" is available but action "0" and
|
||
|
"2" are unavailable.
|
||
|
|
||
|
:return: A :class:`~tianshou.data.Batch` with "act" key, containing
|
||
|
the random action.
|
||
|
|
||
|
.. seealso::
|
||
|
|
||
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||
|
more detailed explanation.
|
||
|
"""
|
||
|
mask = batch.obs.mask
|
||
|
logits = np.random.rand(*mask.shape)
|
||
|
logits[~mask] = -np.inf
|
||
|
return Batch(act=logits.argmax(axis=-1))
|
||
|
|
||
|
def learn(self, batch: Batch, **kwargs
|
||
|
) -> Dict[str, Union[float, List[float]]]:
|
||
|
"""No need of a learn function for a random agent, so it returns an
|
||
|
empty dict."""
|
||
|
return {}
|