Implement TD3+BC for offline RL (#660)
- implement TD3+BC for offline RL; - fix a bug in trainer about test reward not logged because self.env_step is not set for offline setting;
This commit is contained in:
parent
9ce0a554dc
commit
df35718992
@ -31,6 +31,7 @@
|
||||
- Vanilla Imitation Learning
|
||||
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
||||
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf)
|
||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
||||
|
@ -129,6 +129,11 @@ Imitation
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.TD3BCPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -31,6 +31,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
||||
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning <https://arxiv.org/pdf/2106.06860.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||
|
@ -35,6 +35,24 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c
|
||||
| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
|
||||
| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
|
||||
|
||||
### TD3+BC
|
||||
|
||||
| Environment | Dataset | CQL | Parameters |
|
||||
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
|
||||
| HalfCheetah-v2 | halfcheetah-expert-v2 | 11788.25 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
|
||||
| HalfCheetah-v2 | halfcheetah-medium-v2 | 5741.13 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
|
||||
|
||||
#### Observation normalization
|
||||
|
||||
Following the original paper, we use observation normalization by default. You can turn it off by setting `--norm-obs 0`. The difference are small but consistent.
|
||||
|
||||
| Dataset | w/ norm-obs | w/o norm-obs |
|
||||
| :--- | :--- | :--- |
|
||||
| halfcheeta-medium-v2 | 5741.13 | 5724.41 |
|
||||
| halfcheeta-expert-v2 | 11788.25 | 11665.77 |
|
||||
| walker2d-medium-v2 | 4051.76 | 3985.59 |
|
||||
| walker2d-expert-v2 | 5068.15 | 5027.75 |
|
||||
|
||||
## Discrete control
|
||||
|
||||
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.
|
||||
|
227
examples/offline/d4rl_td3_bc.py
Normal file
227
examples/offline/d4rl_td3_bc.py
Normal file
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.policy import TD3BCPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--expert-data-task", type=str, default="halfcheetah-expert-v2"
|
||||
)
|
||||
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
parser.add_argument("--actor-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--critic-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--epoch", type=int, default=200)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=5000)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
|
||||
parser.add_argument("--alpha", type=float, default=2.5)
|
||||
parser.add_argument("--exploration-noise", type=float, default=0.1)
|
||||
parser.add_argument("--policy-noise", type=float, default=0.2)
|
||||
parser.add_argument("--noise-clip", type=float, default=0.5)
|
||||
parser.add_argument("--update-actor-freq", type=int, default=2)
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--norm-obs", type=int, default=1)
|
||||
|
||||
parser.add_argument("--eval-freq", type=int, default=1)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=1 / 35)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_td3_bc():
|
||||
args = get_args()
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0] # float
|
||||
print("device:", args.device)
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||
|
||||
args.state_dim = args.state_shape[0]
|
||||
args.action_dim = args.action_shape[0]
|
||||
print("Max_action", args.max_action)
|
||||
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
if args.norm_obs:
|
||||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
||||
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
# model
|
||||
# actor network
|
||||
net_a = Net(
|
||||
args.state_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
)
|
||||
actor = Actor(
|
||||
net_a,
|
||||
action_shape=args.action_shape,
|
||||
max_action=args.max_action,
|
||||
device=args.device,
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
# critic network
|
||||
net_c1 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
net_c2 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
policy = TD3BCPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
alpha=args.alpha,
|
||||
estimation_step=args.n_step,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs)
|
||||
|
||||
# log
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "td3_bc"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def watch():
|
||||
if args.resume_path is None:
|
||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||
|
||||
policy.load_state_dict(
|
||||
torch.load(args.resume_path, map_location=torch.device("cpu"))
|
||||
)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
if args.norm_obs:
|
||||
replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer)
|
||||
test_envs.set_obs_rms(obs_rms)
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
policy,
|
||||
replay_buffer,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
)
|
||||
pprint.pprint(result)
|
||||
else:
|
||||
watch()
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_td3_bc()
|
@ -1,8 +1,12 @@
|
||||
from typing import Tuple
|
||||
|
||||
import d4rl
|
||||
import gym
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import ReplayBuffer
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
|
||||
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
||||
@ -27,3 +31,18 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
|
||||
obs_next=dataset["next_observations"]
|
||||
)
|
||||
return buffer
|
||||
|
||||
|
||||
def normalize_all_obs_in_replay_buffer(
|
||||
replay_buffer: ReplayBuffer
|
||||
) -> Tuple[ReplayBuffer, RunningMeanStd]:
|
||||
# compute obs mean and var
|
||||
obs_rms = RunningMeanStd()
|
||||
obs_rms.update(replay_buffer.obs)
|
||||
_eps = np.finfo(np.float32).eps.item()
|
||||
# normalize obs
|
||||
replay_buffer._meta["obs"] = (replay_buffer.obs -
|
||||
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
|
||||
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next -
|
||||
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
|
||||
return replay_buffer, obs_rms
|
||||
|
215
test/offline/test_td3_bc.py
Normal file
215
test/offline/test_td3_bc.py
Normal file
@ -0,0 +1,215 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.policy import TD3BCPolicy
|
||||
from tianshou.trainer import OfflineTrainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v1')
|
||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--alpha', type=float, default=2.5)
|
||||
parser.add_argument("--exploration-noise", type=float, default=0.1)
|
||||
parser.add_argument("--policy-noise", type=float, default=0.2)
|
||||
parser.add_argument("--noise-clip", type=float, default=0.5)
|
||||
parser.add_argument("--update-actor-freq", type=int, default=2)
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
|
||||
parser.add_argument("--eval-freq", type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=1 / 35)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only',
|
||||
)
|
||||
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_td3_bc(args=get_args()):
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
if args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
else:
|
||||
buffer = gather_data()
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0] # float
|
||||
if args.reward_threshold is None:
|
||||
# too low?
|
||||
default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task, env.spec.reward_threshold
|
||||
)
|
||||
|
||||
args.state_dim = args.state_shape[0]
|
||||
args.action_dim = args.action_shape[0]
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
# model
|
||||
# actor network
|
||||
net_a = Net(
|
||||
args.state_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
)
|
||||
actor = Actor(
|
||||
net_a,
|
||||
action_shape=args.action_shape,
|
||||
max_action=args.max_action,
|
||||
device=args.device,
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
# critic network
|
||||
net_c1 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
net_c2 = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True,
|
||||
device=args.device,
|
||||
)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
policy = TD3BCPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
tau=args.tau,
|
||||
gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
alpha=args.alpha,
|
||||
estimation_step=args.n_step,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
# buffer has been gathered
|
||||
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc'
|
||||
log_path = os.path.join(args.logdir, args.task, 'td3_bc', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
def watch():
|
||||
policy.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
|
||||
)
|
||||
)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
# trainer
|
||||
trainer = OfflineTrainer(
|
||||
policy,
|
||||
buffer,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
save_best_fn=save_best_fn,
|
||||
stop_fn=stop_fn,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
print(f"Epoch: {epoch}")
|
||||
print(epoch_stat)
|
||||
print(info)
|
||||
|
||||
assert stop_fn(info["best_reward"])
|
||||
|
||||
# Let's watch its performance!
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(info)
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_td3_bc()
|
@ -23,6 +23,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.imitation.bcq import BCQPolicy
|
||||
from tianshou.policy.imitation.cql import CQLPolicy
|
||||
from tianshou.policy.imitation.td3_bc import TD3BCPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||
@ -54,6 +55,7 @@ __all__ = [
|
||||
"ImitationPolicy",
|
||||
"BCQPolicy",
|
||||
"CQLPolicy",
|
||||
"TD3BCPolicy",
|
||||
"DiscreteBCQPolicy",
|
||||
"DiscreteCQLPolicy",
|
||||
"DiscreteCRRPolicy",
|
||||
|
107
tianshou/policy/imitation/td3_bc.py
Normal file
107
tianshou/policy/imitation/td3_bc.py
Normal file
@ -0,0 +1,107 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, to_torch_as
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.policy import TD3Policy
|
||||
|
||||
|
||||
class TD3BCPolicy(TD3Policy):
|
||||
"""Implementation of TD3+BC. arXiv:2106.06860.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||
critic network.
|
||||
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||
:param float exploration_noise: the exploration noise, add to the action.
|
||||
Default to ``GaussianNoise(sigma=0.1)``
|
||||
:param float policy_noise: the noise used in updating policy network.
|
||||
Default to 0.2.
|
||||
:param int update_actor_freq: the update frequency of actor network.
|
||||
Default to 2.
|
||||
:param float noise_clip: the clipping range used in updating policy network.
|
||||
Default to 0.5.
|
||||
:param float alpha: the value of alpha, which controls the weight for TD3 learning
|
||||
relative to behavior cloning.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||
policy_noise: float = 0.2,
|
||||
update_actor_freq: int = 2,
|
||||
noise_clip: float = 0.5,
|
||||
alpha: float = 2.5,
|
||||
reward_normalization: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
|
||||
gamma, exploration_noise, policy_noise, update_actor_freq, noise_clip,
|
||||
reward_normalization, estimation_step, **kwargs
|
||||
)
|
||||
self._alpha = alpha
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
# critic 1&2
|
||||
td1, critic1_loss = self._mse_optimizer(
|
||||
batch, self.critic1, self.critic1_optim
|
||||
)
|
||||
td2, critic2_loss = self._mse_optimizer(
|
||||
batch, self.critic2, self.critic2_optim
|
||||
)
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
|
||||
# actor
|
||||
if self._cnt % self._freq == 0:
|
||||
act = self(batch, eps=0.0).act
|
||||
q_value = self.critic1(batch.obs, act)
|
||||
lmbda = self._alpha / q_value.abs().mean().detach()
|
||||
actor_loss = -lmbda * q_value.mean() + F.mse_loss(
|
||||
act, to_torch_as(batch.act, act)
|
||||
)
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self._last = actor_loss.item()
|
||||
self.actor_optim.step()
|
||||
self.sync_weight()
|
||||
self._cnt += 1
|
||||
return {
|
||||
"loss/actor": self._last,
|
||||
"loss/critic1": critic1_loss.item(),
|
||||
"loss/critic2": critic2_loss.item(),
|
||||
}
|
@ -301,6 +301,10 @@ class BaseTrainer(ABC):
|
||||
if t.n <= t.total and not self.stop_fn_flag:
|
||||
t.update()
|
||||
|
||||
# for offline RL
|
||||
if self.train_collector is None:
|
||||
self.env_step = self.gradient_step * self.batch_size
|
||||
|
||||
if not self.stop_fn_flag:
|
||||
self.logger.save_data(
|
||||
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
|
||||
|
Loading…
x
Reference in New Issue
Block a user