diff --git a/README.md b/README.md index c68fbf4..636b487 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ - Vanilla Imitation Learning - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) +- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) Here is Tianshou's other features: diff --git a/docs/index.rst b/docs/index.rst index 454997e..7d96f3f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -19,6 +19,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ +* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md new file mode 100644 index 0000000..c3563f6 --- /dev/null +++ b/examples/modelbase/README.md @@ -0,0 +1,7 @@ +# PSRL + +`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` + +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` + +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py new file mode 120000 index 0000000..228d259 --- /dev/null +++ b/examples/modelbase/psrl.py @@ -0,0 +1 @@ +../../test/modelbase/test_psrl.py \ No newline at end of file diff --git a/test/modelbase/__init__.py b/test/modelbase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py new file mode 100644 index 0000000..6fb0e16 --- /dev/null +++ b/test/modelbase/test_psrl.py @@ -0,0 +1,97 @@ +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import PSRLPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='NChain-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=50000) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=5) + parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.0) + parser.add_argument('--rew-mean-prior', type=float, default=0.0) + parser.add_argument('--rew-std-prior', type=float, default=1.0) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--eps', type=float, default=0.01) + parser.add_argument('--add-done-loop', action='store_true') + return parser.parse_known_args()[0] + + +def test_psrl(args=get_args()): + env = gym.make(args.task) + if args.task == "NChain-v0": + env.spec.reward_threshold = 3647 # described in PSRL paper + print("reward threshold:", env.spec.reward_threshold) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + n_action = args.action_shape + n_state = args.state_shape + trans_count_prior = np.ones((n_state, n_action, n_state)) + rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) + rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) + policy = PSRLPolicy( + trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps, + args.add_done_loop) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # log + writer = SummaryWriter(args.logdir + '/' + args.task) + + def stop_fn(x): + if env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + + train_collector.collect(n_step=args.buffer_size, random=True) + # trainer + result = onpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, 1, + args.test_num, 0, stop_fn=stop_fn, writer=writer, + test_in_train=False) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + elif env.spec.reward_threshold: + assert result["best_reward"] >= env.spec.reward_threshold + + +if __name__ == '__main__': + test_psrl() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index cf1f4de..456993d 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -9,6 +9,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -24,5 +25,6 @@ __all__ = [ "TD3Policy", "SACPolicy", "DiscreteSACPolicy", + "PSRLPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/modelbase/__init__.py b/tianshou/policy/modelbase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py new file mode 100644 index 0000000..dcf6a5d --- /dev/null +++ b/tianshou/policy/modelbase/psrl.py @@ -0,0 +1,220 @@ +import torch +import numpy as np +from typing import Any, Dict, Union, Optional + +from tianshou.data import Batch +from tianshou.policy import BasePolicy + + +class PSRLModel(object): + """Implementation of Posterior Sampling Reinforcement Learning Model. + + :param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param np.ndarray rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param np.ndarray rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param float discount_factor: in [0, 1]. + :param float epsilon: for precision control in value iteration. + """ + + def __init__( + self, + trans_count_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + discount_factor: float, + epsilon: float, + ) -> None: + self.trans_count = trans_count_prior + self.n_state, self.n_action = rew_mean_prior.shape + self.rew_mean = rew_mean_prior + self.rew_std = rew_std_prior + self.rew_square_sum = np.zeros_like(rew_mean_prior) + self.rew_std_prior = rew_std_prior + self.discount_factor = discount_factor + self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight + self.eps = epsilon + self.policy: np.ndarray + self.value = np.zeros(self.n_state) + self.updated = False + self.__eps = np.finfo(np.float32).eps.item() + + def observe( + self, + trans_count: np.ndarray, + rew_sum: np.ndarray, + rew_square_sum: np.ndarray, + rew_count: np.ndarray, + ) -> None: + """Add data into memory pool. + + For rewards, we have a normal prior at first. After we observed a + reward for a given state-action pair, we use the mean value of our + observations instead of the prior mean as the posterior mean. The + standard deviations are in inverse proportion to the number of the + corresponding observations. + + :param np.ndarray trans_count: the number of observations, with shape + (n_state, n_action, n_state). + :param np.ndarray rew_sum: total rewards, with shape + (n_state, n_action). + :param np.ndarray rew_square_sum: total rewards' squares, with shape + (n_state, n_action). + :param np.ndarray rew_count: the number of rewards, with shape + (n_state, n_action). + """ + self.updated = False + self.trans_count += trans_count + sum_count = self.rew_count + rew_count + self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count + self.rew_square_sum += rew_square_sum + raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2 + self.rew_std = np.sqrt(1 / ( + sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2)) + self.rew_count = sum_count + + def sample_trans_prob(self) -> np.ndarray: + sample_prob = torch.distributions.Dirichlet( + torch.from_numpy(self.trans_count)).sample().numpy() + return sample_prob + + def sample_reward(self) -> np.ndarray: + return np.random.normal(self.rew_mean, self.rew_std) + + def solve_policy(self) -> None: + self.updated = True + self.policy, self.value = self.value_iteration( + self.sample_trans_prob(), + self.sample_reward(), + self.discount_factor, + self.eps, + self.value, + ) + + @staticmethod + def value_iteration( + trans_prob: np.ndarray, + rew: np.ndarray, + discount_factor: float, + eps: float, + value: np.ndarray, + ) -> np.ndarray: + """Value iteration solver for MDPs. + + :param np.ndarray trans_prob: transition probabilities, with shape + (n_state, n_action, n_state). + :param np.ndarray rew: rewards, with shape (n_state, n_action). + :param float eps: for precision control. + :param float discount_factor: in [0, 1]. + :param np.ndarray value: the initialize value of value array, with + shape (n_state, ). + + :return: the optimal policy with shape (n_state, ). + """ + Q = rew + discount_factor * trans_prob.dot(value) + new_value = Q.max(axis=1) + while not np.allclose(new_value, value, eps): + value = new_value + Q = rew + discount_factor * trans_prob.dot(value) + new_value = Q.max(axis=1) + # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly + Q += eps * np.random.randn(*Q.shape) + return Q.argmax(axis=1), new_value + + def __call__( + self, + obs: np.ndarray, + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> np.ndarray: + if not self.updated: + self.solve_policy() + return self.policy[obs] + + +class PSRLPolicy(BasePolicy): + """Implementation of Posterior Sampling Reinforcement Learning. + + Reference: Strens M. A Bayesian framework for reinforcement learning [C] + //ICML. 2000, 2000: 943-950. + + :param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param np.ndarray rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param np.ndarray rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param float discount_factor: in [0, 1]. + :param float epsilon: for precision control in value iteration. + :param bool add_done_loop: whether to add an extra self-loop for the + terminal state in MDP, defaults to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + trans_count_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + discount_factor: float = 0.99, + epsilon: float = 0.01, + add_done_loop: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert ( + 0.0 <= discount_factor <= 1.0 + ), "discount factor should be in [0, 1]" + self.model = PSRLModel( + trans_count_prior, rew_mean_prior, rew_std_prior, + discount_factor, epsilon) + self._add_done_loop = add_done_loop + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data with PSRL model. + + :return: A :class:`~tianshou.data.Batch` with "act" key containing + the action. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + act = self.model(batch.obs, state=state, info=batch.info) + return Batch(act=act) + + def learn( + self, batch: Batch, *args: Any, **kwargs: Any + ) -> Dict[str, float]: + n_s, n_a = self.model.n_state, self.model.n_action + trans_count = np.zeros((n_s, n_a, n_s)) + rew_sum = np.zeros((n_s, n_a)) + rew_square_sum = np.zeros((n_s, n_a)) + rew_count = np.zeros((n_s, n_a)) + for b in batch.split(size=1): + obs, act, obs_next = b.obs, b.act, b.obs_next + trans_count[obs, act, obs_next] += 1 + rew_sum[obs, act] += b.rew + rew_square_sum[obs, act] += b.rew ** 2 + rew_count[obs, act] += 1 + if self._add_done_loop and b.done: + # special operation for terminal states: add a self-loop + trans_count[obs_next, :, obs_next] += 1 + rew_count[obs_next, :] += 1 + self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) + return { + "psrl/rew_mean": self.model.rew_mean.mean(), + "psrl/rew_std": self.model.rew_std.mean(), + }