Tianshou/tianshou/data/utils/converter.py
Nico Gürtler 5d13d8a453
Saving and loading replay buffer with HDF5 (#261)
As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
2020-12-17 08:58:43 +08:00

172 lines
6.3 KiB
Python

import h5py
import torch
import pickle
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Dict, Union, Optional
from tianshou.data.batch import _parse_value, Batch
def to_numpy(
x: Optional[Union[Batch, dict, list, tuple, np.number, np.bool_, Number,
np.ndarray, torch.Tensor]]
) -> Union[Batch, dict, list, tuple, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor): # most often case
return x.detach().cpu().numpy()
elif isinstance(x, np.ndarray): # second often case
return x
elif isinstance(x, (np.number, np.bool_, Number)):
return np.asanyarray(x)
elif x is None:
return np.array(None, dtype=np.object)
elif isinstance(x, Batch):
x = deepcopy(x)
x.to_numpy()
return x
elif isinstance(x, dict):
return {k: to_numpy(v) for k, v in x.items()}
elif isinstance(x, (list, tuple)):
try:
return to_numpy(_parse_value(x))
except TypeError:
return [to_numpy(e) for e in x]
else: # fallback
return np.asanyarray(x)
def to_torch(
x: Union[Batch, dict, list, tuple, np.number, np.bool_, Number, np.ndarray,
torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = "cpu",
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
): # most often case
x = torch.from_numpy(x).to(device) # type: ignore
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
return x.to(device) # type: ignore
elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, dict):
return {k: to_torch(v, dtype, device) for k, v in x.items()}
elif isinstance(x, Batch):
x = deepcopy(x)
x.to_torch(dtype, device)
return x
elif isinstance(x, (list, tuple)):
try:
return to_torch(_parse_value(x), dtype, device)
except TypeError:
return [to_torch(e, dtype, device) for e in x]
else: # fallback
raise TypeError(f"object {x} cannot be converted to torch.")
def to_torch_as(
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
y: torch.Tensor,
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
"""Return an object without np.ndarray.
Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)
# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore
int, float, Batch, np.ndarray, torch.Tensor, object,
'Hdf5ConvertibleType', # type: ignore
]
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
"""Copy object into HDF5 group."""
def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
"""Pickle, convert to numpy array and write to HDF5 dataset."""
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
y.create_dataset(key, data=data)
for k, v in x.items():
if isinstance(v, (Batch, dict)):
# dicts and batches are both represented by groups
subgrp = y.create_group(k)
if isinstance(v, Batch):
subgrp_data = v.__getstate__()
subgrp.attrs["__data_type__"] = "Batch"
else:
subgrp_data = v
to_hdf5(subgrp_data, subgrp)
elif isinstance(v, torch.Tensor):
# PyTorch tensors are written to datasets
y.create_dataset(k, data=to_numpy(v))
y[k].attrs["__data_type__"] = "Tensor"
elif isinstance(v, np.ndarray):
try:
# NumPy arrays are written to datasets
y.create_dataset(k, data=v)
y[k].attrs["__data_type__"] = "ndarray"
except TypeError:
# If data type is not supported by HDF5 fall back to pickle.
# This happens if dtype=object (e.g. due to entries being None)
# and possibly in other cases like structured arrays.
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise RuntimeError(
f"Attempted to pickle {v.__class__.__name__} due to "
"data type not supported by HDF5 and failed."
) from e
y[k].attrs["__data_type__"] = "pickled_ndarray"
elif isinstance(v, (int, float)):
# ints and floats are stored as attributes of groups
y.attrs[k] = v
else: # resort to pickle for any other type of object
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise NotImplementedError(
f"No conversion to HDF5 for object of type '{type(v)}' "
"implemented and fallback to pickle failed."
) from e
y[k].attrs["__data_type__"] = v.__class__.__name__
def from_hdf5(
x: h5py.Group, device: Optional[str] = None
) -> Hdf5ConvertibleType:
"""Restore object from HDF5 group."""
if isinstance(x, h5py.Dataset):
# handle datasets
if x.attrs["__data_type__"] == "ndarray":
y = np.array(x)
elif x.attrs["__data_type__"] == "Tensor":
y = torch.tensor(x, device=device)
else:
y = pickle.loads(x[()])
else:
# handle groups representing a dict or a Batch
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
for k, v in x.items():
y[k] = from_hdf5(v, device)
if "__data_type__" in x.attrs:
# if dictionary represents Batch, convert to Batch
if x.attrs["__data_type__"] == "Batch":
y = Batch(y)
return y