Tianshou/test/base/test_utils.py
wizardsheng c6f2648e87
Add C51 algorithm (#266)
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
2021-01-06 10:17:45 +08:00

92 lines
3.0 KiB
Python

import torch
import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DQN, C51
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
def test_noise():
noise = GaussianNoise()
size = (3, 4, 5)
assert np.allclose(noise(size).shape, size)
noise = OUNoise()
noise.reset()
assert np.allclose(noise(size).shape, size)
def test_moving_average():
stat = MovAvg(10)
assert np.allclose(stat.get(), 0)
assert np.allclose(stat.mean(), 0)
assert np.allclose(stat.std() ** 2, 0)
stat.add(torch.tensor([1]))
stat.add(np.array([2]))
stat.add([3, 4])
stat.add(5.)
assert np.allclose(stat.get(), 3)
assert np.allclose(stat.mean(), 3)
assert np.allclose(stat.std() ** 2, 2)
def test_net():
# here test the networks that does not appear in the other script
bsz = 64
# common net
state_shape = (10, 2)
action_shape = (5, )
data = torch.rand([bsz, *state_shape])
expect_output_shape = [bsz, *action_shape]
net = Net(3, state_shape, action_shape, norm_layer=torch.nn.LayerNorm)
assert list(net(data)[0].shape) == expect_output_shape
net = Net(3, state_shape, action_shape, dueling=(2, 2))
assert list(net(data)[0].shape) == expect_output_shape
# recurrent actor/critic
data = data.flatten(1)
net = RecurrentActorProb(3, state_shape, action_shape)
mu, sigma = net(data)[0]
assert mu.shape == sigma.shape
assert list(mu.shape) == [bsz, 5]
net = RecurrentCritic(3, state_shape, action_shape)
data = torch.rand([bsz, 8, np.prod(state_shape)])
act = torch.rand(expect_output_shape)
assert list(net(data, act).shape) == [bsz, 1]
# DQN
state_shape = (4, 84, 84)
action_shape = (6, )
data = np.random.rand(bsz, *state_shape)
expect_output_shape = [bsz, *action_shape]
net = DQN(*state_shape, action_shape)
assert list(net(data)[0].shape) == expect_output_shape
num_atoms = 51
net = C51(*state_shape, action_shape, num_atoms)
expect_output_shape = [bsz, *action_shape, num_atoms]
assert list(net(data)[0].shape) == expect_output_shape
def test_summary_writer():
# get first instance by key of `default` or your own key
writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
assert writer1.log_dir == "log/test_sw/first"
writer2 = SummaryWriter.get_instance()
assert writer1 is writer2
# create new instance by specify a new key
writer3 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
assert writer3.log_dir == "log/test_sw/second"
writer4 = SummaryWriter.get_instance(key="second")
assert writer3 is writer4
assert writer1 is not writer3
assert writer1.log_dir != writer4.log_dir
if __name__ == '__main__':
test_noise()
test_moving_average()
test_net()
test_summary_writer()