add PSRL policy (#202)
Add PSRL policy in tianshou/policy/modelbase/psrl.py. Co-authored-by: n+e <trinkle23897@cmu.edu>
This commit is contained in:
parent
bf39b9ef7d
commit
dcfcbb37f4
@ -31,6 +31,7 @@
|
||||
- Vanilla Imitation Learning
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
|
||||
Here is Tianshou's other features:
|
||||
|
||||
|
@ -19,6 +19,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
7
examples/modelbase/README.md
Normal file
7
examples/modelbase/README.md
Normal file
@ -0,0 +1,7 @@
|
||||
# PSRL
|
||||
|
||||
`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1`
|
||||
|
||||
`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20`
|
||||
|
||||
`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20`
|
1
examples/modelbase/psrl.py
Symbolic link
1
examples/modelbase/psrl.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../test/modelbase/test_psrl.py
|
0
test/modelbase/__init__.py
Normal file
0
test/modelbase/__init__.py
Normal file
97
test/modelbase/test_psrl.py
Normal file
97
test/modelbase/test_psrl.py
Normal file
@ -0,0 +1,97 @@
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PSRLPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='NChain-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=50000)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5)
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.0)
|
||||
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
|
||||
parser.add_argument('--rew-std-prior', type=float, default=1.0)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--eps', type=float, default=0.01)
|
||||
parser.add_argument('--add-done-loop', action='store_true')
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def test_psrl(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
if args.task == "NChain-v0":
|
||||
env.spec.reward_threshold = 3647 # described in PSRL paper
|
||||
print("reward threshold:", env.spec.reward_threshold)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
n_action = args.action_shape
|
||||
n_state = args.state_shape
|
||||
trans_count_prior = np.ones((n_state, n_action, n_state))
|
||||
rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior)
|
||||
rew_std_prior = np.full((n_state, n_action), args.rew_std_prior)
|
||||
policy = PSRLPolicy(
|
||||
trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps,
|
||||
args.add_done_loop)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + args.task)
|
||||
|
||||
def stop_fn(x):
|
||||
if env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
else:
|
||||
return False
|
||||
|
||||
train_collector.collect(n_step=args.buffer_size, random=True)
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, 1,
|
||||
args.test_num, 0, stop_fn=stop_fn, writer=writer,
|
||||
test_in_train=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
elif env.spec.reward_threshold:
|
||||
assert result["best_reward"] >= env.spec.reward_threshold
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_psrl()
|
@ -9,6 +9,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||
from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.modelbase.psrl import PSRLPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
|
||||
@ -24,5 +25,6 @@ __all__ = [
|
||||
"TD3Policy",
|
||||
"SACPolicy",
|
||||
"DiscreteSACPolicy",
|
||||
"PSRLPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
0
tianshou/policy/modelbase/__init__.py
Normal file
0
tianshou/policy/modelbase/__init__.py
Normal file
220
tianshou/policy/modelbase/psrl.py
Normal file
220
tianshou/policy/modelbase/psrl.py
Normal file
@ -0,0 +1,220 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class PSRLModel(object):
|
||||
"""Implementation of Posterior Sampling Reinforcement Learning Model.
|
||||
|
||||
:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
|
||||
(n_state, n_action, n_state).
|
||||
:param np.ndarray rew_mean_prior: means of the normal priors of rewards,
|
||||
with shape (n_state, n_action).
|
||||
:param np.ndarray rew_std_prior: standard deviations of the normal priors
|
||||
of rewards, with shape (n_state, n_action).
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param float epsilon: for precision control in value iteration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trans_count_prior: np.ndarray,
|
||||
rew_mean_prior: np.ndarray,
|
||||
rew_std_prior: np.ndarray,
|
||||
discount_factor: float,
|
||||
epsilon: float,
|
||||
) -> None:
|
||||
self.trans_count = trans_count_prior
|
||||
self.n_state, self.n_action = rew_mean_prior.shape
|
||||
self.rew_mean = rew_mean_prior
|
||||
self.rew_std = rew_std_prior
|
||||
self.rew_square_sum = np.zeros_like(rew_mean_prior)
|
||||
self.rew_std_prior = rew_std_prior
|
||||
self.discount_factor = discount_factor
|
||||
self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight
|
||||
self.eps = epsilon
|
||||
self.policy: np.ndarray
|
||||
self.value = np.zeros(self.n_state)
|
||||
self.updated = False
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def observe(
|
||||
self,
|
||||
trans_count: np.ndarray,
|
||||
rew_sum: np.ndarray,
|
||||
rew_square_sum: np.ndarray,
|
||||
rew_count: np.ndarray,
|
||||
) -> None:
|
||||
"""Add data into memory pool.
|
||||
|
||||
For rewards, we have a normal prior at first. After we observed a
|
||||
reward for a given state-action pair, we use the mean value of our
|
||||
observations instead of the prior mean as the posterior mean. The
|
||||
standard deviations are in inverse proportion to the number of the
|
||||
corresponding observations.
|
||||
|
||||
:param np.ndarray trans_count: the number of observations, with shape
|
||||
(n_state, n_action, n_state).
|
||||
:param np.ndarray rew_sum: total rewards, with shape
|
||||
(n_state, n_action).
|
||||
:param np.ndarray rew_square_sum: total rewards' squares, with shape
|
||||
(n_state, n_action).
|
||||
:param np.ndarray rew_count: the number of rewards, with shape
|
||||
(n_state, n_action).
|
||||
"""
|
||||
self.updated = False
|
||||
self.trans_count += trans_count
|
||||
sum_count = self.rew_count + rew_count
|
||||
self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count
|
||||
self.rew_square_sum += rew_square_sum
|
||||
raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2
|
||||
self.rew_std = np.sqrt(1 / (
|
||||
sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2))
|
||||
self.rew_count = sum_count
|
||||
|
||||
def sample_trans_prob(self) -> np.ndarray:
|
||||
sample_prob = torch.distributions.Dirichlet(
|
||||
torch.from_numpy(self.trans_count)).sample().numpy()
|
||||
return sample_prob
|
||||
|
||||
def sample_reward(self) -> np.ndarray:
|
||||
return np.random.normal(self.rew_mean, self.rew_std)
|
||||
|
||||
def solve_policy(self) -> None:
|
||||
self.updated = True
|
||||
self.policy, self.value = self.value_iteration(
|
||||
self.sample_trans_prob(),
|
||||
self.sample_reward(),
|
||||
self.discount_factor,
|
||||
self.eps,
|
||||
self.value,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def value_iteration(
|
||||
trans_prob: np.ndarray,
|
||||
rew: np.ndarray,
|
||||
discount_factor: float,
|
||||
eps: float,
|
||||
value: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Value iteration solver for MDPs.
|
||||
|
||||
:param np.ndarray trans_prob: transition probabilities, with shape
|
||||
(n_state, n_action, n_state).
|
||||
:param np.ndarray rew: rewards, with shape (n_state, n_action).
|
||||
:param float eps: for precision control.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param np.ndarray value: the initialize value of value array, with
|
||||
shape (n_state, ).
|
||||
|
||||
:return: the optimal policy with shape (n_state, ).
|
||||
"""
|
||||
Q = rew + discount_factor * trans_prob.dot(value)
|
||||
new_value = Q.max(axis=1)
|
||||
while not np.allclose(new_value, value, eps):
|
||||
value = new_value
|
||||
Q = rew + discount_factor * trans_prob.dot(value)
|
||||
new_value = Q.max(axis=1)
|
||||
# this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly
|
||||
Q += eps * np.random.randn(*Q.shape)
|
||||
return Q.argmax(axis=1), new_value
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
obs: np.ndarray,
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> np.ndarray:
|
||||
if not self.updated:
|
||||
self.solve_policy()
|
||||
return self.policy[obs]
|
||||
|
||||
|
||||
class PSRLPolicy(BasePolicy):
|
||||
"""Implementation of Posterior Sampling Reinforcement Learning.
|
||||
|
||||
Reference: Strens M. A Bayesian framework for reinforcement learning [C]
|
||||
//ICML. 2000, 2000: 943-950.
|
||||
|
||||
:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
|
||||
(n_state, n_action, n_state).
|
||||
:param np.ndarray rew_mean_prior: means of the normal priors of rewards,
|
||||
with shape (n_state, n_action).
|
||||
:param np.ndarray rew_std_prior: standard deviations of the normal priors
|
||||
of rewards, with shape (n_state, n_action).
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param float epsilon: for precision control in value iteration.
|
||||
:param bool add_done_loop: whether to add an extra self-loop for the
|
||||
terminal state in MDP, defaults to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trans_count_prior: np.ndarray,
|
||||
rew_mean_prior: np.ndarray,
|
||||
rew_std_prior: np.ndarray,
|
||||
discount_factor: float = 0.99,
|
||||
epsilon: float = 0.01,
|
||||
add_done_loop: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert (
|
||||
0.0 <= discount_factor <= 1.0
|
||||
), "discount factor should be in [0, 1]"
|
||||
self.model = PSRLModel(
|
||||
trans_count_prior, rew_mean_prior, rew_std_prior,
|
||||
discount_factor, epsilon)
|
||||
self._add_done_loop = add_done_loop
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data with PSRL model.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` with "act" key containing
|
||||
the action.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
act = self.model(batch.obs, state=state, info=batch.info)
|
||||
return Batch(act=act)
|
||||
|
||||
def learn(
|
||||
self, batch: Batch, *args: Any, **kwargs: Any
|
||||
) -> Dict[str, float]:
|
||||
n_s, n_a = self.model.n_state, self.model.n_action
|
||||
trans_count = np.zeros((n_s, n_a, n_s))
|
||||
rew_sum = np.zeros((n_s, n_a))
|
||||
rew_square_sum = np.zeros((n_s, n_a))
|
||||
rew_count = np.zeros((n_s, n_a))
|
||||
for b in batch.split(size=1):
|
||||
obs, act, obs_next = b.obs, b.act, b.obs_next
|
||||
trans_count[obs, act, obs_next] += 1
|
||||
rew_sum[obs, act] += b.rew
|
||||
rew_square_sum[obs, act] += b.rew ** 2
|
||||
rew_count[obs, act] += 1
|
||||
if self._add_done_loop and b.done:
|
||||
# special operation for terminal states: add a self-loop
|
||||
trans_count[obs_next, :, obs_next] += 1
|
||||
rew_count[obs_next, :] += 1
|
||||
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
|
||||
return {
|
||||
"psrl/rew_mean": self.model.rew_mean.mean(),
|
||||
"psrl/rew_std": self.model.rew_std.mean(),
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user