v0.3.2 (#292)
Throw a warning in ListReplayBuffer. This version update is needed because of #289, the previous v0.3.1 cannot work well under torch<=1.6.0 with cuda environment.
This commit is contained in:
parent
d003c8e566
commit
cb65b56b13
@ -1,7 +1,7 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer, exploration
|
from tianshou import data, env, utils, policy, trainer, exploration
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.3.1"
|
__version__ = "0.3.2"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import h5py
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||||
@ -412,6 +413,7 @@ class ListReplayBuffer(ReplayBuffer):
|
|||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||||
|
warnings.warn("ListReplayBuffer will be removed in version 0.4.0.")
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||||
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user