| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-31 16:38:54 +08:00
										 |  |  | from tianshou.exploration import GaussianNoise, OUNoise | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  | from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.utils.net.common import MLP, Net | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  | from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_noise(): | 
					
						
							|  |  |  |     noise = GaussianNoise() | 
					
						
							|  |  |  |     size = (3, 4, 5) | 
					
						
							|  |  |  |     assert np.allclose(noise(size).shape, size) | 
					
						
							|  |  |  |     noise = OUNoise() | 
					
						
							|  |  |  |     noise.reset() | 
					
						
							|  |  |  |     assert np.allclose(noise(size).shape, size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_moving_average(): | 
					
						
							|  |  |  |     stat = MovAvg(10) | 
					
						
							|  |  |  |     assert np.allclose(stat.get(), 0) | 
					
						
							|  |  |  |     assert np.allclose(stat.mean(), 0) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(stat.std() ** 2, 0) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     stat.add(torch.tensor([1])) | 
					
						
							|  |  |  |     stat.add(np.array([2])) | 
					
						
							|  |  |  |     stat.add([3, 4]) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     stat.add(5.0) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert np.allclose(stat.get(), 3) | 
					
						
							|  |  |  |     assert np.allclose(stat.mean(), 3) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(stat.std() ** 2, 2) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-11 20:50:20 +08:00
										 |  |  | def test_rms(): | 
					
						
							|  |  |  |     rms = RunningMeanStd() | 
					
						
							|  |  |  |     assert np.allclose(rms.mean, 0) | 
					
						
							|  |  |  |     assert np.allclose(rms.var, 1) | 
					
						
							|  |  |  |     rms.update(np.array([[[1, 2], [3, 5]]])) | 
					
						
							|  |  |  |     rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]])) | 
					
						
							|  |  |  |     assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.0]]), atol=1e-3) | 
					
						
							| 
									
										
										
										
											2021-03-11 20:50:20 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  | def test_net(): | 
					
						
							|  |  |  |     # here test the networks that does not appear in the other script | 
					
						
							|  |  |  |     bsz = 64 | 
					
						
							| 
									
										
										
										
											2021-01-20 16:54:13 +08:00
										 |  |  |     # MLP | 
					
						
							|  |  |  |     data = torch.rand([bsz, 3]) | 
					
						
							|  |  |  |     mlp = MLP(3, 6, hidden_sizes=[128]) | 
					
						
							|  |  |  |     assert list(mlp(data).shape) == [bsz, 6] | 
					
						
							|  |  |  |     # output == 0 and len(hidden_sizes) == 0 means identity model | 
					
						
							|  |  |  |     mlp = MLP(6, 0) | 
					
						
							|  |  |  |     assert data.shape == mlp(data).shape | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     # common net | 
					
						
							|  |  |  |     state_shape = (10, 2) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     action_shape = (5,) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     data = torch.rand([bsz, *state_shape]) | 
					
						
							|  |  |  |     expect_output_shape = [bsz, *action_shape] | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     net = Net( | 
					
						
							|  |  |  |         state_shape, | 
					
						
							|  |  |  |         action_shape, | 
					
						
							|  |  |  |         hidden_sizes=[128, 128], | 
					
						
							|  |  |  |         norm_layer=torch.nn.LayerNorm, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         activation=None, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert list(net(data)[0].shape) == expect_output_shape | 
					
						
							| 
									
										
										
										
											2021-01-20 16:54:13 +08:00
										 |  |  |     assert str(net).count("LayerNorm") == 2 | 
					
						
							|  |  |  |     assert str(net).count("ReLU") == 0 | 
					
						
							|  |  |  |     Q_param = V_param = {"hidden_sizes": [128, 128]} | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     net = Net( | 
					
						
							|  |  |  |         state_shape, | 
					
						
							|  |  |  |         action_shape, | 
					
						
							|  |  |  |         hidden_sizes=[128, 128], | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         dueling_param=(Q_param, V_param), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-20 16:54:13 +08:00
										 |  |  |     assert list(net(data)[0].shape) == expect_output_shape | 
					
						
							|  |  |  |     # concat | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True) | 
					
						
							| 
									
										
										
										
											2021-01-20 16:54:13 +08:00
										 |  |  |     data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)]) | 
					
						
							|  |  |  |     expect_output_shape = [bsz, 128] | 
					
						
							|  |  |  |     assert list(net(data)[0].shape) == expect_output_shape | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     net = Net( | 
					
						
							|  |  |  |         state_shape, | 
					
						
							|  |  |  |         action_shape, | 
					
						
							|  |  |  |         hidden_sizes=[128], | 
					
						
							|  |  |  |         concat=True, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         dueling_param=(Q_param, V_param), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert list(net(data)[0].shape) == expect_output_shape | 
					
						
							|  |  |  |     # recurrent actor/critic | 
					
						
							| 
									
										
										
										
											2021-01-20 16:54:13 +08:00
										 |  |  |     data = torch.rand([bsz, *state_shape]).flatten(1) | 
					
						
							|  |  |  |     expect_output_shape = [bsz, *action_shape] | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     net = RecurrentActorProb(3, state_shape, action_shape) | 
					
						
							|  |  |  |     mu, sigma = net(data)[0] | 
					
						
							|  |  |  |     assert mu.shape == sigma.shape | 
					
						
							|  |  |  |     assert list(mu.shape) == [bsz, 5] | 
					
						
							|  |  |  |     net = RecurrentCritic(3, state_shape, action_shape) | 
					
						
							|  |  |  |     data = torch.rand([bsz, 8, np.prod(state_shape)]) | 
					
						
							|  |  |  |     act = torch.rand(expect_output_shape) | 
					
						
							|  |  |  |     assert list(net(data, act).shape) == [bsz, 1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  | def test_lr_schedulers(): | 
					
						
							|  |  |  |     initial_lr_1 = 10.0 | 
					
						
							|  |  |  |     step_size_1 = 1 | 
					
						
							|  |  |  |     gamma_1 = 0.5 | 
					
						
							|  |  |  |     net_1 = torch.nn.Linear(2, 3) | 
					
						
							|  |  |  |     optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     sched_1 = torch.optim.lr_scheduler.StepLR(optim_1, step_size=step_size_1, gamma=gamma_1) | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     initial_lr_2 = 5.0 | 
					
						
							|  |  |  |     step_size_2 = 2 | 
					
						
							|  |  |  |     gamma_2 = 0.3 | 
					
						
							|  |  |  |     net_2 = torch.nn.Linear(3, 2) | 
					
						
							|  |  |  |     optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     sched_2 = torch.optim.lr_scheduler.StepLR(optim_2, step_size=step_size_2, gamma=gamma_2) | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  |     schedulers = MultipleLRSchedulers(sched_1, sched_2) | 
					
						
							|  |  |  |     for _ in range(10): | 
					
						
							|  |  |  |         loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum() | 
					
						
							|  |  |  |         optim_1.zero_grad() | 
					
						
							|  |  |  |         loss_1.backward() | 
					
						
							|  |  |  |         optim_1.step() | 
					
						
							|  |  |  |         loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum() | 
					
						
							|  |  |  |         optim_2.zero_grad() | 
					
						
							|  |  |  |         loss_2.backward() | 
					
						
							|  |  |  |         optim_2.step() | 
					
						
							|  |  |  |         schedulers.step() | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert optim_1.state_dict()["param_groups"][0]["lr"] == ( | 
					
						
							|  |  |  |         initial_lr_1 * gamma_1 ** (10 // step_size_1) | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     assert optim_2.state_dict()["param_groups"][0]["lr"] == ( | 
					
						
							|  |  |  |         initial_lr_2 * gamma_2 ** (10 // step_size_2) | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     test_noise() | 
					
						
							|  |  |  |     test_moving_average() | 
					
						
							| 
									
										
										
										
											2021-03-11 20:50:20 +08:00
										 |  |  |     test_rms() | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     test_net() | 
					
						
							| 
									
										
										
										
											2022-04-17 08:52:30 -07:00
										 |  |  |     test_lr_schedulers() |