Add Intrinsic Curiosity Module (#503)
This commit is contained in:
parent
a2d76d1276
commit
a59d96d041
@ -43,6 +43,7 @@
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
|
||||
|
||||
Here is Tianshou's other features:
|
||||
|
||||
|
@ -137,6 +137,11 @@ Model-based
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.ICMPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Multi-agent
|
||||
-----------
|
||||
|
||||
|
@ -32,6 +32,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
||||
|
@ -135,3 +135,4 @@ Huayu
|
||||
Strens
|
||||
Ornstein
|
||||
Uhlenbeck
|
||||
mse
|
||||
|
@ -11,8 +11,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.policy.modelbased.icm import ICMPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -55,6 +57,24 @@ def get_args():
|
||||
help='watch the play of pre-trained policy only'
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--icm-lr-scale',
|
||||
type=float,
|
||||
default=0.,
|
||||
help='use intrinsic curiosity module with this lr scale'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-reward-scale',
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='scaling factor for intrinsic curiosity reward'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-forward-loss-weight',
|
||||
type=float,
|
||||
default=0.2,
|
||||
help='weight for the forward model loss in ICM'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -101,6 +121,24 @@ def test_dqn(args=get_args()):
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
)
|
||||
if args.icm_lr_scale > 0:
|
||||
feature_net = DQN(
|
||||
*args.state_shape, args.action_shape, args.device, features_only=True
|
||||
)
|
||||
action_dim = np.prod(args.action_shape)
|
||||
feature_dim = feature_net.output_dim
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net.net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=[512],
|
||||
device=args.device
|
||||
)
|
||||
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||
policy = ICMPolicy(
|
||||
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
|
||||
args.icm_forward_loss_weight
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
@ -118,7 +156,8 @@ def test_dqn(args=get_args()):
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
|
||||
log_path = os.path.join(args.logdir, args.task, log_name)
|
||||
if args.logger == "tensorboard":
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
@ -127,7 +166,7 @@ def test_dqn(args=get_args()):
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
project=args.task,
|
||||
name='dqn',
|
||||
name=log_name,
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
)
|
||||
|
225
examples/vizdoom/vizdoom_a2c_icm.py
Normal file
225
examples/vizdoom/vizdoom_a2c_icm.py
Normal file
@ -0,0 +1,225 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from env import Env
|
||||
from network import DQN
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import A2CPolicy, ICMPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='D2_navigation')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=2000000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=300)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--skip-num', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--save-lmp',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='save lmp file for replay whole episode'
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--icm-lr-scale',
|
||||
type=float,
|
||||
default=0.,
|
||||
help='use intrinsic curiosity module with this lr scale'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-reward-scale',
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='scaling factor for intrinsic curiosity reward'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-forward-loss-weight',
|
||||
type=float,
|
||||
default=0.2,
|
||||
help='weight for the forward model loss in ICM'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
args.cfg_path = f"maps/{args.task}.cfg"
|
||||
args.wad_path = f"maps/{args.task}.wad"
|
||||
args.res = (args.skip_num, 84, 84)
|
||||
env = Env(args.cfg_path, args.frames_stack, args.res)
|
||||
args.state_shape = args.res
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
||||
for _ in range(args.training_num)
|
||||
]
|
||||
)
|
||||
test_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
|
||||
for _ in range(min(os.cpu_count() - 1, args.test_num))
|
||||
]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = DQN(
|
||||
*args.state_shape, args.action_shape, device=args.device, features_only=True
|
||||
)
|
||||
actor = Actor(
|
||||
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
)
|
||||
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||
# define policy
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(actor, critic, optim, dist).to(args.device)
|
||||
if args.icm_lr_scale > 0:
|
||||
feature_net = DQN(
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
features_only=True
|
||||
)
|
||||
action_dim = np.prod(args.action_shape)
|
||||
feature_dim = feature_net.output_dim
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net.net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
)
|
||||
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr)
|
||||
policy = ICMPolicy(
|
||||
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
|
||||
args.icm_forward_loss_weight
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
ignore_obs_next=True,
|
||||
save_only_last_obs=True,
|
||||
stack_num=args.frames_stack
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
def watch():
|
||||
# watch agent's performance
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
buffer_num=len(test_envs),
|
||||
ignore_obs_next=True,
|
||||
save_only_last_obs=True,
|
||||
stack_num=args.frames_stack
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
lens = result["lens"].mean() * args.skip_num
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f'Mean length (over {result["n/ep"]} episodes): {lens}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.repeat_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
episode_per_collect=args.episode_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
test_in_train=False
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_a2c(get_args())
|
204
test/modelbased/test_dqn_icm.py
Normal file
204
test/modelbased/test_dqn_icm.py
Normal file
@ -0,0 +1,204 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DQNPolicy, ICMPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument(
|
||||
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||
)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr-scale',
|
||||
type=float,
|
||||
default=1.,
|
||||
help='use intrinsic curiosity module with this lr scale'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--reward-scale',
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='scaling factor for intrinsic curiosity reward'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--forward-loss-weight',
|
||||
type=float,
|
||||
default=0.2,
|
||||
help='weight for the forward model loss in ICM'
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_dqn_icm(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# Q_param = V_param = {"hidden_sizes": [128]}
|
||||
# model
|
||||
net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
# dueling=(Q_param, V_param),
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net,
|
||||
optim,
|
||||
args.gamma,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq,
|
||||
)
|
||||
feature_dim = args.hidden_sizes[-1]
|
||||
feature_net = MLP(
|
||||
np.prod(args.state_shape),
|
||||
output_dim=feature_dim,
|
||||
hidden_sizes=args.hidden_sizes[:-1],
|
||||
device=args.device
|
||||
)
|
||||
action_dim = np.prod(args.action_shape)
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=args.hidden_sizes[-1:],
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||
policy = ICMPolicy(
|
||||
policy, icm_net, icm_optim, args.lr_scale, args.reward_scale,
|
||||
args.forward_loss_weight
|
||||
)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedVectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
alpha=args.alpha,
|
||||
beta=args.beta,
|
||||
)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn_icm')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annnealing, just a demo
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
update_per_step=args.update_per_step,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
def test_pdqn_icm(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_dqn_icm(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dqn_icm(get_args())
|
192
test/modelbased/test_ppo_icm.py
Normal file
192
test/modelbased/test_ppo_icm.py
Normal file
@ -0,0 +1,192 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import ICMPolicy, PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import MLP, ActorCritic, DataParallelNet, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2000)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=20)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
# ppo special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=int, default=0)
|
||||
parser.add_argument('--norm-adv', type=int, default=0)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--lr-scale',
|
||||
type=float,
|
||||
default=1.,
|
||||
help='use intrinsic curiosity module with this lr scale'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--reward-scale',
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='scaling factor for intrinsic curiosity reward'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--forward-loss-weight',
|
||||
type=float,
|
||||
default=0.2,
|
||||
help='weight for the forward model loss in ICM'
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
if torch.cuda.is_available():
|
||||
actor = DataParallelNet(
|
||||
Actor(net, args.action_shape, device=None).to(args.device)
|
||||
)
|
||||
critic = DataParallelNet(Critic(net, device=None).to(args.device))
|
||||
else:
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
# orthogonal initialization
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PPOPolicy(
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist,
|
||||
discount_factor=args.gamma,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
gae_lambda=args.gae_lambda,
|
||||
reward_normalization=args.rew_norm,
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
action_space=env.action_space,
|
||||
deterministic_eval=True,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv
|
||||
)
|
||||
feature_dim = args.hidden_sizes[-1]
|
||||
feature_net = MLP(
|
||||
np.prod(args.state_shape),
|
||||
output_dim=feature_dim,
|
||||
hidden_sizes=args.hidden_sizes[:-1],
|
||||
device=args.device
|
||||
)
|
||||
action_dim = np.prod(args.action_shape)
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=args.hidden_sizes[-1:],
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||
policy = ICMPolicy(
|
||||
policy, icm_net, icm_optim, args.lr_scale, args.reward_scale,
|
||||
args.forward_loss_weight
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo_icm')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.repeat_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
step_per_collect=args.step_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ppo()
|
@ -24,6 +24,7 @@ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
||||
from tianshou.policy.modelbased.icm import ICMPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
__all__ = [
|
||||
@ -50,5 +51,6 @@ __all__ = [
|
||||
"DiscreteCQLPolicy",
|
||||
"DiscreteCRRPolicy",
|
||||
"PSRLPolicy",
|
||||
"ICMPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
121
tianshou/policy/modelbased/icm.py
Normal file
121
tianshou/policy/modelbased/icm.py
Normal file
@ -0,0 +1,121 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
|
||||
|
||||
class ICMPolicy(BasePolicy):
|
||||
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
|
||||
|
||||
:param BasePolicy policy: a base policy to add ICM to.
|
||||
:param IntrinsicCuriosityModule model: the ICM model.
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float lr_scale: the scaling factor for ICM learning.
|
||||
:param float forward_loss_weight: the weight for forward model loss.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
model: IntrinsicCuriosityModule,
|
||||
optim: torch.optim.Optimizer,
|
||||
lr_scale: float,
|
||||
reward_scale: float,
|
||||
forward_loss_weight: float,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.policy = policy
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.lr_scale = lr_scale
|
||||
self.reward_scale = reward_scale
|
||||
self.forward_loss_weight = forward_loss_weight
|
||||
|
||||
def train(self, mode: bool = True) -> "ICMPolicy":
|
||||
"""Set the module in training mode."""
|
||||
self.policy.train(mode)
|
||||
self.training = mode
|
||||
self.model.train(mode)
|
||||
return self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data by inner policy.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
return self.policy.forward(batch, state, **kwargs)
|
||||
|
||||
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
||||
batch: Batch) -> Union[np.ndarray, Batch]:
|
||||
return self.policy.exploration_noise(act, batch)
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
if hasattr(self.policy, "set_eps"):
|
||||
self.policy.set_eps(eps) # type: ignore
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
) -> Batch:
|
||||
"""Pre-process the data from the provided replay buffer.
|
||||
|
||||
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
|
||||
"""
|
||||
mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next)
|
||||
batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss)
|
||||
batch.rew += to_numpy(mse_loss * self.reward_scale)
|
||||
return self.policy.process_fn(batch, buffer, indices)
|
||||
|
||||
def post_process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
) -> None:
|
||||
"""Post-process the data from the provided replay buffer.
|
||||
|
||||
Typical usage is to update the sampling weight in prioritized
|
||||
experience replay. Used in :meth:`update`.
|
||||
"""
|
||||
self.policy.post_process_fn(batch, buffer, indices)
|
||||
batch.rew = batch.policy.orig_rew # restore original reward
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
res = self.policy.learn(batch, **kwargs)
|
||||
self.optim.zero_grad()
|
||||
act_hat = batch.policy.act_hat
|
||||
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
|
||||
inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore
|
||||
forward_loss = batch.policy.mse_loss.mean()
|
||||
loss = (
|
||||
(1 - self.forward_loss_weight) * inverse_loss +
|
||||
self.forward_loss_weight * forward_loss
|
||||
) * self.lr_scale
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
res.update(
|
||||
{
|
||||
"loss/icm": loss.item(),
|
||||
"loss/icm/forward": forward_loss.item(),
|
||||
"loss/icm/inverse": inverse_loss.item()
|
||||
}
|
||||
)
|
||||
return res
|
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP
|
||||
|
||||
|
||||
@ -392,3 +392,58 @@ def sample_noise(model: nn.Module) -> bool:
|
||||
m.sample()
|
||||
done = True
|
||||
return done
|
||||
|
||||
|
||||
class IntrinsicCuriosityModule(nn.Module):
|
||||
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
|
||||
|
||||
:param torch.nn.Module feature_net: a self-defined feature_net which output a
|
||||
flattened hidden state.
|
||||
:param int feature_dim: input dimension of the feature net.
|
||||
:param int action_dim: dimension of the action space.
|
||||
:param hidden_sizes: hidden layer sizes for forward and inverse models.
|
||||
:param device: device for the module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_net: nn.Module,
|
||||
feature_dim: int,
|
||||
action_dim: int,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
device: Union[str, torch.device] = "cpu"
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.feature_net = feature_net
|
||||
self.forward_model = MLP(
|
||||
feature_dim + action_dim,
|
||||
output_dim=feature_dim,
|
||||
hidden_sizes=hidden_sizes,
|
||||
device=device
|
||||
)
|
||||
self.inverse_model = MLP(
|
||||
feature_dim * 2,
|
||||
output_dim=action_dim,
|
||||
hidden_sizes=hidden_sizes,
|
||||
device=device
|
||||
)
|
||||
self.feature_dim = feature_dim
|
||||
self.action_dim = action_dim
|
||||
self.device = device
|
||||
|
||||
def forward(
|
||||
self, s1: Union[np.ndarray, torch.Tensor],
|
||||
act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray,
|
||||
torch.Tensor], **kwargs: Any
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""Mapping: s1, act, s2 -> mse_loss, act_hat."""
|
||||
s1 = to_torch(s1, dtype=torch.float32, device=self.device)
|
||||
s2 = to_torch(s2, dtype=torch.float32, device=self.device)
|
||||
phi1, phi2 = self.feature_net(s1), self.feature_net(s2)
|
||||
act = to_torch(act, dtype=torch.long, device=self.device)
|
||||
phi2_hat = self.forward_model(
|
||||
torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1)
|
||||
)
|
||||
mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1)
|
||||
act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1))
|
||||
return mse_loss, act_hat
|
||||
|
Loading…
x
Reference in New Issue
Block a user