Implement CQLPolicy and offline_cql example (#506)
This commit is contained in:
parent
a59d96d041
commit
bc53ead273
@ -37,6 +37,7 @@
|
||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
||||
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
||||
|
@ -114,6 +114,11 @@ Imitation
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.CQLPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -28,6 +28,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
||||
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||
|
@ -2,10 +2,12 @@
|
||||
|
||||
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
|
||||
|
||||
## Continous control
|
||||
## Continuous control
|
||||
|
||||
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
||||
|
||||
We provide implementation of BCQ and CQL algorithm for continuous control.
|
||||
|
||||
### Train
|
||||
|
||||
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
|
||||
@ -20,7 +22,7 @@ After 1M steps:
|
||||
|
||||

|
||||
|
||||
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment.
|
||||
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the off-policy algorithms in mujoco environment.
|
||||
|
||||
## Results
|
||||
|
||||
|
236
examples/offline/offline_cql.py
Normal file
236
examples/offline/offline_cql.py
Normal file
@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import d4rl
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import CQLPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='halfcheetah-medium-v1')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||
parser.add_argument('--critic-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--auto-alpha', default=True, action='store_true')
|
||||
parser.add_argument('--alpha-lr', type=float, default=1e-4)
|
||||
parser.add_argument('--cql-alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||
parser.add_argument('--epoch', type=int, default=200)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--cql-weight", type=float, default=1.0)
|
||||
parser.add_argument("--with-lagrange", type=bool, default=True)
|
||||
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
|
||||
parser.add_argument("--eval-freq", type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=1 / 35)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only',
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_cql():
|
||||
args = get_args()
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0] # float
|
||||
print("device:", args.device)
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||
|
||||
args.state_dim = args.state_shape[0]
|
||||
args.action_dim = args.action_shape[0]
|
||||
print("Max_action", args.max_action)
|
||||
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
# model
|
||||
# actor network
|
||||
net_a = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
)
|
||||
actor = ActorProb(
|
||||
net_a,
|
||||
action_shape=args.action_shape,
|
||||
max_action=args.max_action,
|
||||
device=args.device,
|
||||
unbounded=True,
|
||||
conditioned_sigma=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
# critic network
|
||||
net_c1 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
net_c2 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = CQLPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
cql_alpha_lr=args.cql_alpha_lr,
|
||||
cql_weight=args.cql_weight,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
alpha=args.alpha,
|
||||
temperature=args.temperature,
|
||||
with_lagrange=args.with_lagrange,
|
||||
lagrange_threshold=args.lagrange_threshold,
|
||||
min_action=np.min(env.action_space.low),
|
||||
max_action=np.max(env.action_space.high),
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql'
|
||||
log_path = os.path.join(args.logdir, args.task, 'cql', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def watch():
|
||||
if args.resume_path is None:
|
||||
args.resume_path = os.path.join(log_path, 'policy.pth')
|
||||
|
||||
policy.load_state_dict(
|
||||
torch.load(args.resume_path, map_location=torch.device('cpu'))
|
||||
)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
dataset = d4rl.qlearning_dataset(env)
|
||||
dataset_size = dataset['rewards'].size
|
||||
|
||||
print("dataset_size", dataset_size)
|
||||
replay_buffer = ReplayBuffer(dataset_size)
|
||||
|
||||
for i in range(dataset_size):
|
||||
replay_buffer.add(
|
||||
Batch(
|
||||
obs=dataset['observations'][i],
|
||||
act=dataset['actions'][i],
|
||||
rew=dataset['rewards'][i],
|
||||
done=dataset['terminals'][i],
|
||||
obs_next=dataset['next_observations'][i],
|
||||
)
|
||||
)
|
||||
print("dataset loaded")
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
policy,
|
||||
replay_buffer,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
)
|
||||
pprint.pprint(result)
|
||||
else:
|
||||
watch()
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cql()
|
219
test/offline/test_cql.py
Normal file
219
test/offline/test_cql.py
Normal file
@ -0,0 +1,219 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import CQLPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import gather_data
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--auto-alpha', default=True, action='store_true')
|
||||
parser.add_argument('--alpha-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--cql-alpha-lr', type=float, default=1e-3)
|
||||
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2000)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--cql-weight", type=float, default=1.0)
|
||||
parser.add_argument("--with-lagrange", type=bool, default=True)
|
||||
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
|
||||
parser.add_argument("--eval-freq", type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=1 / 35)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_cql(args=get_args()):
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
else:
|
||||
buffer = gather_data()
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0] # float
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -1200 # too low?
|
||||
|
||||
args.state_dim = args.state_shape[0]
|
||||
args.action_dim = args.action_shape[0]
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
# model
|
||||
# actor network
|
||||
net_a = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
)
|
||||
actor = ActorProb(
|
||||
net_a,
|
||||
action_shape=args.action_shape,
|
||||
max_action=args.max_action,
|
||||
device=args.device,
|
||||
unbounded=True,
|
||||
conditioned_sigma=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
# critic network
|
||||
net_c1 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
net_c2 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = CQLPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
cql_alpha_lr=args.cql_alpha_lr,
|
||||
cql_weight=args.cql_weight,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
alpha=args.alpha,
|
||||
temperature=args.temperature,
|
||||
with_lagrange=args.with_lagrange,
|
||||
lagrange_threshold=args.lagrange_threshold,
|
||||
min_action=np.min(env.action_space.low),
|
||||
max_action=np.max(env.action_space.high),
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
# buffer has been gathered
|
||||
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql'
|
||||
log_path = os.path.join(args.logdir, args.task, 'cql', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def watch():
|
||||
policy.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
|
||||
)
|
||||
)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
policy,
|
||||
buffer,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
save_fn=save_fn,
|
||||
stop_fn=stop_fn,
|
||||
logger=logger,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
# Let's watch its performance!
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cql()
|
@ -20,6 +20,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.imitation.bcq import BCQPolicy
|
||||
from tianshou.policy.imitation.cql import CQLPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||
@ -47,6 +48,7 @@ __all__ = [
|
||||
"DiscreteSACPolicy",
|
||||
"ImitationPolicy",
|
||||
"BCQPolicy",
|
||||
"CQLPolicy",
|
||||
"DiscreteBCQPolicy",
|
||||
"DiscreteCQLPolicy",
|
||||
"DiscreteCRRPolicy",
|
||||
|
293
tianshou/policy/imitation/cql.py
Normal file
293
tianshou/policy/imitation/cql.py
Normal file
@ -0,0 +1,293 @@
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
|
||||
|
||||
class CQLPolicy(SACPolicy):
|
||||
"""Implementation of CQL algorithm. arXiv:2006.04779.
|
||||
|
||||
:param ActorProb actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||
critic network.
|
||||
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param float cql_alpha_lr: the learning rate of cql_log_alpha. Default to 1e-4.
|
||||
:param float cql_weight: the value of alpha. Default to 1.0.
|
||||
:param float tau: param for soft update of the target network.
|
||||
Default to 0.005.
|
||||
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||
regularization coefficient. Default to 0.2.
|
||||
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||
alpha is automatically tuned.
|
||||
:param float temperature: the value of temperature. Default to 1.0.
|
||||
:param bool with_lagrange: whether to use Lagrange. Default to True.
|
||||
:param float lagrange_threshold: the value of tau in CQL(Lagrange).
|
||||
Default to 10.0.
|
||||
:param float min_action: The minimum value of each dimension of action.
|
||||
Default to -1.0.
|
||||
:param float max_action: The maximum value of each dimension of action.
|
||||
Default to 1.0.
|
||||
:param int num_repeat_actions: The number of times the action is repeated
|
||||
when calculating log-sum-exp. Default to 10.
|
||||
:param float alpha_min: lower bound for clipping cql_alpha. Default to 0.0.
|
||||
:param float alpha_max: upper bound for clipping cql_alpha. Default to 1e6.
|
||||
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
|
||||
:param Union[str, torch.device] device: which device to create this model on.
|
||||
Default to "cpu".
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: ActorProb,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
cql_alpha_lr: float = 1e-4,
|
||||
cql_weight: float = 1.0,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
|
||||
temperature: float = 1.0,
|
||||
with_lagrange: bool = True,
|
||||
lagrange_threshold: float = 10.0,
|
||||
min_action: float = -1.0,
|
||||
max_action: float = 1.0,
|
||||
num_repeat_actions: int = 10,
|
||||
alpha_min: float = 0.0,
|
||||
alpha_max: float = 1e6,
|
||||
clip_grad: float = 1.0,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
|
||||
gamma, alpha, **kwargs
|
||||
)
|
||||
# There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy.
|
||||
self.device = device
|
||||
self.temperature = temperature
|
||||
self.with_lagrange = with_lagrange
|
||||
self.lagrange_threshold = lagrange_threshold
|
||||
|
||||
self.cql_weight = cql_weight
|
||||
|
||||
self.cql_log_alpha = torch.tensor([0.0], requires_grad=True)
|
||||
self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)
|
||||
self.cql_log_alpha = self.cql_log_alpha.to(device)
|
||||
|
||||
self.min_action = min_action
|
||||
self.max_action = max_action
|
||||
|
||||
self.num_repeat_actions = num_repeat_actions
|
||||
|
||||
self.alpha_min = alpha_min
|
||||
self.alpha_max = alpha_max
|
||||
self.clip_grad = clip_grad
|
||||
|
||||
def train(self, mode: bool = True) -> "CQLPolicy":
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = mode
|
||||
self.actor.train(mode)
|
||||
self.critic1.train(mode)
|
||||
self.critic2.train(mode)
|
||||
return self
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
for net, net_old in [
|
||||
[self.critic1, self.critic1_old], [self.critic2, self.critic2_old]
|
||||
]:
|
||||
for param, target_param in zip(net.parameters(), net_old.parameters()):
|
||||
target_param.data.copy_(
|
||||
self._tau * param.data + (1 - self._tau) * target_param.data
|
||||
)
|
||||
|
||||
def actor_pred(self, obs: torch.Tensor) -> \
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch = Batch(obs=obs, info=None)
|
||||
obs_result = self(batch)
|
||||
return obs_result.act, obs_result.log_prob
|
||||
|
||||
def calc_actor_loss(self, obs: torch.Tensor) -> \
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
act_pred, log_pi = self.actor_pred(obs)
|
||||
q1 = self.critic1(obs, act_pred)
|
||||
q2 = self.critic2(obs, act_pred)
|
||||
min_Q = torch.min(q1, q2)
|
||||
self._alpha: Union[float, torch.Tensor]
|
||||
actor_loss = (self._alpha * log_pi - min_Q).mean()
|
||||
# actor_loss.shape: (), log_pi.shape: (batch_size, 1)
|
||||
return actor_loss, log_pi
|
||||
|
||||
def calc_pi_values(self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor) -> \
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
act_pred, log_pi = self.actor_pred(obs_pi)
|
||||
|
||||
q1 = self.critic1(obs_to_pred, act_pred)
|
||||
q2 = self.critic2(obs_to_pred, act_pred)
|
||||
|
||||
return q1 - log_pi.detach(), q2 - log_pi.detach()
|
||||
|
||||
def calc_random_values(self, obs: torch.Tensor, act: torch.Tensor) -> \
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
random_value1 = self.critic1(obs, act)
|
||||
random_log_prob1 = np.log(0.5**act.shape[-1])
|
||||
|
||||
random_value2 = self.critic2(obs, act)
|
||||
random_log_prob2 = np.log(0.5**act.shape[-1])
|
||||
|
||||
return random_value1 - random_log_prob1, random_value2 - random_log_prob2
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
) -> Batch:
|
||||
return batch
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
batch: Batch = to_torch( # type: ignore
|
||||
batch, dtype=torch.float, device=self.device,
|
||||
)
|
||||
obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next
|
||||
batch_size = obs.shape[0]
|
||||
|
||||
# compute actor loss and update actor
|
||||
actor_loss, log_pi = self.calc_actor_loss(obs)
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
||||
# compute alpha loss
|
||||
if self._is_auto_alpha:
|
||||
log_pi = log_pi + self._target_entropy
|
||||
alpha_loss = -(self._log_alpha * log_pi.detach()).mean()
|
||||
self._alpha_optim.zero_grad()
|
||||
# update log_alpha
|
||||
alpha_loss.backward()
|
||||
self._alpha_optim.step()
|
||||
# update alpha
|
||||
self._alpha = self._log_alpha.detach().exp()
|
||||
|
||||
# compute target_Q
|
||||
with torch.no_grad():
|
||||
act_next, new_log_pi = self.actor_pred(obs_next)
|
||||
|
||||
target_Q1 = self.critic1_old(obs_next, act_next)
|
||||
target_Q2 = self.critic2_old(obs_next, act_next)
|
||||
|
||||
target_Q = torch.min(target_Q1, target_Q2) - self._alpha * new_log_pi
|
||||
|
||||
target_Q = \
|
||||
rew + self._gamma * (1 - batch.done) * target_Q.flatten()
|
||||
# shape: (batch_size)
|
||||
|
||||
# compute critic loss
|
||||
current_Q1 = self.critic1(obs, act).flatten()
|
||||
current_Q2 = self.critic2(obs, act).flatten()
|
||||
# shape: (batch_size)
|
||||
|
||||
critic1_loss = F.mse_loss(current_Q1, target_Q)
|
||||
critic2_loss = F.mse_loss(current_Q2, target_Q)
|
||||
|
||||
# CQL
|
||||
random_actions = torch.FloatTensor(
|
||||
batch_size * self.num_repeat_actions, act.shape[-1]
|
||||
).uniform_(-self.min_action, self.max_action).to(self.device)
|
||||
tmp_obs = obs.unsqueeze(1) \
|
||||
.repeat(1, self.num_repeat_actions, 1) \
|
||||
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
|
||||
tmp_obs_next = obs_next.unsqueeze(1) \
|
||||
.repeat(1, self.num_repeat_actions, 1) \
|
||||
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
|
||||
# tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim)
|
||||
|
||||
current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs)
|
||||
next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs)
|
||||
|
||||
random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions)
|
||||
|
||||
for value in [
|
||||
current_pi_value1, current_pi_value2, next_pi_value1, next_pi_value2,
|
||||
random_value1, random_value2
|
||||
]:
|
||||
value.reshape(batch_size, self.num_repeat_actions, 1)
|
||||
|
||||
# cat q values
|
||||
cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1)
|
||||
cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1)
|
||||
# shape: (batch_size, 3 * num_repeat, 1)
|
||||
|
||||
cql1_scaled_loss = \
|
||||
torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() * \
|
||||
self.cql_weight * self.temperature - current_Q1.mean() * \
|
||||
self.cql_weight
|
||||
cql2_scaled_loss = \
|
||||
torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() * \
|
||||
self.cql_weight * self.temperature - current_Q2.mean() * \
|
||||
self.cql_weight
|
||||
# shape: (1)
|
||||
|
||||
if self.with_lagrange:
|
||||
cql_alpha = torch.clamp(
|
||||
self.cql_log_alpha.exp(),
|
||||
self.alpha_min,
|
||||
self.alpha_max,
|
||||
)
|
||||
cql1_scaled_loss = \
|
||||
cql_alpha * (cql1_scaled_loss - self.lagrange_threshold)
|
||||
cql2_scaled_loss = \
|
||||
cql_alpha * (cql2_scaled_loss - self.lagrange_threshold)
|
||||
|
||||
self.cql_alpha_optim.zero_grad()
|
||||
cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5
|
||||
cql_alpha_loss.backward(retain_graph=True)
|
||||
self.cql_alpha_optim.step()
|
||||
|
||||
critic1_loss = critic1_loss + cql1_scaled_loss
|
||||
critic2_loss = critic2_loss + cql2_scaled_loss
|
||||
|
||||
# update critic
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward(retain_graph=True)
|
||||
# clip grad, prevent the vanishing gradient problem
|
||||
# It doesn't seem necessary
|
||||
clip_grad_norm_(self.critic1.parameters(), self.clip_grad)
|
||||
self.critic1_optim.step()
|
||||
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
clip_grad_norm_(self.critic2.parameters(), self.clip_grad)
|
||||
self.critic2_optim.step()
|
||||
|
||||
self.sync_weight()
|
||||
|
||||
result = {
|
||||
"loss/actor": actor_loss.item(),
|
||||
"loss/critic1": critic1_loss.item(),
|
||||
"loss/critic2": critic2_loss.item(),
|
||||
}
|
||||
if self._is_auto_alpha:
|
||||
result["loss/alpha"] = alpha_loss.item()
|
||||
result["alpha"] = self._alpha.item() # type: ignore
|
||||
if self.with_lagrange:
|
||||
result["loss/cql_alpha"] = cql_alpha_loss.item()
|
||||
result["cql_alpha"] = cql_alpha.item()
|
||||
return result
|
Loading…
x
Reference in New Issue
Block a user