Changed .keys() to get_keys() in batch class (#1105)
Solves the inconsistency that iter(Batch) is not the same as Batch.keys() by "deprecating" the implicit .keys() method Closes: #922
This commit is contained in:
		
							parent
							
								
									03e9af04b7
								
							
						
					
					
						commit
						e2a2a6856d
					
				@ -556,7 +556,7 @@ def test_batch_standard_compatibility() -> None:
 | 
			
		||||
    batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
 | 
			
		||||
    batch_mean = np.mean(batch)
 | 
			
		||||
    assert isinstance(batch_mean, Batch)  # type: ignore  # mypy doesn't know but it works, cf. `batch.rst`
 | 
			
		||||
    assert sorted(batch_mean.keys()) == ["a", "b", "c"]  # type: ignore
 | 
			
		||||
    assert sorted(batch_mean.get_keys()) == ["a", "b", "c"]  # type: ignore
 | 
			
		||||
    with pytest.raises(TypeError):
 | 
			
		||||
        len(batch_mean)
 | 
			
		||||
    assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
 | 
			
		||||
 | 
			
		||||
@ -1379,11 +1379,14 @@ def test_custom_key() -> None:
 | 
			
		||||
    buffer.add(batch)
 | 
			
		||||
    sampled_batch, _ = buffer.sample(1)
 | 
			
		||||
    # Check if they have the same keys
 | 
			
		||||
    assert set(batch.keys()) == set(
 | 
			
		||||
        sampled_batch.keys(),
 | 
			
		||||
    ), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys()))
 | 
			
		||||
    assert set(batch.get_keys()) == set(
 | 
			
		||||
        sampled_batch.get_keys(),
 | 
			
		||||
    ), "Batches have different keys: {} and {}".format(
 | 
			
		||||
        set(batch.get_keys()),
 | 
			
		||||
        set(sampled_batch.get_keys()),
 | 
			
		||||
    )
 | 
			
		||||
    # Compare the values for each key
 | 
			
		||||
    for key in batch.keys():
 | 
			
		||||
    for key in batch.get_keys():
 | 
			
		||||
        if isinstance(batch.__dict__[key], np.ndarray) and isinstance(
 | 
			
		||||
            sampled_batch.__dict__[key],
 | 
			
		||||
            np.ndarray,
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
import pprint
 | 
			
		||||
import warnings
 | 
			
		||||
from collections.abc import Collection, Iterable, Iterator, Sequence
 | 
			
		||||
from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from numbers import Number
 | 
			
		||||
from types import EllipsisType
 | 
			
		||||
@ -185,8 +185,8 @@ def alloc_by_keys_diff(
 | 
			
		||||
 | 
			
		||||
    This mainly is an internal method, use it only if you know what you are doing.
 | 
			
		||||
    """
 | 
			
		||||
    for key in batch.keys():
 | 
			
		||||
        if key in meta.keys():
 | 
			
		||||
    for key in batch.get_keys():
 | 
			
		||||
        if key in meta.get_keys():
 | 
			
		||||
            if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
 | 
			
		||||
                alloc_by_keys_diff(meta[key], batch[key], size, stack)
 | 
			
		||||
            elif isinstance(meta[key], Batch) and meta[key].is_empty():
 | 
			
		||||
@ -441,6 +441,9 @@ class Batch(BatchProtocol):
 | 
			
		||||
            result[k] = v
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def get_keys(self) -> KeysView:
 | 
			
		||||
        return self.__dict__.keys()
 | 
			
		||||
 | 
			
		||||
    def to_list_of_dicts(self) -> list[dict[str, Any]]:
 | 
			
		||||
        return [entry.to_dict() for entry in self]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -253,12 +253,12 @@ class ReplayBuffer:
 | 
			
		||||
        """
 | 
			
		||||
        # preprocess batch
 | 
			
		||||
        new_batch = Batch()
 | 
			
		||||
        for key in batch.keys():
 | 
			
		||||
        for key in batch.get_keys():
 | 
			
		||||
            new_batch.__dict__[key] = batch[key]
 | 
			
		||||
        batch = new_batch
 | 
			
		||||
        batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
 | 
			
		||||
        assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(
 | 
			
		||||
            batch.keys(),
 | 
			
		||||
            batch.get_keys(),
 | 
			
		||||
        )  # important to do after preprocess batch
 | 
			
		||||
        stacked_batch = buffer_ids is not None
 | 
			
		||||
        if stacked_batch:
 | 
			
		||||
 | 
			
		||||
@ -127,11 +127,11 @@ class ReplayBufferManager(ReplayBuffer):
 | 
			
		||||
        """
 | 
			
		||||
        # preprocess batch
 | 
			
		||||
        new_batch = Batch()
 | 
			
		||||
        for key in set(self._reserved_keys).intersection(batch.keys()):
 | 
			
		||||
        for key in set(self._reserved_keys).intersection(batch.get_keys()):
 | 
			
		||||
            new_batch.__dict__[key] = batch[key]
 | 
			
		||||
        batch = new_batch
 | 
			
		||||
        batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
 | 
			
		||||
        assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys())
 | 
			
		||||
        assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys())
 | 
			
		||||
        if self._save_only_last_obs:
 | 
			
		||||
            batch.obs = batch.obs[:, -1]
 | 
			
		||||
        if not self._save_obs_next:
 | 
			
		||||
 | 
			
		||||
@ -225,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy):
 | 
			
		||||
                results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
 | 
			
		||||
                continue
 | 
			
		||||
            tmp_batch = batch[agent_index]
 | 
			
		||||
            if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray):
 | 
			
		||||
            if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray):
 | 
			
		||||
                # reward can be empty Batch (after initial reset) or nparray.
 | 
			
		||||
                tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
 | 
			
		||||
            if not hasattr(tmp_batch.obs, "mask"):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user