Add Intrinsic Curiosity Module (#503)
This commit is contained in:
parent
a2d76d1276
commit
a59d96d041
@ -43,6 +43,7 @@
|
|||||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||||
|
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
|
||||||
|
|
||||||
Here is Tianshou's other features:
|
Here is Tianshou's other features:
|
||||||
|
|
||||||
|
@ -137,6 +137,11 @@ Model-based
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.ICMPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
Multi-agent
|
Multi-agent
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
|
||||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||||
|
|
||||||
|
@ -135,3 +135,4 @@ Huayu
|
|||||||
Strens
|
Strens
|
||||||
Ornstein
|
Ornstein
|
||||||
Uhlenbeck
|
Uhlenbeck
|
||||||
|
mse
|
||||||
|
@ -11,8 +11,10 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.policy.modelbased.icm import ICMPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -55,6 +57,24 @@ def get_args():
|
|||||||
help='watch the play of pre-trained policy only'
|
help='watch the play of pre-trained policy only'
|
||||||
)
|
)
|
||||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--icm-lr-scale',
|
||||||
|
type=float,
|
||||||
|
default=0.,
|
||||||
|
help='use intrinsic curiosity module with this lr scale'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--icm-reward-scale',
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help='scaling factor for intrinsic curiosity reward'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--icm-forward-loss-weight',
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help='weight for the forward model loss in ICM'
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -101,6 +121,24 @@ def test_dqn(args=get_args()):
|
|||||||
args.n_step,
|
args.n_step,
|
||||||
target_update_freq=args.target_update_freq
|
target_update_freq=args.target_update_freq
|
||||||
)
|
)
|
||||||
|
if args.icm_lr_scale > 0:
|
||||||
|
feature_net = DQN(
|
||||||
|
*args.state_shape, args.action_shape, args.device, features_only=True
|
||||||
|
)
|
||||||
|
action_dim = np.prod(args.action_shape)
|
||||||
|
feature_dim = feature_net.output_dim
|
||||||
|
icm_net = IntrinsicCuriosityModule(
|
||||||
|
feature_net.net,
|
||||||
|
feature_dim,
|
||||||
|
action_dim,
|
||||||
|
hidden_sizes=[512],
|
||||||
|
device=args.device
|
||||||
|
)
|
||||||
|
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||||
|
policy = ICMPolicy(
|
||||||
|
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
|
||||||
|
args.icm_forward_loss_weight
|
||||||
|
).to(args.device)
|
||||||
# load a previous policy
|
# load a previous policy
|
||||||
if args.resume_path:
|
if args.resume_path:
|
||||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
@ -118,7 +156,8 @@ def test_dqn(args=get_args()):
|
|||||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
|
||||||
|
log_path = os.path.join(args.logdir, args.task, log_name)
|
||||||
if args.logger == "tensorboard":
|
if args.logger == "tensorboard":
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
@ -127,7 +166,7 @@ def test_dqn(args=get_args()):
|
|||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
project=args.task,
|
project=args.task,
|
||||||
name='dqn',
|
name=log_name,
|
||||||
run_id=args.resume_id,
|
run_id=args.resume_id,
|
||||||
config=args,
|
config=args,
|
||||||
)
|
)
|
||||||
|
225
examples/vizdoom/vizdoom_a2c_icm.py
Normal file
225
examples/vizdoom/vizdoom_a2c_icm.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from env import Env
|
||||||
|
from network import DQN
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import ShmemVectorEnv
|
||||||
|
from tianshou.policy import A2CPolicy, ICMPolicy
|
||||||
|
from tianshou.trainer import onpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='D2_navigation')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=2000000)
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0001)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--epoch', type=int, default=300)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
|
parser.add_argument('--episode-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--update-per-step', type=int, default=1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument('--frames-stack', type=int, default=4)
|
||||||
|
parser.add_argument('--skip-num', type=int, default=4)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--save-lmp',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='save lmp file for replay whole episode'
|
||||||
|
)
|
||||||
|
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--icm-lr-scale',
|
||||||
|
type=float,
|
||||||
|
default=0.,
|
||||||
|
help='use intrinsic curiosity module with this lr scale'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--icm-reward-scale',
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help='scaling factor for intrinsic curiosity reward'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--icm-forward-loss-weight',
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help='weight for the forward model loss in ICM'
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def test_a2c(args=get_args()):
|
||||||
|
args.cfg_path = f"maps/{args.task}.cfg"
|
||||||
|
args.wad_path = f"maps/{args.task}.wad"
|
||||||
|
args.res = (args.skip_num, 84, 84)
|
||||||
|
env = Env(args.cfg_path, args.frames_stack, args.res)
|
||||||
|
args.state_shape = args.res
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# should be N_FRAMES x H x W
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
# make environments
|
||||||
|
train_envs = ShmemVectorEnv(
|
||||||
|
[
|
||||||
|
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
||||||
|
for _ in range(args.training_num)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
test_envs = ShmemVectorEnv(
|
||||||
|
[
|
||||||
|
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
|
||||||
|
for _ in range(min(os.cpu_count() - 1, args.test_num))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# define model
|
||||||
|
net = DQN(
|
||||||
|
*args.state_shape, args.action_shape, device=args.device, features_only=True
|
||||||
|
)
|
||||||
|
actor = Actor(
|
||||||
|
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||||
|
)
|
||||||
|
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
|
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||||
|
# define policy
|
||||||
|
dist = torch.distributions.Categorical
|
||||||
|
policy = A2CPolicy(actor, critic, optim, dist).to(args.device)
|
||||||
|
if args.icm_lr_scale > 0:
|
||||||
|
feature_net = DQN(
|
||||||
|
*args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
device=args.device,
|
||||||
|
features_only=True
|
||||||
|
)
|
||||||
|
action_dim = np.prod(args.action_shape)
|
||||||
|
feature_dim = feature_net.output_dim
|
||||||
|
icm_net = IntrinsicCuriosityModule(
|
||||||
|
feature_net.net,
|
||||||
|
feature_dim,
|
||||||
|
action_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device
|
||||||
|
)
|
||||||
|
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr)
|
||||||
|
policy = ICMPolicy(
|
||||||
|
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
|
||||||
|
args.icm_forward_loss_weight
|
||||||
|
).to(args.device)
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||||
|
# when you have enough RAM
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size,
|
||||||
|
buffer_num=len(train_envs),
|
||||||
|
ignore_obs_next=True,
|
||||||
|
save_only_last_obs=True,
|
||||||
|
stack_num=args.frames_stack
|
||||||
|
)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def watch():
|
||||||
|
# watch agent's performance
|
||||||
|
print("Setup test envs ...")
|
||||||
|
policy.eval()
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
if args.save_buffer_name:
|
||||||
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size,
|
||||||
|
buffer_num=len(test_envs),
|
||||||
|
ignore_obs_next=True,
|
||||||
|
save_only_last_obs=True,
|
||||||
|
stack_num=args.frames_stack
|
||||||
|
)
|
||||||
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
|
result = collector.collect(n_step=args.buffer_size)
|
||||||
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
|
else:
|
||||||
|
print("Testing agent ...")
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num, render=args.render
|
||||||
|
)
|
||||||
|
rew = result["rews"].mean()
|
||||||
|
lens = result["lens"].mean() * args.skip_num
|
||||||
|
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||||
|
print(f'Mean length (over {result["n/ep"]} episodes): {lens}')
|
||||||
|
|
||||||
|
if args.watch:
|
||||||
|
watch()
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
# test train_collector and start filling replay buffer
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# trainer
|
||||||
|
result = onpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.repeat_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
episode_per_collect=args.episode_per_collect,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
logger=logger,
|
||||||
|
test_in_train=False
|
||||||
|
)
|
||||||
|
|
||||||
|
pprint.pprint(result)
|
||||||
|
watch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_a2c(get_args())
|
204
test/modelbased/test_dqn_icm.py
Normal file
204
test/modelbased/test_dqn_icm.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.policy import DQNPolicy, ICMPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import MLP, Net
|
||||||
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument(
|
||||||
|
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||||
|
)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.6)
|
||||||
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--lr-scale',
|
||||||
|
type=float,
|
||||||
|
default=1.,
|
||||||
|
help='use intrinsic curiosity module with this lr scale'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--reward-scale',
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help='scaling factor for intrinsic curiosity reward'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--forward-loss-weight',
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help='weight for the forward model loss in ICM'
|
||||||
|
)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_dqn_icm(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# Q_param = V_param = {"hidden_sizes": [128]}
|
||||||
|
# model
|
||||||
|
net = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
# dueling=(Q_param, V_param),
|
||||||
|
).to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
policy = DQNPolicy(
|
||||||
|
net,
|
||||||
|
optim,
|
||||||
|
args.gamma,
|
||||||
|
args.n_step,
|
||||||
|
target_update_freq=args.target_update_freq,
|
||||||
|
)
|
||||||
|
feature_dim = args.hidden_sizes[-1]
|
||||||
|
feature_net = MLP(
|
||||||
|
np.prod(args.state_shape),
|
||||||
|
output_dim=feature_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes[:-1],
|
||||||
|
device=args.device
|
||||||
|
)
|
||||||
|
action_dim = np.prod(args.action_shape)
|
||||||
|
icm_net = IntrinsicCuriosityModule(
|
||||||
|
feature_net,
|
||||||
|
feature_dim,
|
||||||
|
action_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes[-1:],
|
||||||
|
device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||||
|
policy = ICMPolicy(
|
||||||
|
policy, icm_net, icm_optim, args.lr_scale, args.reward_scale,
|
||||||
|
args.forward_loss_weight
|
||||||
|
)
|
||||||
|
# buffer
|
||||||
|
if args.prioritized_replay:
|
||||||
|
buf = PrioritizedVectorReplayBuffer(
|
||||||
|
args.buffer_size,
|
||||||
|
buffer_num=len(train_envs),
|
||||||
|
alpha=args.alpha,
|
||||||
|
beta=args.beta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'dqn_icm')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# eps annnealing, just a demo
|
||||||
|
if env_step <= 10000:
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
elif env_step <= 50000:
|
||||||
|
eps = args.eps_train - (env_step - 10000) / \
|
||||||
|
40000 * (0.9 * args.eps_train)
|
||||||
|
policy.set_eps(eps)
|
||||||
|
else:
|
||||||
|
policy.set_eps(0.1 * args.eps_train)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.step_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
update_per_step=args.update_per_step,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_pdqn_icm(args=get_args()):
|
||||||
|
args.prioritized_replay = True
|
||||||
|
args.gamma = .95
|
||||||
|
args.seed = 1
|
||||||
|
test_dqn_icm(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dqn_icm(get_args())
|
192
test/modelbased/test_ppo_icm.py
Normal file
192
test/modelbased/test_ppo_icm.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.policy import ICMPolicy, PPOPolicy
|
||||||
|
from tianshou.trainer import onpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import MLP, ActorCritic, DataParallelNet, Net
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=2000)
|
||||||
|
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
|
parser.add_argument('--training-num', type=int, default=20)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
# ppo special
|
||||||
|
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||||
|
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||||
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||||
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
|
parser.add_argument('--rew-norm', type=int, default=0)
|
||||||
|
parser.add_argument('--norm-adv', type=int, default=0)
|
||||||
|
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||||
|
parser.add_argument('--dual-clip', type=float, default=None)
|
||||||
|
parser.add_argument('--value-clip', type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
'--lr-scale',
|
||||||
|
type=float,
|
||||||
|
default=1.,
|
||||||
|
help='use intrinsic curiosity module with this lr scale'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--reward-scale',
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help='scaling factor for intrinsic curiosity reward'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--forward-loss-weight',
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help='weight for the forward model loss in ICM'
|
||||||
|
)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_ppo(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
actor = DataParallelNet(
|
||||||
|
Actor(net, args.action_shape, device=None).to(args.device)
|
||||||
|
)
|
||||||
|
critic = DataParallelNet(Critic(net, device=None).to(args.device))
|
||||||
|
else:
|
||||||
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
# orthogonal initialization
|
||||||
|
for m in actor_critic.modules():
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
|
dist = torch.distributions.Categorical
|
||||||
|
policy = PPOPolicy(
|
||||||
|
actor,
|
||||||
|
critic,
|
||||||
|
optim,
|
||||||
|
dist,
|
||||||
|
discount_factor=args.gamma,
|
||||||
|
max_grad_norm=args.max_grad_norm,
|
||||||
|
eps_clip=args.eps_clip,
|
||||||
|
vf_coef=args.vf_coef,
|
||||||
|
ent_coef=args.ent_coef,
|
||||||
|
gae_lambda=args.gae_lambda,
|
||||||
|
reward_normalization=args.rew_norm,
|
||||||
|
dual_clip=args.dual_clip,
|
||||||
|
value_clip=args.value_clip,
|
||||||
|
action_space=env.action_space,
|
||||||
|
deterministic_eval=True,
|
||||||
|
advantage_normalization=args.norm_adv,
|
||||||
|
recompute_advantage=args.recompute_adv
|
||||||
|
)
|
||||||
|
feature_dim = args.hidden_sizes[-1]
|
||||||
|
feature_net = MLP(
|
||||||
|
np.prod(args.state_shape),
|
||||||
|
output_dim=feature_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes[:-1],
|
||||||
|
device=args.device
|
||||||
|
)
|
||||||
|
action_dim = np.prod(args.action_shape)
|
||||||
|
icm_net = IntrinsicCuriosityModule(
|
||||||
|
feature_net,
|
||||||
|
feature_dim,
|
||||||
|
action_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes[-1:],
|
||||||
|
device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||||
|
policy = ICMPolicy(
|
||||||
|
policy, icm_net, icm_optim, args.lr_scale, args.reward_scale,
|
||||||
|
args.forward_loss_weight
|
||||||
|
)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||||
|
)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'ppo_icm')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = onpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.repeat_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
step_per_collect=args.step_per_collect,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_ppo()
|
@ -24,6 +24,7 @@ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
|||||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||||
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
||||||
|
from tianshou.policy.modelbased.icm import ICMPolicy
|
||||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -50,5 +51,6 @@ __all__ = [
|
|||||||
"DiscreteCQLPolicy",
|
"DiscreteCQLPolicy",
|
||||||
"DiscreteCRRPolicy",
|
"DiscreteCRRPolicy",
|
||||||
"PSRLPolicy",
|
"PSRLPolicy",
|
||||||
|
"ICMPolicy",
|
||||||
"MultiAgentPolicyManager",
|
"MultiAgentPolicyManager",
|
||||||
]
|
]
|
||||||
|
121
tianshou/policy/modelbased/icm.py
Normal file
121
tianshou/policy/modelbased/icm.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
|
|
||||||
|
|
||||||
|
class ICMPolicy(BasePolicy):
|
||||||
|
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
|
||||||
|
|
||||||
|
:param BasePolicy policy: a base policy to add ICM to.
|
||||||
|
:param IntrinsicCuriosityModule model: the ICM model.
|
||||||
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
:param float lr_scale: the scaling factor for ICM learning.
|
||||||
|
:param float forward_loss_weight: the weight for forward model loss.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
model: IntrinsicCuriosityModule,
|
||||||
|
optim: torch.optim.Optimizer,
|
||||||
|
lr_scale: float,
|
||||||
|
reward_scale: float,
|
||||||
|
forward_loss_weight: float,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.policy = policy
|
||||||
|
self.model = model
|
||||||
|
self.optim = optim
|
||||||
|
self.lr_scale = lr_scale
|
||||||
|
self.reward_scale = reward_scale
|
||||||
|
self.forward_loss_weight = forward_loss_weight
|
||||||
|
|
||||||
|
def train(self, mode: bool = True) -> "ICMPolicy":
|
||||||
|
"""Set the module in training mode."""
|
||||||
|
self.policy.train(mode)
|
||||||
|
self.training = mode
|
||||||
|
self.model.train(mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: Batch,
|
||||||
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Batch:
|
||||||
|
"""Compute action over the given batch data by inner policy.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||||
|
more detailed explanation.
|
||||||
|
"""
|
||||||
|
return self.policy.forward(batch, state, **kwargs)
|
||||||
|
|
||||||
|
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
||||||
|
batch: Batch) -> Union[np.ndarray, Batch]:
|
||||||
|
return self.policy.exploration_noise(act, batch)
|
||||||
|
|
||||||
|
def set_eps(self, eps: float) -> None:
|
||||||
|
"""Set the eps for epsilon-greedy exploration."""
|
||||||
|
if hasattr(self.policy, "set_eps"):
|
||||||
|
self.policy.set_eps(eps) # type: ignore
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def process_fn(
|
||||||
|
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||||
|
) -> Batch:
|
||||||
|
"""Pre-process the data from the provided replay buffer.
|
||||||
|
|
||||||
|
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
|
||||||
|
"""
|
||||||
|
mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next)
|
||||||
|
batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss)
|
||||||
|
batch.rew += to_numpy(mse_loss * self.reward_scale)
|
||||||
|
return self.policy.process_fn(batch, buffer, indices)
|
||||||
|
|
||||||
|
def post_process_fn(
|
||||||
|
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||||
|
) -> None:
|
||||||
|
"""Post-process the data from the provided replay buffer.
|
||||||
|
|
||||||
|
Typical usage is to update the sampling weight in prioritized
|
||||||
|
experience replay. Used in :meth:`update`.
|
||||||
|
"""
|
||||||
|
self.policy.post_process_fn(batch, buffer, indices)
|
||||||
|
batch.rew = batch.policy.orig_rew # restore original reward
|
||||||
|
|
||||||
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
|
res = self.policy.learn(batch, **kwargs)
|
||||||
|
self.optim.zero_grad()
|
||||||
|
act_hat = batch.policy.act_hat
|
||||||
|
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
|
||||||
|
inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore
|
||||||
|
forward_loss = batch.policy.mse_loss.mean()
|
||||||
|
loss = (
|
||||||
|
(1 - self.forward_loss_weight) * inverse_loss +
|
||||||
|
self.forward_loss_weight * forward_loss
|
||||||
|
) * self.lr_scale
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
res.update(
|
||||||
|
{
|
||||||
|
"loss/icm": loss.item(),
|
||||||
|
"loss/icm/forward": forward_loss.item(),
|
||||||
|
"loss/icm/inverse": inverse_loss.item()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return res
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.utils.net.common import MLP
|
from tianshou.utils.net.common import MLP
|
||||||
|
|
||||||
|
|
||||||
@ -392,3 +392,58 @@ def sample_noise(model: nn.Module) -> bool:
|
|||||||
m.sample()
|
m.sample()
|
||||||
done = True
|
done = True
|
||||||
return done
|
return done
|
||||||
|
|
||||||
|
|
||||||
|
class IntrinsicCuriosityModule(nn.Module):
|
||||||
|
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
|
||||||
|
|
||||||
|
:param torch.nn.Module feature_net: a self-defined feature_net which output a
|
||||||
|
flattened hidden state.
|
||||||
|
:param int feature_dim: input dimension of the feature net.
|
||||||
|
:param int action_dim: dimension of the action space.
|
||||||
|
:param hidden_sizes: hidden layer sizes for forward and inverse models.
|
||||||
|
:param device: device for the module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_net: nn.Module,
|
||||||
|
feature_dim: int,
|
||||||
|
action_dim: int,
|
||||||
|
hidden_sizes: Sequence[int] = (),
|
||||||
|
device: Union[str, torch.device] = "cpu"
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.feature_net = feature_net
|
||||||
|
self.forward_model = MLP(
|
||||||
|
feature_dim + action_dim,
|
||||||
|
output_dim=feature_dim,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
self.inverse_model = MLP(
|
||||||
|
feature_dim * 2,
|
||||||
|
output_dim=action_dim,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, s1: Union[np.ndarray, torch.Tensor],
|
||||||
|
act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray,
|
||||||
|
torch.Tensor], **kwargs: Any
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
r"""Mapping: s1, act, s2 -> mse_loss, act_hat."""
|
||||||
|
s1 = to_torch(s1, dtype=torch.float32, device=self.device)
|
||||||
|
s2 = to_torch(s2, dtype=torch.float32, device=self.device)
|
||||||
|
phi1, phi2 = self.feature_net(s1), self.feature_net(s2)
|
||||||
|
act = to_torch(act, dtype=torch.long, device=self.device)
|
||||||
|
phi2_hat = self.forward_model(
|
||||||
|
torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1)
|
||||||
|
)
|
||||||
|
mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1)
|
||||||
|
act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1))
|
||||||
|
return mse_loss, act_hat
|
||||||
|
Loading…
x
Reference in New Issue
Block a user