Implement TD3+BC for offline RL (#660)
- implement TD3+BC for offline RL; - fix a bug in trainer about test reward not logged because self.env_step is not set for offline setting;
This commit is contained in:
parent
9ce0a554dc
commit
df35718992
@ -31,6 +31,7 @@
|
|||||||
- Vanilla Imitation Learning
|
- Vanilla Imitation Learning
|
||||||
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
||||||
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
|
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||||
|
- [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf)
|
||||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||||
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
||||||
|
@ -129,6 +129,11 @@ Imitation
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.TD3BCPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
@ -31,6 +31,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||||
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
||||||
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning <https://arxiv.org/pdf/2106.06860.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||||
|
@ -35,6 +35,24 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c
|
|||||||
| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
|
| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
|
||||||
| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
|
| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
|
||||||
|
|
||||||
|
### TD3+BC
|
||||||
|
|
||||||
|
| Environment | Dataset | CQL | Parameters |
|
||||||
|
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
|
||||||
|
| HalfCheetah-v2 | halfcheetah-expert-v2 | 11788.25 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
|
||||||
|
| HalfCheetah-v2 | halfcheetah-medium-v2 | 5741.13 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
|
||||||
|
|
||||||
|
#### Observation normalization
|
||||||
|
|
||||||
|
Following the original paper, we use observation normalization by default. You can turn it off by setting `--norm-obs 0`. The difference are small but consistent.
|
||||||
|
|
||||||
|
| Dataset | w/ norm-obs | w/o norm-obs |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| halfcheeta-medium-v2 | 5741.13 | 5724.41 |
|
||||||
|
| halfcheeta-expert-v2 | 11788.25 | 11665.77 |
|
||||||
|
| walker2d-medium-v2 | 4051.76 | 3985.59 |
|
||||||
|
| walker2d-expert-v2 | 5068.15 | 5027.75 |
|
||||||
|
|
||||||
## Discrete control
|
## Discrete control
|
||||||
|
|
||||||
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.
|
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.
|
||||||
|
227
examples/offline/d4rl_td3_bc.py
Normal file
227
examples/offline/d4rl_td3_bc.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer
|
||||||
|
from tianshou.data import Collector
|
||||||
|
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
|
||||||
|
from tianshou.exploration import GaussianNoise
|
||||||
|
from tianshou.policy import TD3BCPolicy
|
||||||
|
from tianshou.trainer import offline_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
"--expert-data-task", type=str, default="halfcheetah-expert-v2"
|
||||||
|
)
|
||||||
|
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||||
|
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||||
|
parser.add_argument("--actor-lr", type=float, default=3e-4)
|
||||||
|
parser.add_argument("--critic-lr", type=float, default=3e-4)
|
||||||
|
parser.add_argument("--epoch", type=int, default=200)
|
||||||
|
parser.add_argument("--step-per-epoch", type=int, default=5000)
|
||||||
|
parser.add_argument("--n-step", type=int, default=3)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=256)
|
||||||
|
|
||||||
|
parser.add_argument("--alpha", type=float, default=2.5)
|
||||||
|
parser.add_argument("--exploration-noise", type=float, default=0.1)
|
||||||
|
parser.add_argument("--policy-noise", type=float, default=0.2)
|
||||||
|
parser.add_argument("--noise-clip", type=float, default=0.5)
|
||||||
|
parser.add_argument("--update-actor-freq", type=int, default=2)
|
||||||
|
parser.add_argument("--tau", type=float, default=0.005)
|
||||||
|
parser.add_argument("--gamma", type=float, default=0.99)
|
||||||
|
parser.add_argument("--norm-obs", type=int, default=1)
|
||||||
|
|
||||||
|
parser.add_argument("--eval-freq", type=int, default=1)
|
||||||
|
parser.add_argument("--test-num", type=int, default=10)
|
||||||
|
parser.add_argument("--logdir", type=str, default="log")
|
||||||
|
parser.add_argument("--render", type=float, default=1 / 35)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
parser.add_argument("--resume-path", type=str, default=None)
|
||||||
|
parser.add_argument("--resume-id", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logger",
|
||||||
|
type=str,
|
||||||
|
default="tensorboard",
|
||||||
|
choices=["tensorboard", "wandb"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--watch",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="watch the play of pre-trained policy only",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def test_td3_bc():
|
||||||
|
args = get_args()
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0] # float
|
||||||
|
print("device:", args.device)
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||||
|
|
||||||
|
args.state_dim = args.state_shape[0]
|
||||||
|
args.action_dim = args.action_shape[0]
|
||||||
|
print("Max_action", args.max_action)
|
||||||
|
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
if args.norm_obs:
|
||||||
|
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
||||||
|
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
|
||||||
|
# model
|
||||||
|
# actor network
|
||||||
|
net_a = Net(
|
||||||
|
args.state_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
actor = Actor(
|
||||||
|
net_a,
|
||||||
|
action_shape=args.action_shape,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
|
# critic network
|
||||||
|
net_c1 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
net_c2 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||||
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
|
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||||
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
policy = TD3BCPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic1,
|
||||||
|
critic1_optim,
|
||||||
|
critic2,
|
||||||
|
critic2_optim,
|
||||||
|
tau=args.tau,
|
||||||
|
gamma=args.gamma,
|
||||||
|
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||||
|
policy_noise=args.policy_noise,
|
||||||
|
update_actor_freq=args.update_actor_freq,
|
||||||
|
noise_clip=args.noise_clip,
|
||||||
|
alpha=args.alpha,
|
||||||
|
estimation_step=args.n_step,
|
||||||
|
action_space=env.action_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
|
||||||
|
# log
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
args.algo_name = "td3_bc"
|
||||||
|
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||||
|
log_path = os.path.join(args.logdir, log_name)
|
||||||
|
|
||||||
|
# logger
|
||||||
|
if args.logger == "wandb":
|
||||||
|
logger = WandbLogger(
|
||||||
|
save_interval=1,
|
||||||
|
name=log_name.replace(os.path.sep, "__"),
|
||||||
|
run_id=args.resume_id,
|
||||||
|
config=args,
|
||||||
|
project=args.wandb_project,
|
||||||
|
)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
if args.logger == "tensorboard":
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
else: # wandb
|
||||||
|
logger.load(writer)
|
||||||
|
|
||||||
|
def save_best_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
|
def watch():
|
||||||
|
if args.resume_path is None:
|
||||||
|
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||||
|
|
||||||
|
policy.load_state_dict(
|
||||||
|
torch.load(args.resume_path, map_location=torch.device("cpu"))
|
||||||
|
)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
|
||||||
|
if not args.watch:
|
||||||
|
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||||
|
if args.norm_obs:
|
||||||
|
replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer)
|
||||||
|
test_envs.set_obs_rms(obs_rms)
|
||||||
|
# trainer
|
||||||
|
result = offline_trainer(
|
||||||
|
policy,
|
||||||
|
replay_buffer,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
save_best_fn=save_best_fn,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
pprint.pprint(result)
|
||||||
|
else:
|
||||||
|
watch()
|
||||||
|
|
||||||
|
# Let's watch its performance!
|
||||||
|
policy.eval()
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||||
|
print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_td3_bc()
|
@ -1,8 +1,12 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import d4rl
|
import d4rl
|
||||||
import gym
|
import gym
|
||||||
import h5py
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.data import ReplayBuffer
|
from tianshou.data import ReplayBuffer
|
||||||
|
from tianshou.utils import RunningMeanStd
|
||||||
|
|
||||||
|
|
||||||
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
||||||
@ -27,3 +31,18 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
|
|||||||
obs_next=dataset["next_observations"]
|
obs_next=dataset["next_observations"]
|
||||||
)
|
)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_all_obs_in_replay_buffer(
|
||||||
|
replay_buffer: ReplayBuffer
|
||||||
|
) -> Tuple[ReplayBuffer, RunningMeanStd]:
|
||||||
|
# compute obs mean and var
|
||||||
|
obs_rms = RunningMeanStd()
|
||||||
|
obs_rms.update(replay_buffer.obs)
|
||||||
|
_eps = np.finfo(np.float32).eps.item()
|
||||||
|
# normalize obs
|
||||||
|
replay_buffer._meta["obs"] = (replay_buffer.obs -
|
||||||
|
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
|
||||||
|
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next -
|
||||||
|
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
|
||||||
|
return replay_buffer, obs_rms
|
||||||
|
215
test/offline/test_td3_bc.py
Normal file
215
test/offline/test_td3_bc.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.exploration import GaussianNoise
|
||||||
|
from tianshou.policy import TD3BCPolicy
|
||||||
|
from tianshou.trainer import OfflineTrainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from gather_pendulum_data import expert_file_name, gather_data
|
||||||
|
else: # pytest
|
||||||
|
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Pendulum-v1')
|
||||||
|
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--alpha', type=float, default=2.5)
|
||||||
|
parser.add_argument("--exploration-noise", type=float, default=0.1)
|
||||||
|
parser.add_argument("--policy-noise", type=float, default=0.2)
|
||||||
|
parser.add_argument("--noise-clip", type=float, default=0.5)
|
||||||
|
parser.add_argument("--update-actor-freq", type=int, default=2)
|
||||||
|
parser.add_argument("--tau", type=float, default=0.005)
|
||||||
|
parser.add_argument("--gamma", type=float, default=0.99)
|
||||||
|
|
||||||
|
parser.add_argument("--eval-freq", type=int, default=1)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=1 / 35)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only',
|
||||||
|
)
|
||||||
|
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_td3_bc(args=get_args()):
|
||||||
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
|
if args.load_buffer_name.endswith(".hdf5"):
|
||||||
|
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||||
|
else:
|
||||||
|
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||||
|
else:
|
||||||
|
buffer = gather_data()
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0] # float
|
||||||
|
if args.reward_threshold is None:
|
||||||
|
# too low?
|
||||||
|
default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200}
|
||||||
|
args.reward_threshold = default_reward_threshold.get(
|
||||||
|
args.task, env.spec.reward_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
args.state_dim = args.state_shape[0]
|
||||||
|
args.action_dim = args.action_shape[0]
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
|
||||||
|
# model
|
||||||
|
# actor network
|
||||||
|
net_a = Net(
|
||||||
|
args.state_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
actor = Actor(
|
||||||
|
net_a,
|
||||||
|
action_shape=args.action_shape,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
|
# critic network
|
||||||
|
net_c1 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
net_c2 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||||
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
|
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||||
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
policy = TD3BCPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic1,
|
||||||
|
critic1_optim,
|
||||||
|
critic2,
|
||||||
|
critic2_optim,
|
||||||
|
tau=args.tau,
|
||||||
|
gamma=args.gamma,
|
||||||
|
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||||
|
policy_noise=args.policy_noise,
|
||||||
|
update_actor_freq=args.update_actor_freq,
|
||||||
|
noise_clip=args.noise_clip,
|
||||||
|
alpha=args.alpha,
|
||||||
|
estimation_step=args.n_step,
|
||||||
|
action_space=env.action_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
# buffer has been gathered
|
||||||
|
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# log
|
||||||
|
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||||
|
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc'
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'td3_bc', log_file)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_best_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= args.reward_threshold
|
||||||
|
|
||||||
|
def watch():
|
||||||
|
policy.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
trainer = OfflineTrainer(
|
||||||
|
policy,
|
||||||
|
buffer,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
save_best_fn=save_best_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch, epoch_stat, info in trainer:
|
||||||
|
print(f"Epoch: {epoch}")
|
||||||
|
print(epoch_stat)
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
assert stop_fn(info["best_reward"])
|
||||||
|
|
||||||
|
# Let's watch its performance!
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pprint.pprint(info)
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_td3_bc()
|
@ -23,6 +23,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
|||||||
from tianshou.policy.imitation.base import ImitationPolicy
|
from tianshou.policy.imitation.base import ImitationPolicy
|
||||||
from tianshou.policy.imitation.bcq import BCQPolicy
|
from tianshou.policy.imitation.bcq import BCQPolicy
|
||||||
from tianshou.policy.imitation.cql import CQLPolicy
|
from tianshou.policy.imitation.cql import CQLPolicy
|
||||||
|
from tianshou.policy.imitation.td3_bc import TD3BCPolicy
|
||||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||||
@ -54,6 +55,7 @@ __all__ = [
|
|||||||
"ImitationPolicy",
|
"ImitationPolicy",
|
||||||
"BCQPolicy",
|
"BCQPolicy",
|
||||||
"CQLPolicy",
|
"CQLPolicy",
|
||||||
|
"TD3BCPolicy",
|
||||||
"DiscreteBCQPolicy",
|
"DiscreteBCQPolicy",
|
||||||
"DiscreteCQLPolicy",
|
"DiscreteCQLPolicy",
|
||||||
"DiscreteCRRPolicy",
|
"DiscreteCRRPolicy",
|
||||||
|
107
tianshou/policy/imitation/td3_bc.py
Normal file
107
tianshou/policy/imitation/td3_bc.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch, to_torch_as
|
||||||
|
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||||
|
from tianshou.policy import TD3Policy
|
||||||
|
|
||||||
|
|
||||||
|
class TD3BCPolicy(TD3Policy):
|
||||||
|
"""Implementation of TD3+BC. arXiv:2106.06860.
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||||
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
||||||
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||||
|
critic network.
|
||||||
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
||||||
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||||
|
critic network.
|
||||||
|
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||||
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||||
|
:param float exploration_noise: the exploration noise, add to the action.
|
||||||
|
Default to ``GaussianNoise(sigma=0.1)``
|
||||||
|
:param float policy_noise: the noise used in updating policy network.
|
||||||
|
Default to 0.2.
|
||||||
|
:param int update_actor_freq: the update frequency of actor network.
|
||||||
|
Default to 2.
|
||||||
|
:param float noise_clip: the clipping range used in updating policy network.
|
||||||
|
Default to 0.5.
|
||||||
|
:param float alpha: the value of alpha, which controls the weight for TD3 learning
|
||||||
|
relative to behavior cloning.
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
|
Default to False.
|
||||||
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||||
|
[action_spaces.low, action_spaces.high]. Default to True.
|
||||||
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||||
|
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||||
|
Default to "clip".
|
||||||
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||||
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor: torch.nn.Module,
|
||||||
|
actor_optim: torch.optim.Optimizer,
|
||||||
|
critic1: torch.nn.Module,
|
||||||
|
critic1_optim: torch.optim.Optimizer,
|
||||||
|
critic2: torch.nn.Module,
|
||||||
|
critic2_optim: torch.optim.Optimizer,
|
||||||
|
tau: float = 0.005,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||||
|
policy_noise: float = 0.2,
|
||||||
|
update_actor_freq: int = 2,
|
||||||
|
noise_clip: float = 0.5,
|
||||||
|
alpha: float = 2.5,
|
||||||
|
reward_normalization: bool = False,
|
||||||
|
estimation_step: int = 1,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
|
||||||
|
gamma, exploration_noise, policy_noise, update_actor_freq, noise_clip,
|
||||||
|
reward_normalization, estimation_step, **kwargs
|
||||||
|
)
|
||||||
|
self._alpha = alpha
|
||||||
|
|
||||||
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
|
# critic 1&2
|
||||||
|
td1, critic1_loss = self._mse_optimizer(
|
||||||
|
batch, self.critic1, self.critic1_optim
|
||||||
|
)
|
||||||
|
td2, critic2_loss = self._mse_optimizer(
|
||||||
|
batch, self.critic2, self.critic2_optim
|
||||||
|
)
|
||||||
|
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||||
|
|
||||||
|
# actor
|
||||||
|
if self._cnt % self._freq == 0:
|
||||||
|
act = self(batch, eps=0.0).act
|
||||||
|
q_value = self.critic1(batch.obs, act)
|
||||||
|
lmbda = self._alpha / q_value.abs().mean().detach()
|
||||||
|
actor_loss = -lmbda * q_value.mean() + F.mse_loss(
|
||||||
|
act, to_torch_as(batch.act, act)
|
||||||
|
)
|
||||||
|
self.actor_optim.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self._last = actor_loss.item()
|
||||||
|
self.actor_optim.step()
|
||||||
|
self.sync_weight()
|
||||||
|
self._cnt += 1
|
||||||
|
return {
|
||||||
|
"loss/actor": self._last,
|
||||||
|
"loss/critic1": critic1_loss.item(),
|
||||||
|
"loss/critic2": critic2_loss.item(),
|
||||||
|
}
|
@ -301,6 +301,10 @@ class BaseTrainer(ABC):
|
|||||||
if t.n <= t.total and not self.stop_fn_flag:
|
if t.n <= t.total and not self.stop_fn_flag:
|
||||||
t.update()
|
t.update()
|
||||||
|
|
||||||
|
# for offline RL
|
||||||
|
if self.train_collector is None:
|
||||||
|
self.env_step = self.gradient_step * self.batch_size
|
||||||
|
|
||||||
if not self.stop_fn_flag:
|
if not self.stop_fn_flag:
|
||||||
self.logger.save_data(
|
self.logger.save_data(
|
||||||
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
|
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
|
||||||
|
Loading…
x
Reference in New Issue
Block a user