2020-08-02 18:24:40 +08:00
|
|
|
import copy
|
|
|
|
import pickle
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from tianshou.data import Batch
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def data():
|
2021-11-02 17:08:00 +01:00
|
|
|
print("Initializing data...")
|
2020-08-02 18:24:40 +08:00
|
|
|
np.random.seed(0)
|
2021-09-03 05:05:04 +08:00
|
|
|
batch_set = [
|
|
|
|
Batch(
|
2023-08-25 23:40:56 +02:00
|
|
|
a=list(np.arange(1e3)),
|
|
|
|
b={"b1": (3.14, 3.14), "b2": np.arange(1e3)},
|
2021-11-02 17:08:00 +01:00
|
|
|
c=i,
|
2023-08-25 23:40:56 +02:00
|
|
|
)
|
|
|
|
for i in np.arange(int(1e4))
|
2021-09-03 05:05:04 +08:00
|
|
|
]
|
2020-08-02 18:24:40 +08:00
|
|
|
batch0 = Batch(
|
|
|
|
a=np.ones((3, 4), dtype=np.float64),
|
|
|
|
b=Batch(
|
2023-08-25 23:40:56 +02:00
|
|
|
c=np.ones((1,), dtype=np.float64),
|
2020-08-02 18:24:40 +08:00
|
|
|
d=torch.ones((3, 3, 3), dtype=torch.float32),
|
2021-11-02 17:08:00 +01:00
|
|
|
e=list(range(3)),
|
|
|
|
),
|
2020-08-02 18:24:40 +08:00
|
|
|
)
|
|
|
|
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
|
|
|
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
|
|
|
batch_len = int(1e4)
|
2023-08-25 23:40:56 +02:00
|
|
|
batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)], reward=np.arange(batch_len))
|
2021-09-03 05:05:04 +08:00
|
|
|
indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
|
|
|
|
slice_dict = {
|
2023-08-25 23:40:56 +02:00
|
|
|
"obs": [np.arange(20) for _ in np.arange(batch_len // 10)],
|
|
|
|
"reward": np.arange(batch_len // 10),
|
2021-09-03 05:05:04 +08:00
|
|
|
}
|
|
|
|
dict_set = [
|
|
|
|
{
|
2023-08-25 23:40:56 +02:00
|
|
|
"obs": np.arange(20),
|
|
|
|
"info": "this is info",
|
|
|
|
"reward": 0,
|
|
|
|
}
|
|
|
|
for _ in np.arange(1e2)
|
2021-09-03 05:05:04 +08:00
|
|
|
]
|
2020-08-02 18:24:40 +08:00
|
|
|
batch4 = Batch(
|
|
|
|
a=np.ones((10000, 4), dtype=np.float64),
|
|
|
|
b=Batch(
|
2023-08-25 23:40:56 +02:00
|
|
|
c=np.ones((1,), dtype=np.float64),
|
2020-08-02 18:24:40 +08:00
|
|
|
d=torch.ones((1000, 1000), dtype=torch.float32),
|
2021-11-02 17:08:00 +01:00
|
|
|
e=np.arange(1000),
|
|
|
|
),
|
2020-08-02 18:24:40 +08:00
|
|
|
)
|
|
|
|
|
2021-11-02 17:08:00 +01:00
|
|
|
print("Initialized")
|
2020-08-19 15:00:24 +08:00
|
|
|
return {
|
2023-08-25 23:40:56 +02:00
|
|
|
"batch_set": batch_set,
|
|
|
|
"batch0": batch0,
|
|
|
|
"batchs1": batchs1,
|
|
|
|
"batchs2": batchs2,
|
|
|
|
"batch3": batch3,
|
|
|
|
"indexs": indexs,
|
|
|
|
"dict_set": dict_set,
|
|
|
|
"slice_dict": slice_dict,
|
|
|
|
"batch4": batch4,
|
2020-08-19 15:00:24 +08:00
|
|
|
}
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_init(data):
|
|
|
|
"""Test Batch __init__()."""
|
|
|
|
for _ in np.arange(10):
|
2023-08-25 23:40:56 +02:00
|
|
|
_ = Batch(data["batch_set"])
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_item(data):
|
|
|
|
"""Test get with item."""
|
|
|
|
for _ in np.arange(1e5):
|
2023-08-25 23:40:56 +02:00
|
|
|
_ = data["batch3"][data["indexs"]]
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_attr(data):
|
|
|
|
"""Test get with attr."""
|
|
|
|
for _ in np.arange(1e6):
|
2023-08-25 23:40:56 +02:00
|
|
|
data["batch3"].get("obs")
|
|
|
|
data["batch3"].get("reward")
|
|
|
|
_, _ = data["batch3"].obs, data["batch3"].reward
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_set_item(data):
|
|
|
|
"""Test set with item."""
|
|
|
|
for _ in np.arange(1e4):
|
2023-08-25 23:40:56 +02:00
|
|
|
data["batch3"][data["indexs"]] = data["slice_dict"]
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_set_attr(data):
|
|
|
|
"""Test set with attr."""
|
|
|
|
for _ in np.arange(1e4):
|
2023-08-25 23:40:56 +02:00
|
|
|
data["batch3"].c = np.arange(1e3)
|
|
|
|
data["batch3"].obs = data["dict_set"]
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_numpy_torch_convert(data):
|
|
|
|
"""Test conversion between numpy and torch."""
|
2021-11-02 17:08:00 +01:00
|
|
|
for _ in np.arange(1e4): # not sure what's wrong in torch==1.10.0
|
2023-08-25 23:40:56 +02:00
|
|
|
data["batch4"].to_torch()
|
|
|
|
data["batch4"].to_numpy()
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_pickle(data):
|
|
|
|
for _ in np.arange(1e4):
|
2023-08-25 23:40:56 +02:00
|
|
|
pickle.loads(pickle.dumps(data["batch4"]))
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_cat(data):
|
2023-08-25 23:40:56 +02:00
|
|
|
"""Test cat."""
|
2020-08-02 18:24:40 +08:00
|
|
|
for i in range(10000):
|
2023-08-25 23:40:56 +02:00
|
|
|
Batch.cat((data["batch0"], data["batch0"]))
|
|
|
|
data["batchs1"][i].cat_(data["batch0"])
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_stack(data):
|
2023-08-25 23:40:56 +02:00
|
|
|
"""Test stack."""
|
2020-08-02 18:24:40 +08:00
|
|
|
for i in range(10000):
|
2023-08-25 23:40:56 +02:00
|
|
|
Batch.stack((data["batch0"], data["batch0"]))
|
|
|
|
data["batchs2"][i].stack_([data["batch0"]])
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
if __name__ == "__main__":
|
2020-08-02 18:24:40 +08:00
|
|
|
pytest.main(["-s", "-k batch_profile", "--durations=0", "-v"])
|