Improve data loading from D4RL and convert RL Unplugged to D4RL format (#624)

This commit is contained in:
Yi Su 2022-05-03 13:37:52 -07:00 committed by GitHub
parent dd16818ce4
commit a7c789f851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 180 additions and 116 deletions

View File

@ -12,7 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
@ -118,18 +119,19 @@ def test_discrete_bcq(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
if args.buffer_from_rl_unplugged:
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import QRDQN
from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer
@ -57,6 +58,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
@ -100,15 +104,19 @@ def test_discrete_cql(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
@ -59,6 +60,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
@ -120,15 +124,19 @@ def test_discrete_crr(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import ImitationPolicy
from tianshou.trainer import offline_trainer
@ -50,6 +51,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
@ -85,15 +89,19 @@ def test_il(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -3,7 +3,7 @@
# Adapted from
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
#
"""Convert Atari RL Unplugged datasets to Tianshou replay buffers.
"""Convert Atari RL Unplugged datasets to HDF5 format.
Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.
@ -30,11 +30,13 @@ Every transition in the dataset is a tuple containing the following features:
import os
from argparse import ArgumentParser
import h5py
import numpy as np
import requests
import tensorflow as tf
from tqdm import tqdm
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch
tf.config.set_visible_devices([], 'GPU')
@ -108,7 +110,7 @@ def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
pngs: String Tensor of size (4,) containing PNG encoded images.
Returns:
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
4 84x84 grayscale images packed in a (4, 84, 84) uint8 Tensor.
"""
# Statically unroll png decoding
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
@ -195,17 +197,30 @@ def download(url: str, fname: str, chunk_size=1024):
def process_shard(url: str, fname: str, ofname: str) -> None:
download(url, fname)
maxsize = 500000
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
act = np.ndarray((maxsize, ), dtype="int64")
rew = np.ndarray((maxsize, ), dtype="float32")
done = np.ndarray((maxsize, ), dtype="bool")
obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
i = 0
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
buffer = ReplayBuffer(500000)
cnt = 0
for example in file_ds:
batch = _tf_example_to_tianshou_batch(example)
buffer.add(batch)
cnt += 1
if cnt % 1000 == 0:
print(f"...{cnt}", end="", flush=True)
print("\nReplayBuffer size:", len(buffer))
buffer.save_hdf5(ofname, compression="gzip")
obs[i], act[i], rew[i], done[i], obs_next[i] = (
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
)
i += 1
if i % 1000 == 0:
print(f"...{i}", end="", flush=True)
print("\nDataset size:", i)
# Following D4RL dataset naming conventions
with h5py.File(ofname, "w") as f:
f.create_dataset("observations", data=obs, compression="gzip")
f.create_dataset("actions", data=act, compression="gzip")
f.create_dataset("rewards", data=rew, compression="gzip")
f.create_dataset("terminals", data=done, compression="gzip")
f.create_dataset("next_observations", data=obs_next, compression="gzip")
def process_dataset(
@ -227,19 +242,19 @@ def main(args):
if args.task not in ALL_GAMES:
raise KeyError(f"`{args.task}` is not in the list of games.")
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5")
if os.path.exists(buffer_path):
raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.")
dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
if os.path.exists(dataset_path):
raise IOError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir)
dataset_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dataset_path, exist_ok=True)
dst_path = os.path.join(args.buffer_dir, args.task)
cache_path = os.path.join(args.cache_dir, args.task)
os.makedirs(cache_path, exist_ok=True)
dst_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dst_path, exist_ok=True)
process_dataset(
args.task,
args.cache_dir,
args.dataset_dir,
args.buffer_dir,
run_id=args.run_id,
shard_id=args.shard_id,
total_num_shards=args.total_num_shards
@ -267,12 +282,12 @@ if __name__ == "__main__":
parser.add_argument(
"--dataset-dir",
default=os.path.expanduser("~/.rl_unplugged/datasets"),
help="Directory for downloaded original datasets.",
help="Directory for converted hdf5 files.",
)
parser.add_argument(
"--buffer-dir",
default=os.path.expanduser("~/.rl_unplugged/buffers"),
help="Directory for converted replay buffers.",
"--cache-dir",
default=os.path.expanduser("~/.rl_unplugged/cache"),
help="Directory for downloaded original datasets.",
)
args = parser.parse_args()
main(args)

View File

@ -5,13 +5,13 @@ import datetime
import os
import pprint
import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer
@ -211,23 +211,7 @@ def test_bcq():
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset["rewards"].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = offline_trainer(
policy,

View File

@ -5,13 +5,13 @@ import datetime
import os
import pprint
import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
@ -206,23 +206,7 @@ def test_cql():
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset["rewards"].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = offline_trainer(
policy,

View File

@ -5,13 +5,13 @@ import datetime
import os
import pprint
import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import ImitationPolicy
from tianshou.trainer import offline_trainer
@ -148,23 +148,7 @@ def test_il():
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset["rewards"].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = offline_trainer(
policy,

29
examples/offline/utils.py Normal file
View File

@ -0,0 +1,29 @@
import d4rl
import gym
import h5py
from tianshou.data import ReplayBuffer
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
replay_buffer = ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"]
)
return replay_buffer
def load_buffer(buffer_path: str) -> ReplayBuffer:
with h5py.File(buffer_path, "r") as dataset:
buffer = ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"]
)
return buffer

View File

@ -1019,6 +1019,31 @@ def test_multibuf_hdf5():
os.remove(path)
def test_from_data():
obs_data = np.ndarray((10, 3, 3), dtype="uint8")
for i in range(10):
obs_data[i] = i * np.ones((3, 3), dtype="uint8")
obs_next_data = np.zeros_like(obs_data)
obs_next_data[:-1] = obs_data[1:]
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
with h5py.File(path, "w") as f:
obs = f.create_dataset("obs", data=obs_data)
act = f.create_dataset("act", data=np.arange(10, dtype="int32"))
rew = f.create_dataset("rew", data=np.arange(10, dtype="float32"))
done = f.create_dataset("done", data=np.zeros(10, dtype="bool"))
obs_next = f.create_dataset("obs_next", data=obs_next_data)
buf = ReplayBuffer.from_data(obs, act, rew, done, obs_next)
assert len(buf) == 10
batch = buf[3]
assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8"))
assert batch.act == 3
assert batch.rew == 3.0
assert not batch.done
assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8"))
os.remove(path)
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
@ -1032,3 +1057,4 @@ if __name__ == '__main__':
test_cachedbuffer()
test_multibuf_stack()
test_multibuf_hdf5()
test_from_data()

View File

@ -99,6 +99,22 @@ class ReplayBuffer:
buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
return buf
@classmethod
def from_data(
cls, obs: h5py.Dataset, act: h5py.Dataset, rew: h5py.Dataset,
done: h5py.Dataset, obs_next: h5py.Dataset
) -> "ReplayBuffer":
size = len(obs)
assert all(len(dset) == size for dset in [obs, act, rew, done, obs_next]), \
"Lengths of all hdf5 datasets need to be equal."
buf = cls(size)
if size == 0:
return buf
batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next)
buf.set_batch(batch)
buf._size = size
return buf
def reset(self, keep_statistics: bool = False) -> None:
"""Clear all the data in replay buffer and episode statistics."""
self.last_index = np.array([0])