Throw a warning in ListReplayBuffer.

This version update is needed because of #289, the previous v0.3.1 cannot work well under torch<=1.6.0 with cuda environment.
This commit is contained in:
n+e 2021-02-16 09:31:46 +08:00 committed by GitHub
parent d003c8e566
commit cb65b56b13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 1 deletions

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration from tianshou import data, env, utils, policy, trainer, exploration
__version__ = "0.3.1" __version__ = "0.3.2"
__all__ = [ __all__ = [
"env", "env",

View File

@ -1,5 +1,6 @@
import h5py import h5py
import torch import torch
import warnings
import numpy as np import numpy as np
from numbers import Number from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
@ -412,6 +413,7 @@ class ListReplayBuffer(ReplayBuffer):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
super().__init__(size=0, ignore_obs_next=False, **kwargs) super().__init__(size=0, ignore_obs_next=False, **kwargs)
warnings.warn("ListReplayBuffer will be removed in version 0.4.0.")
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!") raise NotImplementedError("ListReplayBuffer cannot be sampled!")