Michael Panchenko 600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00

47 lines
1.6 KiB
Python

from typing import Any, Optional, Union, cast
import numpy as np
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
class RandomPolicy(BasePolicy):
"""A random agent used in multi-agent learning.
It randomly chooses an action from the legal action.
"""
def forward(
self,
batch: RolloutBatchProtocol,
state: Optional[Union[dict, BatchProtocol, np.ndarray]] = None,
**kwargs: Any,
) -> ActBatchProtocol:
"""Compute the random action over the given batch data.
The input should contain a mask in batch.obs, with "True" to be
available and "False" to be unavailable. For example,
``batch.obs.mask == np.array([[False, True, False]])`` means with batch
size 1, action "1" is available but action "0" and "2" are unavailable.
:return: A :class:`~tianshou.data.Batch` with "act" key, containing
the random action.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
mask = batch.obs.mask # type: ignore
logits = np.random.rand(*mask.shape)
logits[~mask] = -np.inf
result = Batch(act=logits.argmax(axis=-1))
return cast(ActBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
"""Since a random agent learns nothing, it returns an empty dict."""
return {}