Implement TD3+BC for offline RL (#660)

- implement TD3+BC for offline RL;
- fix a bug in trainer about test reward not logged because self.env_step is not set for offline setting;
This commit is contained in:
Yi Su 2022-06-06 09:39:37 -07:00 committed by GitHub
parent 9ce0a554dc
commit df35718992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 599 additions and 0 deletions

View File

@ -31,6 +31,7 @@
- Vanilla Imitation Learning
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
- [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf)
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)

View File

@ -129,6 +129,11 @@ Imitation
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.TD3BCPolicy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
:members:
:undoc-members:

View File

@ -31,6 +31,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning <https://arxiv.org/pdf/2106.06860.pdf>`_
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_

View File

@ -35,6 +35,24 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c
| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
### TD3+BC
| Environment | Dataset | CQL | Parameters |
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
| HalfCheetah-v2 | halfcheetah-expert-v2 | 11788.25 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 5741.13 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |
#### Observation normalization
Following the original paper, we use observation normalization by default. You can turn it off by setting `--norm-obs 0`. The difference are small but consistent.
| Dataset | w/ norm-obs | w/o norm-obs |
| :--- | :--- | :--- |
| halfcheeta-medium-v2 | 5741.13 | 5724.41 |
| halfcheeta-expert-v2 | 11788.25 | 11665.77 |
| walker2d-medium-v2 | 4051.76 | 3985.59 |
| walker2d-expert-v2 | 5068.15 | 5027.75 |
## Discrete control
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3BCPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--expert-data-task", type=str, default="halfcheetah-expert-v2"
)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=3e-4)
parser.add_argument("--critic-lr", type=float, default=3e-4)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--alpha", type=float, default=2.5)
parser.add_argument("--exploration-noise", type=float, default=0.1)
parser.add_argument("--policy-noise", type=float, default=0.2)
parser.add_argument("--noise-clip", type=float, default=0.5)
parser.add_argument("--update-actor-freq", type=int, default=2)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--norm-obs", type=int, default=1)
parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=1 / 35)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def test_td3_bc():
args = get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
print("Max_action", args.max_action)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
if args.norm_obs:
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# actor network
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Actor(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
# critic network
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3BCPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
alpha=args.alpha,
estimation_step=args.n_step,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
test_collector = Collector(policy, test_envs)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "td3_bc"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def watch():
if args.resume_path is None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(
torch.load(args.resume_path, map_location=torch.device("cpu"))
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
if args.norm_obs:
replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer)
test_envs.set_obs_rms(obs_rms)
# trainer
result = offline_trainer(
policy,
replay_buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
)
pprint.pprint(result)
else:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}")
if __name__ == "__main__":
test_td3_bc()

View File

@ -1,8 +1,12 @@
from typing import Tuple
import d4rl
import gym
import h5py
import numpy as np
from tianshou.data import ReplayBuffer
from tianshou.utils import RunningMeanStd
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
@ -27,3 +31,18 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
obs_next=dataset["next_observations"]
)
return buffer
def normalize_all_obs_in_replay_buffer(
replay_buffer: ReplayBuffer
) -> Tuple[ReplayBuffer, RunningMeanStd]:
# compute obs mean and var
obs_rms = RunningMeanStd()
obs_rms.update(replay_buffer.obs)
_eps = np.finfo(np.float32).eps.item()
# normalize obs
replay_buffer._meta["obs"] = (replay_buffer.obs -
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next -
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
return replay_buffer, obs_rms

215
test/offline/test_td3_bc.py Normal file
View File

@ -0,0 +1,215 @@
import argparse
import datetime
import os
import pickle
import pprint
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3BCPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v1')
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--alpha', type=float, default=2.5)
parser.add_argument("--exploration-noise", type=float, default=0.1)
parser.add_argument("--policy-noise", type=float, default=0.2)
parser.add_argument("--noise-clip", type=float, default=0.5)
parser.add_argument("--update-actor-freq", type=int, default=2)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
args = parser.parse_known_args()[0]
return args
def test_td3_bc(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
if args.reward_threshold is None:
# too low?
default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200}
args.reward_threshold = default_reward_threshold.get(
args.task, env.spec.reward_threshold
)
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# actor network
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Actor(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
# critic network
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3BCPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
alpha=args.alpha,
estimation_step=args.n_step,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
# buffer has been gathered
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc'
log_path = os.path.join(args.logdir, args.task, 'td3_bc', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def watch():
policy.load_state_dict(
torch.load(
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
)
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
# trainer
trainer = OfflineTrainer(
policy,
buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_best_fn=save_best_fn,
stop_fn=stop_fn,
logger=logger,
)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
assert stop_fn(info["best_reward"])
# Let's watch its performance!
if __name__ == "__main__":
pprint.pprint(info)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
test_td3_bc()

View File

@ -23,6 +23,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.bcq import BCQPolicy
from tianshou.policy.imitation.cql import CQLPolicy
from tianshou.policy.imitation.td3_bc import TD3BCPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
@ -54,6 +55,7 @@ __all__ = [
"ImitationPolicy",
"BCQPolicy",
"CQLPolicy",
"TD3BCPolicy",
"DiscreteBCQPolicy",
"DiscreteCQLPolicy",
"DiscreteCRRPolicy",

View File

@ -0,0 +1,107 @@
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from tianshou.data import Batch, to_torch_as
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import TD3Policy
class TD3BCPolicy(TD3Policy):
"""Implementation of TD3+BC. arXiv:2106.06860.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
critic network.
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param float exploration_noise: the exploration noise, add to the action.
Default to ``GaussianNoise(sigma=0.1)``
:param float policy_noise: the noise used in updating policy network.
Default to 0.2.
:param int update_actor_freq: the update frequency of actor network.
Default to 2.
:param float noise_clip: the clipping range used in updating policy network.
Default to 0.5.
:param float alpha: the value of alpha, which controls the weight for TD3 learning
relative to behavior cloning.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action) or empty string for no bounding.
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
policy_noise: float = 0.2,
update_actor_freq: int = 2,
noise_clip: float = 0.5,
alpha: float = 2.5,
reward_normalization: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
gamma, exploration_noise, policy_noise, update_actor_freq, noise_clip,
reward_normalization, estimation_step, **kwargs
)
self._alpha = alpha
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# critic 1&2
td1, critic1_loss = self._mse_optimizer(
batch, self.critic1, self.critic1_optim
)
td2, critic2_loss = self._mse_optimizer(
batch, self.critic2, self.critic2_optim
)
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
if self._cnt % self._freq == 0:
act = self(batch, eps=0.0).act
q_value = self.critic1(batch.obs, act)
lmbda = self._alpha / q_value.abs().mean().detach()
actor_loss = -lmbda * q_value.mean() + F.mse_loss(
act, to_torch_as(batch.act, act)
)
self.actor_optim.zero_grad()
actor_loss.backward()
self._last = actor_loss.item()
self.actor_optim.step()
self.sync_weight()
self._cnt += 1
return {
"loss/actor": self._last,
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
}

View File

@ -301,6 +301,10 @@ class BaseTrainer(ABC):
if t.n <= t.total and not self.stop_fn_flag:
t.update()
# for offline RL
if self.train_collector is None:
self.env_step = self.gradient_step * self.batch_size
if not self.stop_fn_flag:
self.logger.save_data(
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn