add PSRL policy (#202)

Add PSRL policy in tianshou/policy/modelbase/psrl.py.

Co-authored-by: n+e <trinkle23897@cmu.edu>
This commit is contained in:
Yao Feng 2020-09-23 20:57:33 +08:00 committed by GitHub
parent bf39b9ef7d
commit dcfcbb37f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 329 additions and 0 deletions

View File

@ -31,6 +31,7 @@
- Vanilla Imitation Learning
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
Here is Tianshou's other features:

View File

@ -19,6 +19,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

View File

@ -0,0 +1,7 @@
# PSRL
`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1`
`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20`
`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20`

1
examples/modelbase/psrl.py Symbolic link
View File

@ -0,0 +1 @@
../../test/modelbase/test_psrl.py

View File

View File

@ -0,0 +1,97 @@
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='NChain-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=50000)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=5)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--training-num', type=int, default=1)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.0)
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
parser.add_argument('--rew-std-prior', type=float, default=1.0)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--eps', type=float, default=0.01)
parser.add_argument('--add-done-loop', action='store_true')
return parser.parse_known_args()[0]
def test_psrl(args=get_args()):
env = gym.make(args.task)
if args.task == "NChain-v0":
env.spec.reward_threshold = 3647 # described in PSRL paper
print("reward threshold:", env.spec.reward_threshold)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
n_action = args.action_shape
n_state = args.state_shape
trans_count_prior = np.ones((n_state, n_action, n_state))
rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior)
rew_std_prior = np.full((n_state, n_action), args.rew_std_prior)
policy = PSRLPolicy(
trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps,
args.add_done_loop)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + args.task)
def stop_fn(x):
if env.spec.reward_threshold:
return x >= env.spec.reward_threshold
else:
return False
train_collector.collect(n_step=args.buffer_size, random=True)
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, 1,
args.test_num, 0, stop_fn=stop_fn, writer=writer,
test_in_train=False)
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
elif env.spec.reward_threshold:
assert result["best_reward"] >= env.spec.reward_threshold
if __name__ == '__main__':
test_psrl()

View File

@ -9,6 +9,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.modelbase.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
@ -24,5 +25,6 @@ __all__ = [
"TD3Policy",
"SACPolicy",
"DiscreteSACPolicy",
"PSRLPolicy",
"MultiAgentPolicyManager",
]

View File

View File

