Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
243 lines
8.9 KiB
Python
243 lines
8.9 KiB
Python
from typing import Any, cast
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tianshou.data import Batch
|
|
from tianshou.data.batch import BatchProtocol
|
|
from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.policy.base import TLearningRateScheduler
|
|
|
|
|
|
class PSRLModel:
|
|
"""Implementation of Posterior Sampling Reinforcement Learning Model.
|
|
|
|
:param trans_count_prior: dirichlet prior (alphas), with shape
|
|
(n_state, n_action, n_state).
|
|
:param rew_mean_prior: means of the normal priors of rewards,
|
|
with shape (n_state, n_action).
|
|
:param rew_std_prior: standard deviations of the normal priors
|
|
of rewards, with shape (n_state, n_action).
|
|
:param discount_factor: in [0, 1].
|
|
:param epsilon: for precision control in value iteration.
|
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
trans_count_prior: np.ndarray,
|
|
rew_mean_prior: np.ndarray,
|
|
rew_std_prior: np.ndarray,
|
|
discount_factor: float,
|
|
epsilon: float,
|
|
) -> None:
|
|
self.trans_count = trans_count_prior
|
|
self.n_state, self.n_action = rew_mean_prior.shape
|
|
self.rew_mean = rew_mean_prior
|
|
self.rew_std = rew_std_prior
|
|
self.rew_square_sum = np.zeros_like(rew_mean_prior)
|
|
self.rew_std_prior = rew_std_prior
|
|
self.discount_factor = discount_factor
|
|
self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight
|
|
self.eps = epsilon
|
|
self.policy: np.ndarray
|
|
self.value = np.zeros(self.n_state)
|
|
self.updated = False
|
|
self.__eps = np.finfo(np.float32).eps.item()
|
|
|
|
def observe(
|
|
self,
|
|
trans_count: np.ndarray,
|
|
rew_sum: np.ndarray,
|
|
rew_square_sum: np.ndarray,
|
|
rew_count: np.ndarray,
|
|
) -> None:
|
|
"""Add data into memory pool.
|
|
|
|
For rewards, we have a normal prior at first. After we observed a
|
|
reward for a given state-action pair, we use the mean value of our
|
|
observations instead of the prior mean as the posterior mean. The
|
|
standard deviations are in inverse proportion to the number of the
|
|
corresponding observations.
|
|
|
|
:param trans_count: the number of observations, with shape
|
|
(n_state, n_action, n_state).
|
|
:param rew_sum: total rewards, with shape
|
|
(n_state, n_action).
|
|
:param rew_square_sum: total rewards' squares, with shape
|
|
(n_state, n_action).
|
|
:param rew_count: the number of rewards, with shape
|
|
(n_state, n_action).
|
|
"""
|
|
self.updated = False
|
|
self.trans_count += trans_count
|
|
sum_count = self.rew_count + rew_count
|
|
self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count
|
|
self.rew_square_sum += rew_square_sum
|
|
raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2
|
|
self.rew_std = np.sqrt(
|
|
1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2),
|
|
)
|
|
self.rew_count = sum_count
|
|
|
|
def sample_trans_prob(self) -> np.ndarray:
|
|
return torch.distributions.Dirichlet(torch.from_numpy(self.trans_count)).sample().numpy()
|
|
|
|
def sample_reward(self) -> np.ndarray:
|
|
return np.random.normal(self.rew_mean, self.rew_std)
|
|
|
|
def solve_policy(self) -> None:
|
|
self.updated = True
|
|
self.policy, self.value = self.value_iteration(
|
|
self.sample_trans_prob(),
|
|
self.sample_reward(),
|
|
self.discount_factor,
|
|
self.eps,
|
|
self.value,
|
|
)
|
|
|
|
@staticmethod
|
|
def value_iteration(
|
|
trans_prob: np.ndarray,
|
|
rew: np.ndarray,
|
|
discount_factor: float,
|
|
eps: float,
|
|
value: np.ndarray,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Value iteration solver for MDPs.
|
|
|
|
:param trans_prob: transition probabilities, with shape
|
|
(n_state, n_action, n_state).
|
|
:param rew: rewards, with shape (n_state, n_action).
|
|
:param eps: for precision control.
|
|
:param discount_factor: in [0, 1].
|
|
:param value: the initialize value of value array, with
|
|
shape (n_state, ).
|
|
|
|
:return: the optimal policy with shape (n_state, ).
|
|
"""
|
|
Q = rew + discount_factor * trans_prob.dot(value)
|
|
new_value = Q.max(axis=1)
|
|
while not np.allclose(new_value, value, eps):
|
|
value = new_value
|
|
Q = rew + discount_factor * trans_prob.dot(value)
|
|
new_value = Q.max(axis=1)
|
|
# this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly
|
|
Q += eps * np.random.randn(*Q.shape)
|
|
return Q.argmax(axis=1), new_value
|
|
|
|
def __call__(
|
|
self,
|
|
obs: np.ndarray,
|
|
state: Any = None,
|
|
info: Any = None,
|
|
) -> np.ndarray:
|
|
if not self.updated:
|
|
self.solve_policy()
|
|
return self.policy[obs]
|
|
|
|
|
|
class PSRLPolicy(BasePolicy):
|
|
"""Implementation of Posterior Sampling Reinforcement Learning.
|
|
|
|
Reference: Strens M. A Bayesian framework for reinforcement learning [C]
|
|
//ICML. 2000, 2000: 943-950.
|
|
|
|
:param trans_count_prior: dirichlet prior (alphas), with shape
|
|
(n_state, n_action, n_state).
|
|
:param rew_mean_prior: means of the normal priors of rewards,
|
|
with shape (n_state, n_action).
|
|
:param rew_std_prior: standard deviations of the normal priors
|
|
of rewards, with shape (n_state, n_action).
|
|
:param action_space: Env's action_space.
|
|
:param discount_factor: in [0, 1].
|
|
:param epsilon: for precision control in value iteration.
|
|
:param add_done_loop: whether to add an extra self-loop for the
|
|
terminal state in MDP. Default to False.
|
|
:param observation_space: Env's observation space.
|
|
:param lr_scheduler: if not None, will be called in `policy.update()`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
trans_count_prior: np.ndarray,
|
|
rew_mean_prior: np.ndarray,
|
|
rew_std_prior: np.ndarray,
|
|
action_space: gym.spaces.Discrete,
|
|
discount_factor: float = 0.99,
|
|
epsilon: float = 0.01,
|
|
add_done_loop: bool = False,
|
|
observation_space: gym.Space | None = None,
|
|
lr_scheduler: TLearningRateScheduler | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
action_space=action_space,
|
|
observation_space=observation_space,
|
|
action_scaling=False,
|
|
action_bound_method=None,
|
|
lr_scheduler=lr_scheduler,
|
|
)
|
|
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
|
self.model = PSRLModel(
|
|
trans_count_prior,
|
|
rew_mean_prior,
|
|
rew_std_prior,
|
|
discount_factor,
|
|
epsilon,
|
|
)
|
|
self._add_done_loop = add_done_loop
|
|
|
|
def forward(
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
state: dict | BatchProtocol | np.ndarray | None = None,
|
|
**kwargs: Any,
|
|
) -> ActBatchProtocol:
|
|
"""Compute action over the given batch data with PSRL model.
|
|
|
|
:return: A :class:`~tianshou.data.Batch` with "act" key containing
|
|
the action.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
|
more detailed explanation.
|
|
"""
|
|
assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation"
|
|
act = self.model(batch.obs, state=state, info=batch.info)
|
|
result = Batch(act=act)
|
|
return cast(ActBatchProtocol, result)
|
|
|
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
|
|
n_s, n_a = self.model.n_state, self.model.n_action
|
|
trans_count = np.zeros((n_s, n_a, n_s))
|
|
rew_sum = np.zeros((n_s, n_a))
|
|
rew_square_sum = np.zeros((n_s, n_a))
|
|
rew_count = np.zeros((n_s, n_a))
|
|
for minibatch in batch.split(size=1):
|
|
obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
|
|
obs_next = cast(np.ndarray, obs_next)
|
|
assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here"
|
|
trans_count[obs, act, obs_next] += 1
|
|
rew_sum[obs, act] += minibatch.rew
|
|
rew_square_sum[obs, act] += minibatch.rew**2
|
|
rew_count[obs, act] += 1
|
|
if self._add_done_loop and minibatch.done:
|
|
# special operation for terminal states: add a self-loop
|
|
trans_count[obs_next, :, obs_next] += 1
|
|
rew_count[obs_next, :] += 1
|
|
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
|
|
return {
|
|
"psrl/rew_mean": float(self.model.rew_mean.mean()),
|
|
"psrl/rew_std": float(self.model.rew_std.mean()),
|
|
}
|