Improve data loading from D4RL and convert RL Unplugged to D4RL format (#624)

This commit is contained in:
Yi Su 2022-05-03 13:37:52 -07:00 committed by GitHub
parent dd16818ce4
commit a7c789f851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 180 additions and 116 deletions

View File

@ -12,7 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env from examples.atari.atari_wrapper import make_atari_env
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteBCQPolicy from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils import TensorboardLogger, WandbLogger
@ -118,18 +119,19 @@ def test_discrete_bcq(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path) print("Loaded agent from: ", args.resume_path)
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if args.buffer_from_rl_unplugged:
"Please run atari_dqn.py first to get expert's data buffer." buffer = load_buffer(args.load_buffer_name)
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
if args.buffer_from_rl_unplugged:
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else: else:
print(f"Unknown buffer format: {args.load_buffer_name}") assert os.path.exists(args.load_buffer_name), \
exit(0) "Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import QRDQN from examples.atari.atari_network import QRDQN
from examples.atari.atari_wrapper import make_atari_env from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteCQLPolicy from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -57,6 +58,9 @@ def get_args():
parser.add_argument( parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
) )
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument( parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
@ -100,15 +104,19 @@ def test_discrete_cql(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path) print("Loaded agent from: ", args.resume_path)
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if args.buffer_from_rl_unplugged:
"Please run atari_qrdqn.py first to get expert's data buffer." buffer = load_buffer(args.load_buffer_name)
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else: else:
print(f"Unknown buffer format: {args.load_buffer_name}") assert os.path.exists(args.load_buffer_name), \
exit(0) "Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteCRRPolicy from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -59,6 +60,9 @@ def get_args():
parser.add_argument( parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
) )
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument( parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
@ -120,15 +124,19 @@ def test_discrete_crr(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path) print("Loaded agent from: ", args.resume_path)
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if args.buffer_from_rl_unplugged:
"Please run atari_qrdqn.py first to get expert's data buffer." buffer = load_buffer(args.load_buffer_name)
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else: else:
print(f"Unknown buffer format: {args.load_buffer_name}") assert os.path.exists(args.load_buffer_name), \
exit(0) "Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import ImitationPolicy from tianshou.policy import ImitationPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -50,6 +51,9 @@ def get_args():
parser.add_argument( parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
) )
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument( parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
@ -85,15 +89,19 @@ def test_il(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path) print("Loaded agent from: ", args.resume_path)
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if args.buffer_from_rl_unplugged:
"Please run atari_qrdqn.py first to get expert's data buffer." buffer = load_buffer(args.load_buffer_name)
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else: else:
print(f"Unknown buffer format: {args.load_buffer_name}") assert os.path.exists(args.load_buffer_name), \
exit(0) "Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -3,7 +3,7 @@
# Adapted from # Adapted from
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py # https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
# #
"""Convert Atari RL Unplugged datasets to Tianshou replay buffers. """Convert Atari RL Unplugged datasets to HDF5 format.
Examples in the dataset represent SARSA transitions stored during a Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543. DQN training run as described in https://arxiv.org/pdf/1907.04543.
@ -30,11 +30,13 @@ Every transition in the dataset is a tuple containing the following features:
import os import os
from argparse import ArgumentParser from argparse import ArgumentParser
import h5py
import numpy as np
import requests import requests
import tensorflow as tf import tensorflow as tf
from tqdm import tqdm from tqdm import tqdm
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch
tf.config.set_visible_devices([], 'GPU') tf.config.set_visible_devices([], 'GPU')
@ -108,7 +110,7 @@ def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
pngs: String Tensor of size (4,) containing PNG encoded images. pngs: String Tensor of size (4,) containing PNG encoded images.
Returns: Returns:
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor. 4 84x84 grayscale images packed in a (4, 84, 84) uint8 Tensor.
""" """
# Statically unroll png decoding # Statically unroll png decoding
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)] frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
@ -195,17 +197,30 @@ def download(url: str, fname: str, chunk_size=1024):
def process_shard(url: str, fname: str, ofname: str) -> None: def process_shard(url: str, fname: str, ofname: str) -> None:
download(url, fname) download(url, fname)
maxsize = 500000
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
act = np.ndarray((maxsize, ), dtype="int64")
rew = np.ndarray((maxsize, ), dtype="float32")
done = np.ndarray((maxsize, ), dtype="bool")
obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
i = 0
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
buffer = ReplayBuffer(500000)
cnt = 0
for example in file_ds: for example in file_ds:
batch = _tf_example_to_tianshou_batch(example) batch = _tf_example_to_tianshou_batch(example)
buffer.add(batch) obs[i], act[i], rew[i], done[i], obs_next[i] = (
cnt += 1 batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
if cnt % 1000 == 0: )
print(f"...{cnt}", end="", flush=True) i += 1
print("\nReplayBuffer size:", len(buffer)) if i % 1000 == 0:
buffer.save_hdf5(ofname, compression="gzip") print(f"...{i}", end="", flush=True)
print("\nDataset size:", i)
# Following D4RL dataset naming conventions
with h5py.File(ofname, "w") as f:
f.create_dataset("observations", data=obs, compression="gzip")
f.create_dataset("actions", data=act, compression="gzip")
f.create_dataset("rewards", data=rew, compression="gzip")
f.create_dataset("terminals", data=done, compression="gzip")
f.create_dataset("next_observations", data=obs_next, compression="gzip")
def process_dataset( def process_dataset(
@ -227,19 +242,19 @@ def main(args):
if args.task not in ALL_GAMES: if args.task not in ALL_GAMES:
raise KeyError(f"`{args.task}` is not in the list of games.") raise KeyError(f"`{args.task}` is not in the list of games.")
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5") dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
if os.path.exists(buffer_path): if os.path.exists(dataset_path):
raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.") raise IOError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir) args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir) cache_path = os.path.join(args.cache_dir, args.task)
dataset_path = os.path.join(args.dataset_dir, args.task) os.makedirs(cache_path, exist_ok=True)
os.makedirs(dataset_path, exist_ok=True) dst_path = os.path.join(args.dataset_dir, args.task)
dst_path = os.path.join(args.buffer_dir, args.task)
os.makedirs(dst_path, exist_ok=True) os.makedirs(dst_path, exist_ok=True)
process_dataset( process_dataset(
args.task, args.task,
args.cache_dir,
args.dataset_dir, args.dataset_dir,
args.buffer_dir,
run_id=args.run_id, run_id=args.run_id,
shard_id=args.shard_id, shard_id=args.shard_id,
total_num_shards=args.total_num_shards total_num_shards=args.total_num_shards
@ -267,12 +282,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--dataset-dir", "--dataset-dir",
default=os.path.expanduser("~/.rl_unplugged/datasets"), default=os.path.expanduser("~/.rl_unplugged/datasets"),
help="Directory for downloaded original datasets.", help="Directory for converted hdf5 files.",
) )
parser.add_argument( parser.add_argument(
"--buffer-dir", "--cache-dir",
default=os.path.expanduser("~/.rl_unplugged/buffers"), default=os.path.expanduser("~/.rl_unplugged/cache"),
help="Directory for converted replay buffers.", help="Directory for downloaded original datasets.",
) )
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -5,13 +5,13 @@ import datetime
import os import os
import pprint import pprint
import d4rl
import gym import gym
import numpy as np import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -211,23 +211,7 @@ def test_bcq():
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) replay_buffer = load_buffer_d4rl(args.expert_data_task)
dataset_size = dataset["rewards"].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
# trainer # trainer
result = offline_trainer( result = offline_trainer(
policy, policy,

View File

@ -5,13 +5,13 @@ import datetime
import os import os
import pprint import pprint
import d4rl
import gym import gym
import numpy as np import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.policy import CQLPolicy from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -206,23 +206,7 @@ def test_cql():
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) replay_buffer = load_buffer_d4rl(args.expert_data_task)
dataset_size = dataset["rewards"].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
# trainer # trainer
result = offline_trainer( result = offline_trainer(
policy, policy,

View File

@ -5,13 +5,13 @@ import datetime
import os import os
import pprint import pprint
import d4rl
import gym import gym
import numpy as np import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.policy import ImitationPolicy from tianshou.policy import ImitationPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -148,23 +148,7 @@ def test_il():
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) replay_buffer = load_buffer_d4rl(args.expert_data_task)
dataset_size = dataset["rewards"].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
# trainer # trainer
result = offline_trainer( result = offline_trainer(
policy, policy,

29
examples/offline/utils.py Normal file
View File

@ -0,0 +1,29 @@
import d4rl
import gym
import h5py
from tianshou.data import ReplayBuffer
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
replay_buffer = ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"]
)
return replay_buffer
def load_buffer(buffer_path: str) -> ReplayBuffer:
with h5py.File(buffer_path, "r") as dataset:
buffer = ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"]
)
return buffer

View File

@ -1019,6 +1019,31 @@ def test_multibuf_hdf5():
os.remove(path) os.remove(path)
def test_from_data():
obs_data = np.ndarray((10, 3, 3), dtype="uint8")
for i in range(10):
obs_data[i] = i * np.ones((3, 3), dtype="uint8")
obs_next_data = np.zeros_like(obs_data)
obs_next_data[:-1] = obs_data[1:]
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
with h5py.File(path, "w") as f:
obs = f.create_dataset("obs", data=obs_data)
act = f.create_dataset("act", data=np.arange(10, dtype="int32"))
rew = f.create_dataset("rew", data=np.arange(10, dtype="float32"))
done = f.create_dataset("done", data=np.zeros(10, dtype="bool"))
obs_next = f.create_dataset("obs_next", data=obs_next_data)
buf = ReplayBuffer.from_data(obs, act, rew, done, obs_next)
assert len(buf) == 10
batch = buf[3]
assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8"))
assert batch.act == 3
assert batch.rew == 3.0
assert not batch.done
assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8"))
os.remove(path)
if __name__ == '__main__': if __name__ == '__main__':
test_replaybuffer() test_replaybuffer()
test_ignore_obs_next() test_ignore_obs_next()
@ -1032,3 +1057,4 @@ if __name__ == '__main__':
test_cachedbuffer() test_cachedbuffer()
test_multibuf_stack() test_multibuf_stack()
test_multibuf_hdf5() test_multibuf_hdf5()
test_from_data()

View File

@ -99,6 +99,22 @@ class ReplayBuffer:
buf.__setstate__(from_hdf5(f, device=device)) # type: ignore buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
return buf return buf
@classmethod
def from_data(
cls, obs: h5py.Dataset, act: h5py.Dataset, rew: h5py.Dataset,
done: h5py.Dataset, obs_next: h5py.Dataset
) -> "ReplayBuffer":
size = len(obs)
assert all(len(dset) == size for dset in [obs, act, rew, done, obs_next]), \
"Lengths of all hdf5 datasets need to be equal."
buf = cls(size)
if size == 0:
return buf
batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next)
buf.set_batch(batch)
buf._size = size
return buf
def reset(self, keep_statistics: bool = False) -> None: def reset(self, keep_statistics: bool = False) -> None:
"""Clear all the data in replay buffer and episode statistics.""" """Clear all the data in replay buffer and episode statistics."""
self.last_index = np.array([0]) self.last_index = np.array([0])