Custom keys support in ReplayBuffer (#903)
Issue: Custom keys support in ReplayBuffer #902 Modified `ReplayBuffer` `add` and `__getitem__` methods. Added `test_custom_key()` to test_buffer.py
This commit is contained in:
parent
61182450b6
commit
80a698be52
@ -166,3 +166,13 @@ isort
|
||||
yapf
|
||||
pydocstyle
|
||||
Args
|
||||
tuples
|
||||
tuple
|
||||
Multi
|
||||
multi
|
||||
parameterized
|
||||
Proximal
|
||||
metadata
|
||||
GPU
|
||||
Dopamine
|
||||
builtin
|
||||
|
@ -1336,6 +1336,65 @@ def test_from_data():
|
||||
os.remove(path)
|
||||
|
||||
|
||||
def test_custom_key():
|
||||
batch = Batch(
|
||||
**{
|
||||
'obs_next':
|
||||
np.array(
|
||||
[
|
||||
[
|
||||
1.174, -0.1151, -0.609, -0.5205, -0.9316, 3.236, -2.418, 0.386,
|
||||
0.2227, -0.5117, 2.293
|
||||
]
|
||||
]
|
||||
),
|
||||
'rew':
|
||||
np.array([4.28125]),
|
||||
'act':
|
||||
np.array([[-0.3088, -0.4636, 0.4956]]),
|
||||
'truncated':
|
||||
np.array([False]),
|
||||
'obs':
|
||||
np.array(
|
||||
[
|
||||
[
|
||||
1.193, -0.1203, -0.6123, -0.519, -0.9434, 3.32, -2.266, 0.9116,
|
||||
0.623, 0.1259, 0.363
|
||||
]
|
||||
]
|
||||
),
|
||||
'terminated':
|
||||
np.array([False]),
|
||||
'done':
|
||||
np.array([False]),
|
||||
'returns':
|
||||
np.array([74.70343082]),
|
||||
'info':
|
||||
Batch(),
|
||||
'policy':
|
||||
Batch(),
|
||||
}
|
||||
)
|
||||
buffer_size = len(batch.rew)
|
||||
buffer = ReplayBuffer(buffer_size)
|
||||
buffer.add(batch)
|
||||
sampled_batch, _ = buffer.sample(1)
|
||||
# Check if they have the same keys
|
||||
assert set(batch.keys()) == set(sampled_batch.keys()), \
|
||||
"Batches have different keys: {} and {}".format(
|
||||
set(batch.keys()), set(sampled_batch.keys()))
|
||||
# Compare the values for each key
|
||||
for key in batch.keys():
|
||||
if isinstance(batch.__dict__[key], np.ndarray
|
||||
) and isinstance(sampled_batch.__dict__[key], np.ndarray):
|
||||
assert np.allclose(batch.__dict__[key], sampled_batch.__dict__[key]), \
|
||||
"Value mismatch for key: {}".format(key)
|
||||
if isinstance(batch.__dict__[key],
|
||||
Batch) and isinstance(sampled_batch.__dict__[key], Batch):
|
||||
assert batch.__dict__[key].is_empty()
|
||||
assert sampled_batch.__dict__[key].is_empty()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
@ -1351,3 +1410,4 @@ if __name__ == '__main__':
|
||||
test_multibuf_hdf5()
|
||||
test_from_data()
|
||||
test_herreplaybuffer()
|
||||
test_custom_key()
|
||||
|
@ -220,9 +220,8 @@ class ReplayBuffer:
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Add a batch of data into replay buffer.
|
||||
|
||||
:param Batch batch: the input data batch. Its keys must belong to the 7
|
||||
input keys, and "obs", "act", "rew", "terminated", "truncated" is
|
||||
required.
|
||||
:param Batch batch: the input data batch. "obs", "act", "rew",
|
||||
"terminated", "truncated" are required keys.
|
||||
:param buffer_ids: to make consistent with other buffer's add function; if it
|
||||
is not None, we assume the input batch's first dimension is always 1.
|
||||
|
||||
@ -232,12 +231,12 @@ class ReplayBuffer:
|
||||
"""
|
||||
# preprocess batch
|
||||
new_batch = Batch()
|
||||
for key in set(self._input_keys).intersection(batch.keys()):
|
||||
for key in batch.keys():
|
||||
new_batch.__dict__[key] = batch[key]
|
||||
batch = new_batch
|
||||
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
|
||||
assert set(["obs", "act", "rew", "terminated", "truncated",
|
||||
"done"]).issubset(batch.keys())
|
||||
assert set(["obs", "act", "rew", "terminated", "truncated", "done"]
|
||||
).issubset(batch.keys()) # important to do after preprocess batch
|
||||
stacked_batch = buffer_ids is not None
|
||||
if stacked_batch:
|
||||
assert len(batch) == 1
|
||||
@ -376,14 +375,18 @@ class ReplayBuffer:
|
||||
obs_next = self.get(indices, "obs_next", Batch())
|
||||
else:
|
||||
obs_next = self.get(self.next(indices), "obs", Batch())
|
||||
return Batch(
|
||||
obs=obs,
|
||||
act=self.act[indices],
|
||||
rew=self.rew[indices],
|
||||
terminated=self.terminated[indices],
|
||||
truncated=self.truncated[indices],
|
||||
done=self.done[indices],
|
||||
obs_next=obs_next,
|
||||
info=self.get(indices, "info", Batch()),
|
||||
policy=self.get(indices, "policy", Batch()),
|
||||
)
|
||||
batch_dict = {
|
||||
"obs": obs,
|
||||
"act": self.act[indices],
|
||||
"rew": self.rew[indices],
|
||||
"terminated": self.terminated[indices],
|
||||
"truncated": self.truncated[indices],
|
||||
"done": self.done[indices],
|
||||
"obs_next": obs_next,
|
||||
"info": self.get(indices, "info", Batch()),
|
||||
"policy": self.get(indices, "policy", Batch()),
|
||||
}
|
||||
for key in self._meta.__dict__.keys():
|
||||
if key not in self._input_keys:
|
||||
batch_dict[key] = self._meta[key][indices]
|
||||
return Batch(batch_dict)
|
||||
|
Loading…
x
Reference in New Issue
Block a user