This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail. Things changed in this PR: 1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv; 2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.) 3. add policy.exploration_noise(act, batch) -> act 4. small change in BasePolicy.compute_*_returns 5. move reward_metric from collector to trainer 6. fix np.asanyarray issue (different version's numpy will result in different output) 7. flake8 maxlength=88 8. polish docs and fix test Co-authored-by: n+e <trinkle23897@gmail.com>
31 lines
796 B
Python
31 lines
796 B
Python
from tianshou.data.batch import Batch
|
|
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
|
|
from tianshou.data.utils.segtree import SegmentTree
|
|
from tianshou.data.buffer import (
|
|
ReplayBuffer,
|
|
PrioritizedReplayBuffer,
|
|
ReplayBufferManager,
|
|
PrioritizedReplayBufferManager,
|
|
VectorReplayBuffer,
|
|
PrioritizedVectorReplayBuffer,
|
|
CachedReplayBuffer,
|
|
)
|
|
from tianshou.data.collector import Collector, AsyncCollector
|
|
|
|
__all__ = [
|
|
"Batch",
|
|
"to_numpy",
|
|
"to_torch",
|
|
"to_torch_as",
|
|
"SegmentTree",
|
|
"ReplayBuffer",
|
|
"PrioritizedReplayBuffer",
|
|
"ReplayBufferManager",
|
|
"PrioritizedReplayBufferManager",
|
|
"VectorReplayBuffer",
|
|
"PrioritizedVectorReplayBuffer",
|
|
"CachedReplayBuffer",
|
|
"Collector",
|
|
"AsyncCollector",
|
|
]
|