Tianshou/test/base/test_utils.py
ChenDRAG 9b61bc620c add logger (#295)
This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally.

Things changed:

1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer;
2. remove utils.SummaryWriter;
2021-02-24 14:48:42 +08:00

83 lines
2.8 KiB
Python

import torch
import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils.net.common import MLP, Net
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
def test_noise():
noise = GaussianNoise()
size = (3, 4, 5)
assert np.allclose(noise(size).shape, size)
noise = OUNoise()
noise.reset()
assert np.allclose(noise(size).shape, size)
def test_moving_average():
stat = MovAvg(10)
assert np.allclose(stat.get(), 0)
assert np.allclose(stat.mean(), 0)
assert np.allclose(stat.std() ** 2, 0)
stat.add(torch.tensor([1]))
stat.add(np.array([2]))
stat.add([3, 4])
stat.add(5.)
assert np.allclose(stat.get(), 3)
assert np.allclose(stat.mean(), 3)
assert np.allclose(stat.std() ** 2, 2)
def test_net():
# here test the networks that does not appear in the other script
bsz = 64
# MLP
data = torch.rand([bsz, 3])
mlp = MLP(3, 6, hidden_sizes=[128])
assert list(mlp(data).shape) == [bsz, 6]
# output == 0 and len(hidden_sizes) == 0 means identity model
mlp = MLP(6, 0)
assert data.shape == mlp(data).shape
# common net
state_shape = (10, 2)
action_shape = (5, )
data = torch.rand([bsz, *state_shape])
expect_output_shape = [bsz, *action_shape]
net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
norm_layer=torch.nn.LayerNorm, activation=None)
assert list(net(data)[0].shape) == expect_output_shape
assert str(net).count("LayerNorm") == 2
assert str(net).count("ReLU") == 0
Q_param = V_param = {"hidden_sizes": [128, 128]}
net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
dueling_param=(Q_param, V_param))
assert list(net(data)[0].shape) == expect_output_shape
# concat
net = Net(state_shape, action_shape, hidden_sizes=[128],
concat=True)
data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)])
expect_output_shape = [bsz, 128]
assert list(net(data)[0].shape) == expect_output_shape
net = Net(state_shape, action_shape, hidden_sizes=[128],
concat=True, dueling_param=(Q_param, V_param))
assert list(net(data)[0].shape) == expect_output_shape
# recurrent actor/critic
data = torch.rand([bsz, *state_shape]).flatten(1)
expect_output_shape = [bsz, *action_shape]
net = RecurrentActorProb(3, state_shape, action_shape)
mu, sigma = net(data)[0]
assert mu.shape == sigma.shape
assert list(mu.shape) == [bsz, 5]
net = RecurrentCritic(3, state_shape, action_shape)
data = torch.rand([bsz, 8, np.prod(state_shape)])
act = torch.rand(expect_output_shape)
assert list(net(data, act).shape) == [bsz, 1]
if __name__ == '__main__':
test_noise()
test_moving_average()
test_net()