Tianshou/test/offline/test_bcq.py
Michael Panchenko 600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00

226 lines
7.7 KiB
Python

import argparse
import datetime
import os
import pickle
import pprint
import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=1e-3)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--step-per-epoch", type=int, default=500)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=1 / 35)
parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[32, 32])
# default to 2 * action_dim
parser.add_argument("--latent_dim", type=int, default=None)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
parser.add_argument("--show-progress", action="store_true")
return parser.parse_known_args()[0]
def test_bcq(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
with open(args.load_buffer_name, "rb") as f:
buffer = pickle.load(f)
else:
buffer = gather_data()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
if args.reward_threshold is None:
# too low?
default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100}
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to(
args.device,
)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())
policy = BCQPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
vae,
vae_optim,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
# buffer has been gathered
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
log_path = os.path.join(args.logdir, args.task, "bcq", log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def watch():
policy.load_state_dict(
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
# trainer
result = OfflineTrainer(
policy=policy,
buffer=buffer,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
episode_per_test=args.test_num,
batch_size=args.batch_size,
save_best_fn=save_best_fn,
stop_fn=stop_fn,
logger=logger,
show_progress=args.show_progress,
).run()
assert stop_fn(result["best_reward"])
# Let's watch its performance!
if __name__ == "__main__":
pprint.pprint(result)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == "__main__":
test_bcq()