Changed .keys() to get_keys() in batch class (#1105)
Solves the inconsistency that iter(Batch) is not the same as Batch.keys() by "deprecating" the implicit .keys() method Closes: #922
This commit is contained in:
parent
03e9af04b7
commit
e2a2a6856d
@ -556,7 +556,7 @@ def test_batch_standard_compatibility() -> None:
|
||||
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
|
||||
batch_mean = np.mean(batch)
|
||||
assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst`
|
||||
assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore
|
||||
assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore
|
||||
with pytest.raises(TypeError):
|
||||
len(batch_mean)
|
||||
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
|
||||
|
@ -1379,11 +1379,14 @@ def test_custom_key() -> None:
|
||||
buffer.add(batch)
|
||||
sampled_batch, _ = buffer.sample(1)
|
||||
# Check if they have the same keys
|
||||
assert set(batch.keys()) == set(
|
||||
sampled_batch.keys(),
|
||||
), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys()))
|
||||
assert set(batch.get_keys()) == set(
|
||||
sampled_batch.get_keys(),
|
||||
), "Batches have different keys: {} and {}".format(
|
||||
set(batch.get_keys()),
|
||||
set(sampled_batch.get_keys()),
|
||||
)
|
||||
# Compare the values for each key
|
||||
for key in batch.keys():
|
||||
for key in batch.get_keys():
|
||||
if isinstance(batch.__dict__[key], np.ndarray) and isinstance(
|
||||
sampled_batch.__dict__[key],
|
||||
np.ndarray,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pprint
|
||||
import warnings
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from types import EllipsisType
|
||||
@ -185,8 +185,8 @@ def alloc_by_keys_diff(
|
||||
|
||||
This mainly is an internal method, use it only if you know what you are doing.
|
||||
"""
|
||||
for key in batch.keys():
|
||||
if key in meta.keys():
|
||||
for key in batch.get_keys():
|
||||
if key in meta.get_keys():
|
||||
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
|
||||
alloc_by_keys_diff(meta[key], batch[key], size, stack)
|
||||
elif isinstance(meta[key], Batch) and meta[key].is_empty():
|
||||
@ -441,6 +441,9 @@ class Batch(BatchProtocol):
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
def get_keys(self) -> KeysView:
|
||||
return self.__dict__.keys()
|
||||
|
||||
def to_list_of_dicts(self) -> list[dict[str, Any]]:
|
||||
return [entry.to_dict() for entry in self]
|
||||
|
||||
|
@ -253,12 +253,12 @@ class ReplayBuffer:
|
||||
"""
|
||||
# preprocess batch
|
||||
new_batch = Batch()
|
||||
for key in batch.keys():
|
||||
for key in batch.get_keys():
|
||||
new_batch.__dict__[key] = batch[key]
|
||||
batch = new_batch
|
||||
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
|
||||
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(
|
||||
batch.keys(),
|
||||
batch.get_keys(),
|
||||
) # important to do after preprocess batch
|
||||
stacked_batch = buffer_ids is not None
|
||||
if stacked_batch:
|
||||
|
@ -127,11 +127,11 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
"""
|
||||
# preprocess batch
|
||||
new_batch = Batch()
|
||||
for key in set(self._reserved_keys).intersection(batch.keys()):
|
||||
for key in set(self._reserved_keys).intersection(batch.get_keys()):
|
||||
new_batch.__dict__[key] = batch[key]
|
||||
batch = new_batch
|
||||
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
|
||||
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys())
|
||||
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys())
|
||||
if self._save_only_last_obs:
|
||||
batch.obs = batch.obs[:, -1]
|
||||
if not self._save_obs_next:
|
||||
|
@ -225,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
|
||||
continue
|
||||
tmp_batch = batch[agent_index]
|
||||
if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray):
|
||||
if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray):
|
||||
# reward can be empty Batch (after initial reset) or nparray.
|
||||
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
|
||||
if not hasattr(tmp_batch.obs, "mask"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user