Saving and loading replay buffer with HDF5 (#261)
As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
This commit is contained in:
parent
cd481423dc
commit
5d13d8a453
1
.gitignore
vendored
1
.gitignore
vendored
@ -144,3 +144,4 @@ MUJOCO_LOG.TXT
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
*.zip
|
*.zip
|
||||||
*.pstats
|
*.pstats
|
||||||
|
*.swp
|
||||||
|
9
docs/bibtex.json
Normal file
9
docs/bibtex.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"cited": {
|
||||||
|
"tutorials/dqn": [
|
||||||
|
"DQN",
|
||||||
|
"DDPG",
|
||||||
|
"PPO"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
@ -70,6 +70,7 @@ autodoc_default_options = {
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
bibtex_bibfiles = ['refs.bib']
|
||||||
|
|
||||||
# -- Options for HTML output -------------------------------------------------
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
3
setup.py
3
setup.py
@ -47,10 +47,11 @@ setup(
|
|||||||
install_requires=[
|
install_requires=[
|
||||||
"gym>=0.15.4",
|
"gym>=0.15.4",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"numpy",
|
"numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||||
"tensorboard",
|
"tensorboard",
|
||||||
"torch>=1.4.0",
|
"torch>=1.4.0",
|
||||||
"numba>=0.51.0",
|
"numba>=0.51.0",
|
||||||
|
"h5py>=3.1.0"
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from timeit import timeit
|
from timeit import timeit
|
||||||
|
|
||||||
from tianshou.data import Batch, SegmentTree, \
|
from tianshou.data import Batch, SegmentTree, \
|
||||||
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
from tianshou.data.utils.converter import to_hdf5
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -278,7 +282,73 @@ def test_pickle():
|
|||||||
pbuf.weight[np.arange(len(pbuf))])
|
pbuf.weight[np.arange(len(pbuf))])
|
||||||
|
|
||||||
|
|
||||||
|
def test_hdf5():
|
||||||
|
size = 100
|
||||||
|
buffers = {
|
||||||
|
"array": ReplayBuffer(size, stack_num=2),
|
||||||
|
"list": ListReplayBuffer(),
|
||||||
|
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||||
|
}
|
||||||
|
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
rew = torch.tensor([1.]).to(device)
|
||||||
|
for i in range(4):
|
||||||
|
kwargs = {
|
||||||
|
'obs': Batch(index=np.array([i])),
|
||||||
|
'act': i,
|
||||||
|
'rew': rew,
|
||||||
|
'done': 0,
|
||||||
|
'info': {"number": {"n": i}, 'extra': None},
|
||||||
|
}
|
||||||
|
buffers["array"].add(**kwargs)
|
||||||
|
buffers["list"].add(**kwargs)
|
||||||
|
buffers["prioritized"].add(weight=np.random.rand(), **kwargs)
|
||||||
|
|
||||||
|
# save
|
||||||
|
paths = {}
|
||||||
|
for k, buf in buffers.items():
|
||||||
|
f, path = tempfile.mkstemp(suffix='.hdf5')
|
||||||
|
os.close(f)
|
||||||
|
buf.save_hdf5(path)
|
||||||
|
paths[k] = path
|
||||||
|
|
||||||
|
# load replay buffer
|
||||||
|
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
|
||||||
|
|
||||||
|
# compare
|
||||||
|
for k in buffers.keys():
|
||||||
|
assert len(_buffers[k]) == len(buffers[k])
|
||||||
|
assert np.allclose(_buffers[k].act, buffers[k].act)
|
||||||
|
assert _buffers[k].stack_num == buffers[k].stack_num
|
||||||
|
assert _buffers[k]._maxsize == buffers[k]._maxsize
|
||||||
|
assert _buffers[k]._index == buffers[k]._index
|
||||||
|
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
||||||
|
for k in ["array", "prioritized"]:
|
||||||
|
assert isinstance(buffers[k].get(0, "info"), Batch)
|
||||||
|
assert isinstance(_buffers[k].get(0, "info"), Batch)
|
||||||
|
for k in ["array"]:
|
||||||
|
assert np.all(
|
||||||
|
buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
|
||||||
|
assert np.all(
|
||||||
|
buffers[k][:].info.extra == _buffers[k][:].info.extra)
|
||||||
|
|
||||||
|
for path in paths.values():
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
|
# raise exception when value cannot be pickled
|
||||||
|
data = {"not_supported": lambda x: x*x}
|
||||||
|
grp = h5py.Group
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
to_hdf5(data, grp)
|
||||||
|
# ndarray with data type not supported by HDF5 that cannot be pickled
|
||||||
|
data = {"not_supported": np.array(lambda x: x*x)}
|
||||||
|
grp = h5py.Group
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
to_hdf5(data, grp)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
test_hdf5()
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
test_ignore_obs_next()
|
test_ignore_obs_next()
|
||||||
test_stack()
|
test_stack()
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
|
||||||
from tianshou.data.batch import _create_value
|
from tianshou.data.batch import _create_value
|
||||||
|
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||||
|
from tianshou.data.utils.converter import to_hdf5, from_hdf5
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
class ReplayBuffer:
|
||||||
@ -38,7 +40,10 @@ class ReplayBuffer:
|
|||||||
>>> # but there are only three valid items, so len(buf) == 3.
|
>>> # but there are only three valid items, so len(buf) == 3.
|
||||||
>>> len(buf)
|
>>> len(buf)
|
||||||
3
|
3
|
||||||
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
|
>>> # save to file "buf.pkl"
|
||||||
|
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
|
||||||
|
>>> # save to HDF5 file
|
||||||
|
>>> buf.save_hdf5('buf.hdf5')
|
||||||
>>> buf2 = ReplayBuffer(size=10)
|
>>> buf2 = ReplayBuffer(size=10)
|
||||||
>>> for i in range(15):
|
>>> for i in range(15):
|
||||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||||
@ -54,7 +59,7 @@ class ReplayBuffer:
|
|||||||
0., 0., 0., 0., 0., 0., 0.])
|
0., 0., 0., 0., 0., 0., 0.])
|
||||||
|
|
||||||
>>> # get a random sample from buffer
|
>>> # get a random sample from buffer
|
||||||
>>> # the batch_data is equal to buf[incide].
|
>>> # the batch_data is equal to buf[indice].
|
||||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||||
>>> batch_data.obs == buf[indice].obs
|
>>> batch_data.obs == buf[indice].obs
|
||||||
array([ True, True, True, True])
|
array([ True, True, True, True])
|
||||||
@ -63,6 +68,15 @@ class ReplayBuffer:
|
|||||||
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||||
>>> len(buf)
|
>>> len(buf)
|
||||||
3
|
3
|
||||||
|
>>> # load complete buffer from HDF5 file
|
||||||
|
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
|
>>> # load contents of HDF5 file into existing buffer
|
||||||
|
>>> # (only possible if size of buffer and data in file match)
|
||||||
|
>>> buf.load_contents_hdf5('buf.hdf5')
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
|
|
||||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||||
@ -167,8 +181,14 @@ class ReplayBuffer:
|
|||||||
We need it because pickling buffer does not work out-of-the-box
|
We need it because pickling buffer does not work out-of-the-box
|
||||||
("buffer.__getattr__" is customized).
|
("buffer.__getattr__" is customized).
|
||||||
"""
|
"""
|
||||||
|
self._indices = np.arange(state["_maxsize"])
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def __getstate__(self) -> dict:
|
||||||
|
exclude = {"_indices"}
|
||||||
|
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
|
||||||
|
return state
|
||||||
|
|
||||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||||
try:
|
try:
|
||||||
value = self._meta.__dict__[name]
|
value = self._meta.__dict__[name]
|
||||||
@ -359,6 +379,21 @@ class ReplayBuffer:
|
|||||||
policy=self.get(index, "policy"),
|
policy=self.get(index, "policy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_hdf5(self, path: str) -> None:
|
||||||
|
"""Save replay buffer to HDF5 file."""
|
||||||
|
with h5py.File(path, "w") as f:
|
||||||
|
to_hdf5(self.__getstate__(), f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_hdf5(
|
||||||
|
cls, path: str, device: Optional[str] = None
|
||||||
|
) -> "ReplayBuffer":
|
||||||
|
"""Load replay buffer from HDF5 file."""
|
||||||
|
with h5py.File(path, "r") as f:
|
||||||
|
buf = cls.__new__(cls)
|
||||||
|
buf.__setstate__(from_hdf5(f, device=device))
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
class ListReplayBuffer(ReplayBuffer):
|
class ListReplayBuffer(ReplayBuffer):
|
||||||
"""List-based replay buffer.
|
"""List-based replay buffer.
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Union, Optional
|
from typing import Dict, Union, Optional
|
||||||
|
|
||||||
from tianshou.data.batch import _parse_value, Batch
|
from tianshou.data.batch import _parse_value, Batch
|
||||||
|
|
||||||
@ -80,3 +82,90 @@ def to_torch_as(
|
|||||||
"""
|
"""
|
||||||
assert isinstance(y, torch.Tensor)
|
assert isinstance(y, torch.Tensor)
|
||||||
return to_torch(x, dtype=y.dtype, device=y.device)
|
return to_torch(x, dtype=y.dtype, device=y.device)
|
||||||
|
|
||||||
|
|
||||||
|
# Note: object is used as a proxy for objects that can be pickled
|
||||||
|
# Note: mypy does not support cyclic definition currently
|
||||||
|
Hdf5ConvertibleValues = Union[ # type: ignore
|
||||||
|
int, float, Batch, np.ndarray, torch.Tensor, object,
|
||||||
|
'Hdf5ConvertibleType', # type: ignore
|
||||||
|
]
|
||||||
|
|
||||||
|
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
||||||
|
"""Copy object into HDF5 group."""
|
||||||
|
|
||||||
|
def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
|
||||||
|
"""Pickle, convert to numpy array and write to HDF5 dataset."""
|
||||||
|
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
|
||||||
|
y.create_dataset(key, data=data)
|
||||||
|
|
||||||
|
for k, v in x.items():
|
||||||
|
if isinstance(v, (Batch, dict)):
|
||||||
|
# dicts and batches are both represented by groups
|
||||||
|
subgrp = y.create_group(k)
|
||||||
|
if isinstance(v, Batch):
|
||||||
|
subgrp_data = v.__getstate__()
|
||||||
|
subgrp.attrs["__data_type__"] = "Batch"
|
||||||
|
else:
|
||||||
|
subgrp_data = v
|
||||||
|
to_hdf5(subgrp_data, subgrp)
|
||||||
|
elif isinstance(v, torch.Tensor):
|
||||||
|
# PyTorch tensors are written to datasets
|
||||||
|
y.create_dataset(k, data=to_numpy(v))
|
||||||
|
y[k].attrs["__data_type__"] = "Tensor"
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
try:
|
||||||
|
# NumPy arrays are written to datasets
|
||||||
|
y.create_dataset(k, data=v)
|
||||||
|
y[k].attrs["__data_type__"] = "ndarray"
|
||||||
|
except TypeError:
|
||||||
|
# If data type is not supported by HDF5 fall back to pickle.
|
||||||
|
# This happens if dtype=object (e.g. due to entries being None)
|
||||||
|
# and possibly in other cases like structured arrays.
|
||||||
|
try:
|
||||||
|
to_hdf5_via_pickle(v, y, k)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Attempted to pickle {v.__class__.__name__} due to "
|
||||||
|
"data type not supported by HDF5 and failed."
|
||||||
|
) from e
|
||||||
|
y[k].attrs["__data_type__"] = "pickled_ndarray"
|
||||||
|
elif isinstance(v, (int, float)):
|
||||||
|
# ints and floats are stored as attributes of groups
|
||||||
|
y.attrs[k] = v
|
||||||
|
else: # resort to pickle for any other type of object
|
||||||
|
try:
|
||||||
|
to_hdf5_via_pickle(v, y, k)
|
||||||
|
except Exception as e:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No conversion to HDF5 for object of type '{type(v)}' "
|
||||||
|
"implemented and fallback to pickle failed."
|
||||||
|
) from e
|
||||||
|
y[k].attrs["__data_type__"] = v.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
def from_hdf5(
|
||||||
|
x: h5py.Group, device: Optional[str] = None
|
||||||
|
) -> Hdf5ConvertibleType:
|
||||||
|
"""Restore object from HDF5 group."""
|
||||||
|
if isinstance(x, h5py.Dataset):
|
||||||
|
# handle datasets
|
||||||
|
if x.attrs["__data_type__"] == "ndarray":
|
||||||
|
y = np.array(x)
|
||||||
|
elif x.attrs["__data_type__"] == "Tensor":
|
||||||
|
y = torch.tensor(x, device=device)
|
||||||
|
else:
|
||||||
|
y = pickle.loads(x[()])
|
||||||
|
else:
|
||||||
|
# handle groups representing a dict or a Batch
|
||||||
|
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
|
||||||
|
for k, v in x.items():
|
||||||
|
y[k] = from_hdf5(v, device)
|
||||||
|
if "__data_type__" in x.attrs:
|
||||||
|
# if dictionary represents Batch, convert to Batch
|
||||||
|
if x.attrs["__data_type__"] == "Batch":
|
||||||
|
y = Batch(y)
|
||||||
|
return y
|
||||||
|
Loading…
x
Reference in New Issue
Block a user