30 lines
1.0 KiB
Python
Raw Normal View History

2020-03-13 17:49:22 +08:00
import numpy as np
2020-03-11 09:09:56 +08:00
class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
2020-03-13 17:49:22 +08:00
2020-03-11 09:09:56 +08:00
def __init__(self, **kwargs):
super().__init__()
2020-03-12 22:20:33 +08:00
self.__dict__.update(kwargs)
def update(self, **kwargs):
2020-03-11 09:09:56 +08:00
self.__dict__.update(kwargs)
2020-03-13 17:49:22 +08:00
def append(self, batch):
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = batch.__dict__[k]
elif isinstance(batch.__dict__[k], np.ndarray):
self.__dict__[k] = np.concatenate([
self.__dict__[k], batch.__dict__[k]])
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
raise TypeError(
2020-03-13 21:47:17 +08:00
'Do not support append with type {} in class Batch.'
2020-03-13 17:49:22 +08:00
.format(type(batch.__dict__[k])))