Tianshou/test/offline/test_cql.py
ChenDRAG 1423eeb3b2
Add warnings for duplicate usage of action-bounded actor and action scaling method (#850)
- Fix the current bug discussed in #844 in `test_ppo.py`.
- Add warning for `ActorProb ` if both `max_action ` and
`unbounded=True` are used for model initializations.
- Add warning for PGpolicy and DDPGpolicy if they find duplicate usage
of action-bounded actor and action scaling method.
2023-04-23 16:03:31 -07:00

230 lines
7.9 KiB
Python

import argparse
import datetime
import os
import pickle
import pprint
import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v1')
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=True, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=1e-3)
parser.add_argument('--cql-alpha-lr', type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--cql-weight", type=float, default=1.0)
parser.add_argument("--with-lagrange", type=bool, default=True)
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
args = parser.parse_known_args()[0]
return args
def test_cql(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
if args.reward_threshold is None:
# too low?
default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200}
args.reward_threshold = default_reward_threshold.get(
args.task, env.spec.reward_threshold
)
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# actor network
net_a = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = ActorProb(
net_a,
action_shape=args.action_shape,
device=args.device,
unbounded=True,
conditioned_sigma=True,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
# critic network
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = CQLPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
cql_alpha_lr=args.cql_alpha_lr,
cql_weight=args.cql_weight,
tau=args.tau,
gamma=args.gamma,
alpha=args.alpha,
temperature=args.temperature,
with_lagrange=args.with_lagrange,
lagrange_threshold=args.lagrange_threshold,
min_action=np.min(env.action_space.low),
max_action=np.max(env.action_space.high),
device=args.device,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
# buffer has been gathered
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql'
log_path = os.path.join(args.logdir, args.task, 'cql', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def watch():
policy.load_state_dict(
torch.load(
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
)
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
# trainer
trainer = OfflineTrainer(
policy,
buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_best_fn=save_best_fn,
stop_fn=stop_fn,
logger=logger,
)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
assert stop_fn(info["best_reward"])
# Let's watch its performance!
if __name__ == "__main__":
pprint.pprint(info)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
test_cql()