Changed .keys() to get_keys() in batch class (#1105)

Solves the inconsistency that iter(Batch) is not the same as Batch.keys() by "deprecating" the implicit .keys() method

Closes: #922
This commit is contained in:
Erni 2024-04-12 12:15:37 +02:00 committed by GitHub
parent 03e9af04b7
commit e2a2a6856d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 19 additions and 13 deletions

View File

@ -556,7 +556,7 @@ def test_batch_standard_compatibility() -> None:
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst`
assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore
assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore
with pytest.raises(TypeError):
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))

View File

@ -1379,11 +1379,14 @@ def test_custom_key() -> None:
buffer.add(batch)
sampled_batch, _ = buffer.sample(1)
# Check if they have the same keys
assert set(batch.keys()) == set(
sampled_batch.keys(),
), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys()))
assert set(batch.get_keys()) == set(
sampled_batch.get_keys(),
), "Batches have different keys: {} and {}".format(
set(batch.get_keys()),
set(sampled_batch.get_keys()),
)
# Compare the values for each key
for key in batch.keys():
for key in batch.get_keys():
if isinstance(batch.__dict__[key], np.ndarray) and isinstance(
sampled_batch.__dict__[key],
np.ndarray,

View File

@ -1,6 +1,6 @@
import pprint
import warnings
from collections.abc import Collection, Iterable, Iterator, Sequence
from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence
from copy import deepcopy
from numbers import Number
from types import EllipsisType
@ -185,8 +185,8 @@ def alloc_by_keys_diff(
This mainly is an internal method, use it only if you know what you are doing.
"""
for key in batch.keys():
if key in meta.keys():
for key in batch.get_keys():
if key in meta.get_keys():
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
alloc_by_keys_diff(meta[key], batch[key], size, stack)
elif isinstance(meta[key], Batch) and meta[key].is_empty():
@ -441,6 +441,9 @@ class Batch(BatchProtocol):
result[k] = v
return result
def get_keys(self) -> KeysView:
return self.__dict__.keys()
def to_list_of_dicts(self) -> list[dict[str, Any]]:
return [entry.to_dict() for entry in self]

View File

@ -253,12 +253,12 @@ class ReplayBuffer:
"""
# preprocess batch
new_batch = Batch()
for key in batch.keys():
for key in batch.get_keys():
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(
batch.keys(),
batch.get_keys(),
) # important to do after preprocess batch
stacked_batch = buffer_ids is not None
if stacked_batch:

View File

@ -127,11 +127,11 @@ class ReplayBufferManager(ReplayBuffer):
"""
# preprocess batch
new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()):
for key in set(self._reserved_keys).intersection(batch.get_keys()):
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys())
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys())
if self._save_only_last_obs:
batch.obs = batch.obs[:, -1]
if not self._save_obs_next:

View File

@ -225,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy):
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
continue
tmp_batch = batch[agent_index]
if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray):
if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
if not hasattr(tmp_batch.obs, "mask"):