69 lines
2.1 KiB
Python
Raw Normal View History

import torch
import numpy as np
from tianshou.utils import MovAvg
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DQN
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
def test_noise():
noise = GaussianNoise()
size = (3, 4, 5)
assert np.allclose(noise(size).shape, size)
noise = OUNoise()
noise.reset()
assert np.allclose(noise(size).shape, size)
def test_moving_average():
stat = MovAvg(10)
assert np.allclose(stat.get(), 0)
assert np.allclose(stat.mean(), 0)
assert np.allclose(stat.std() ** 2, 0)
stat.add(torch.tensor([1]))
stat.add(np.array([2]))
stat.add([3, 4])
stat.add(5.)
assert np.allclose(stat.get(), 3)
assert np.allclose(stat.mean(), 3)
assert np.allclose(stat.std() ** 2, 2)
def test_net():
# here test the networks that does not appear in the other script
bsz = 64
# common net
state_shape = (10, 2)
action_shape = (5, )
data = torch.rand([bsz, *state_shape])
expect_output_shape = [bsz, *action_shape]
net = Net(3, state_shape, action_shape, norm_layer=torch.nn.LayerNorm)
assert list(net(data)[0].shape) == expect_output_shape
net = Net(3, state_shape, action_shape, dueling=(2, 2))
assert list(net(data)[0].shape) == expect_output_shape
# recurrent actor/critic
data = data.flatten(1)
net = RecurrentActorProb(3, state_shape, action_shape)
mu, sigma = net(data)[0]
assert mu.shape == sigma.shape
assert list(mu.shape) == [bsz, 5]
net = RecurrentCritic(3, state_shape, action_shape)
data = torch.rand([bsz, 8, np.prod(state_shape)])
act = torch.rand(expect_output_shape)
assert list(net(data, act).shape) == [bsz, 1]
# DQN
state_shape = (4, 84, 84)
action_shape = (6, )
data = np.random.rand(bsz, *state_shape)
expect_output_shape = [bsz, *action_shape]
net = DQN(*state_shape, action_shape)
assert list(net(data)[0].shape) == expect_output_shape
if __name__ == '__main__':
test_noise()
test_moving_average()
test_net()