Improve data loading from D4RL and convert RL Unplugged to D4RL format (#624)
This commit is contained in:
parent
dd16818ce4
commit
a7c789f851
@ -12,7 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import DQN
|
||||
from examples.atari.atari_wrapper import make_atari_env
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from examples.offline.utils import load_buffer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import DiscreteBCQPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
@ -118,18 +119,19 @@ def test_discrete_bcq(args=get_args()):
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_dqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
if args.buffer_from_rl_unplugged:
|
||||
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
if args.buffer_from_rl_unplugged:
|
||||
buffer = load_buffer(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_dqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
print("Replay buffer size:", len(buffer), flush=True)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import QRDQN
|
||||
from examples.atari.atari_wrapper import make_atari_env
|
||||
from examples.offline.utils import load_buffer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import DiscreteCQLPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
@ -57,6 +58,9 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer-from-rl-unplugged", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
@ -100,15 +104,19 @@ def test_discrete_cql(args=get_args()):
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_qrdqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
if args.buffer_from_rl_unplugged:
|
||||
buffer = load_buffer(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_dqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
print("Replay buffer size:", len(buffer), flush=True)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import DQN
|
||||
from examples.atari.atari_wrapper import make_atari_env
|
||||
from examples.offline.utils import load_buffer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import DiscreteCRRPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
@ -59,6 +60,9 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer-from-rl-unplugged", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
@ -120,15 +124,19 @@ def test_discrete_crr(args=get_args()):
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_qrdqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
if args.buffer_from_rl_unplugged:
|
||||
buffer = load_buffer(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_dqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
print("Replay buffer size:", len(buffer), flush=True)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import DQN
|
||||
from examples.atari.atari_wrapper import make_atari_env
|
||||
from examples.offline.utils import load_buffer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import ImitationPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
@ -50,6 +51,9 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer-from-rl-unplugged", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
@ -85,15 +89,19 @@ def test_il(args=get_args()):
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_qrdqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith('.pkl'):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith('.hdf5'):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
if args.buffer_from_rl_unplugged:
|
||||
buffer = load_buffer(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_dqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith(".pkl"):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
print("Replay buffer size:", len(buffer), flush=True)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
|
||||
#
|
||||
"""Convert Atari RL Unplugged datasets to Tianshou replay buffers.
|
||||
"""Convert Atari RL Unplugged datasets to HDF5 format.
|
||||
|
||||
Examples in the dataset represent SARSA transitions stored during a
|
||||
DQN training run as described in https://arxiv.org/pdf/1907.04543.
|
||||
@ -30,11 +30,13 @@ Every transition in the dataset is a tuple containing the following features:
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import Batch
|
||||
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
|
||||
@ -108,7 +110,7 @@ def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
|
||||
pngs: String Tensor of size (4,) containing PNG encoded images.
|
||||
|
||||
Returns:
|
||||
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
|
||||
4 84x84 grayscale images packed in a (4, 84, 84) uint8 Tensor.
|
||||
"""
|
||||
# Statically unroll png decoding
|
||||
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
|
||||
@ -195,17 +197,30 @@ def download(url: str, fname: str, chunk_size=1024):
|
||||
|
||||
def process_shard(url: str, fname: str, ofname: str) -> None:
|
||||
download(url, fname)
|
||||
maxsize = 500000
|
||||
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
|
||||
act = np.ndarray((maxsize, ), dtype="int64")
|
||||
rew = np.ndarray((maxsize, ), dtype="float32")
|
||||
done = np.ndarray((maxsize, ), dtype="bool")
|
||||
obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
|
||||
i = 0
|
||||
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
|
||||
buffer = ReplayBuffer(500000)
|
||||
cnt = 0
|
||||
for example in file_ds:
|
||||
batch = _tf_example_to_tianshou_batch(example)
|
||||
buffer.add(batch)
|
||||
cnt += 1
|
||||
if cnt % 1000 == 0:
|
||||
print(f"...{cnt}", end="", flush=True)
|
||||
print("\nReplayBuffer size:", len(buffer))
|
||||
buffer.save_hdf5(ofname, compression="gzip")
|
||||
obs[i], act[i], rew[i], done[i], obs_next[i] = (
|
||||
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
|
||||
)
|
||||
i += 1
|
||||
if i % 1000 == 0:
|
||||
print(f"...{i}", end="", flush=True)
|
||||
print("\nDataset size:", i)
|
||||
# Following D4RL dataset naming conventions
|
||||
with h5py.File(ofname, "w") as f:
|
||||
f.create_dataset("observations", data=obs, compression="gzip")
|
||||
f.create_dataset("actions", data=act, compression="gzip")
|
||||
f.create_dataset("rewards", data=rew, compression="gzip")
|
||||
f.create_dataset("terminals", data=done, compression="gzip")
|
||||
f.create_dataset("next_observations", data=obs_next, compression="gzip")
|
||||
|
||||
|
||||
def process_dataset(
|
||||
@ -227,19 +242,19 @@ def main(args):
|
||||
if args.task not in ALL_GAMES:
|
||||
raise KeyError(f"`{args.task}` is not in the list of games.")
|
||||
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
|
||||
buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5")
|
||||
if os.path.exists(buffer_path):
|
||||
raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.")
|
||||
dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
|
||||
if os.path.exists(dataset_path):
|
||||
raise IOError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
|
||||
args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
|
||||
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
|
||||
args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir)
|
||||
dataset_path = os.path.join(args.dataset_dir, args.task)
|
||||
os.makedirs(dataset_path, exist_ok=True)
|
||||
dst_path = os.path.join(args.buffer_dir, args.task)
|
||||
cache_path = os.path.join(args.cache_dir, args.task)
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
dst_path = os.path.join(args.dataset_dir, args.task)
|
||||
os.makedirs(dst_path, exist_ok=True)
|
||||
process_dataset(
|
||||
args.task,
|
||||
args.cache_dir,
|
||||
args.dataset_dir,
|
||||
args.buffer_dir,
|
||||
run_id=args.run_id,
|
||||
shard_id=args.shard_id,
|
||||
total_num_shards=args.total_num_shards
|
||||
@ -267,12 +282,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--dataset-dir",
|
||||
default=os.path.expanduser("~/.rl_unplugged/datasets"),
|
||||
help="Directory for downloaded original datasets.",
|
||||
help="Directory for converted hdf5 files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer-dir",
|
||||
default=os.path.expanduser("~/.rl_unplugged/buffers"),
|
||||
help="Directory for converted replay buffers.",
|
||||
"--cache-dir",
|
||||
default=os.path.expanduser("~/.rl_unplugged/cache"),
|
||||
help="Directory for downloaded original datasets.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -5,13 +5,13 @@ import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import d4rl
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
from examples.offline.utils import load_buffer_d4rl
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import BCQPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
@ -211,23 +211,7 @@ def test_bcq():
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
|
||||
dataset_size = dataset["rewards"].size
|
||||
|
||||
print("dataset_size", dataset_size)
|
||||
replay_buffer = ReplayBuffer(dataset_size)
|
||||
|
||||
for i in range(dataset_size):
|
||||
replay_buffer.add(
|
||||
Batch(
|
||||
obs=dataset["observations"][i],
|
||||
act=dataset["actions"][i],
|
||||
rew=dataset["rewards"][i],
|
||||
done=dataset["terminals"][i],
|
||||
obs_next=dataset["next_observations"][i],
|
||||
)
|
||||
)
|
||||
print("dataset loaded")
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
policy,
|
||||
|
@ -5,13 +5,13 @@ import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import d4rl
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
from examples.offline.utils import load_buffer_d4rl
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import CQLPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
@ -206,23 +206,7 @@ def test_cql():
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
|
||||
dataset_size = dataset["rewards"].size
|
||||
|
||||
print("dataset_size", dataset_size)
|
||||
replay_buffer = ReplayBuffer(dataset_size)
|
||||
|
||||
for i in range(dataset_size):
|
||||
replay_buffer.add(
|
||||
Batch(
|
||||
obs=dataset["observations"][i],
|
||||
act=dataset["actions"][i],
|
||||
rew=dataset["rewards"][i],
|
||||
done=dataset["terminals"][i],
|
||||
obs_next=dataset["next_observations"][i],
|
||||
)
|
||||
)
|
||||
print("dataset loaded")
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
policy,
|
||||
|
@ -5,13 +5,13 @@ import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import d4rl
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
from examples.offline.utils import load_buffer_d4rl
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import ImitationPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
@ -148,23 +148,7 @@ def test_il():
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
|
||||
dataset_size = dataset["rewards"].size
|
||||
|
||||
print("dataset_size", dataset_size)
|
||||
replay_buffer = ReplayBuffer(dataset_size)
|
||||
|
||||
for i in range(dataset_size):
|
||||
replay_buffer.add(
|
||||
Batch(
|
||||
obs=dataset["observations"][i],
|
||||
act=dataset["actions"][i],
|
||||
rew=dataset["rewards"][i],
|
||||
done=dataset["terminals"][i],
|
||||
obs_next=dataset["next_observations"][i],
|
||||
)
|
||||
)
|
||||
print("dataset loaded")
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
policy,
|
||||
|
29
examples/offline/utils.py
Normal file
29
examples/offline/utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
import d4rl
|
||||
import gym
|
||||
import h5py
|
||||
|
||||
from tianshou.data import ReplayBuffer
|
||||
|
||||
|
||||
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
||||
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
|
||||
replay_buffer = ReplayBuffer.from_data(
|
||||
obs=dataset["observations"],
|
||||
act=dataset["actions"],
|
||||
rew=dataset["rewards"],
|
||||
done=dataset["terminals"],
|
||||
obs_next=dataset["next_observations"]
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
|
||||
def load_buffer(buffer_path: str) -> ReplayBuffer:
|
||||
with h5py.File(buffer_path, "r") as dataset:
|
||||
buffer = ReplayBuffer.from_data(
|
||||
obs=dataset["observations"],
|
||||
act=dataset["actions"],
|
||||
rew=dataset["rewards"],
|
||||
done=dataset["terminals"],
|
||||
obs_next=dataset["next_observations"]
|
||||
)
|
||||
return buffer
|
@ -1019,6 +1019,31 @@ def test_multibuf_hdf5():
|
||||
os.remove(path)
|
||||
|
||||
|
||||
def test_from_data():
|
||||
obs_data = np.ndarray((10, 3, 3), dtype="uint8")
|
||||
for i in range(10):
|
||||
obs_data[i] = i * np.ones((3, 3), dtype="uint8")
|
||||
obs_next_data = np.zeros_like(obs_data)
|
||||
obs_next_data[:-1] = obs_data[1:]
|
||||
f, path = tempfile.mkstemp(suffix='.hdf5')
|
||||
os.close(f)
|
||||
with h5py.File(path, "w") as f:
|
||||
obs = f.create_dataset("obs", data=obs_data)
|
||||
act = f.create_dataset("act", data=np.arange(10, dtype="int32"))
|
||||
rew = f.create_dataset("rew", data=np.arange(10, dtype="float32"))
|
||||
done = f.create_dataset("done", data=np.zeros(10, dtype="bool"))
|
||||
obs_next = f.create_dataset("obs_next", data=obs_next_data)
|
||||
buf = ReplayBuffer.from_data(obs, act, rew, done, obs_next)
|
||||
assert len(buf) == 10
|
||||
batch = buf[3]
|
||||
assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8"))
|
||||
assert batch.act == 3
|
||||
assert batch.rew == 3.0
|
||||
assert not batch.done
|
||||
assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8"))
|
||||
os.remove(path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
@ -1032,3 +1057,4 @@ if __name__ == '__main__':
|
||||
test_cachedbuffer()
|
||||
test_multibuf_stack()
|
||||
test_multibuf_hdf5()
|
||||
test_from_data()
|
||||
|
@ -99,6 +99,22 @@ class ReplayBuffer:
|
||||
buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
|
||||
return buf
|
||||
|
||||
@classmethod
|
||||
def from_data(
|
||||
cls, obs: h5py.Dataset, act: h5py.Dataset, rew: h5py.Dataset,
|
||||
done: h5py.Dataset, obs_next: h5py.Dataset
|
||||
) -> "ReplayBuffer":
|
||||
size = len(obs)
|
||||
assert all(len(dset) == size for dset in [obs, act, rew, done, obs_next]), \
|
||||
"Lengths of all hdf5 datasets need to be equal."
|
||||
buf = cls(size)
|
||||
if size == 0:
|
||||
return buf
|
||||
batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next)
|
||||
buf.set_batch(batch)
|
||||
buf._size = size
|
||||
return buf
|
||||
|
||||
def reset(self, keep_statistics: bool = False) -> None:
|
||||
"""Clear all the data in replay buffer and episode statistics."""
|
||||
self.last_index = np.array([0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user