@ -0,0 +1,220 @@
import torch
import numpy as np
from typing import Any, Dict, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class PSRLModel(object):
"""Implementation of Posterior Sampling Reinforcement Learning Model.
:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
(n_state, n_action, n_state).
:param np.ndarray rew_mean_prior: means of the normal priors of rewards,
with shape (n_state, n_action).
:param np.ndarray rew_std_prior: standard deviations of the normal priors
of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration.
"""
def __init__(
self,
trans_count_prior: np.ndarray,
rew_mean_prior: np.ndarray,
rew_std_prior: np.ndarray,
discount_factor: float,
epsilon: float,
) -> None:
self.trans_count = trans_count_prior
self.n_state, self.n_action = rew_mean_prior.shape
self.rew_mean = rew_mean_prior
self.rew_std = rew_std_prior
self.rew_square_sum = np.zeros_like(rew_mean_prior)
self.rew_std_prior = rew_std_prior
self.discount_factor = discount_factor
self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight
self.eps = epsilon
self.policy: np.ndarray
self.value = np.zeros(self.n_state)
self.updated = False
self.__eps = np.finfo(np.float32).eps.item()
def observe(
self,
trans_count: np.ndarray,
rew_sum: np.ndarray,
rew_square_sum: np.ndarray,
rew_count: np.ndarray,
) -> None:
"""Add data into memory pool.
For rewards, we have a normal prior at first. After we observed a
reward for a given state-action pair, we use the mean value of our
observations instead of the prior mean as the posterior mean. The
standard deviations are in inverse proportion to the number of the
corresponding observations.
:param np.ndarray trans_count: the number of observations, with shape
(n_state, n_action, n_state).
:param np.ndarray rew_sum: total rewards, with shape
(n_state, n_action).
:param np.ndarray rew_square_sum: total rewards' squares, with shape
(n_state, n_action).
:param np.ndarray rew_count: the number of rewards, with shape
(n_state, n_action).
"""
self.updated = False
self.trans_count += trans_count
sum_count = self.rew_count + rew_count
self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count
self.rew_square_sum += rew_square_sum
raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2
self.rew_std = np.sqrt(1 / (
sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2))
self.rew_count = sum_count
def sample_trans_prob(self) -> np.ndarray:
sample_prob = torch.distributions.Dirichlet(
torch.from_numpy(self.trans_count)).sample().numpy()
return sample_prob
def sample_reward(self) -> np.ndarray:
return np.random.normal(self.rew_mean, self.rew_std)
def solve_policy(self) -> None:
self.updated = True
self.policy, self.value = self.value_iteration(
self.sample_trans_prob(),
self.sample_reward(),
self.discount_factor,
self.eps,
self.value,
)
@staticmethod
def value_iteration(
trans_prob: np.ndarray,
rew: np.ndarray,
discount_factor: float,
eps: float,
value: np.ndarray,
) -> np.ndarray:
"""Value iteration solver for MDPs.
:param np.ndarray trans_prob: transition probabilities, with shape
(n_state, n_action, n_state).
:param np.ndarray rew: rewards, with shape (n_state, n_action).
:param float eps: for precision control.
:param float discount_factor: in [0, 1].
:param np.ndarray value: the initialize value of value array, with
shape (n_state, ).
:return: the optimal policy with shape (n_state, ).
"""
Q = rew + discount_factor * trans_prob.dot(value)
new_value = Q.max(axis=1)
while not np.allclose(new_value, value, eps):
value = new_value
Q = rew + discount_factor * trans_prob.dot(value)
new_value = Q.max(axis=1)
# this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly
Q += eps * np.random.randn(*Q.shape)
return Q.argmax(axis=1), new_value
def __call__(
self,
obs: np.ndarray,
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> np.ndarray:
if not self.updated:
self.solve_policy()
return self.policy[obs]
class PSRLPolicy(BasePolicy):
"""Implementation of Posterior Sampling Reinforcement Learning.
Reference: Strens M. A Bayesian framework for reinforcement learning [C]
//ICML. 2000, 2000: 943-950.
:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
(n_state, n_action, n_state).
:param np.ndarray rew_mean_prior: means of the normal priors of rewards,
with shape (n_state, n_action).
:param np.ndarray rew_std_prior: standard deviations of the normal priors
of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration.
:param bool add_done_loop: whether to add an extra self-loop for the
terminal state in MDP, defaults to False.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
trans_count_prior: np.ndarray,
rew_mean_prior: np.ndarray,
rew_std_prior: np.ndarray,
discount_factor: float = 0.99,
epsilon: float = 0.01,
add_done_loop: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
assert (
0.0 <= discount_factor <= 1.0
), "discount factor should be in [0, 1]"
self.model = PSRLModel(
trans_count_prior, rew_mean_prior, rew_std_prior,
discount_factor, epsilon)
self._add_done_loop = add_done_loop
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data with PSRL model.
:return: A :class:`~tianshou.data.Batch` with "act" key containing
the action.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
act = self.model(batch.obs, state=state, info=batch.info)
return Batch(act=act)
def learn(
self, batch: Batch, *args: Any, **kwargs: Any
) -> Dict[str, float]:
n_s, n_a = self.model.n_state, self.model.n_action
trans_count = np.zeros((n_s, n_a, n_s))
rew_sum = np.zeros((n_s, n_a))
rew_square_sum = np.zeros((n_s, n_a))
rew_count = np.zeros((n_s, n_a))
for b in batch.split(size=1):
obs, act, obs_next = b.obs, b.act, b.obs_next
trans_count[obs, act, obs_next] += 1
rew_sum[obs, act] += b.rew
rew_square_sum[obs, act] += b.rew ** 2
rew_count[obs, act] += 1
if self._add_done_loop and b.done:
# special operation for terminal states: add a self-loop
trans_count[obs_next, :, obs_next] += 1
rew_count[obs_next, :] += 1
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
return {
"psrl/rew_mean": self.model.rew_mean.mean(),
"psrl/rew_std": self.model.rew_std.mean(),
}