Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
		
			
				
	
	
		
			196 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			196 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import os
 | |
| import pprint
 | |
| 
 | |
| import gymnasium as gym
 | |
| import numpy as np
 | |
| import pytest
 | |
| import torch
 | |
| from gymnasium.spaces import Box
 | |
| from torch.utils.tensorboard import SummaryWriter
 | |
| 
 | |
| from tianshou.data import Collector, VectorReplayBuffer
 | |
| from tianshou.policy import A2CPolicy, ImitationPolicy
 | |
| from tianshou.policy.base import BasePolicy
 | |
| from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
 | |
| from tianshou.utils import TensorboardLogger
 | |
| from tianshou.utils.net.common import ActorCritic, Net
 | |
| from tianshou.utils.net.discrete import Actor, Critic
 | |
| 
 | |
| try:
 | |
|     import envpool
 | |
| except ImportError:
 | |
|     envpool = None
 | |
| 
 | |
| 
 | |
| def get_args() -> argparse.Namespace:
 | |
|     parser = argparse.ArgumentParser()
 | |
|     parser.add_argument("--task", type=str, default="CartPole-v0")
 | |
|     parser.add_argument("--reward-threshold", type=float, default=None)
 | |
|     parser.add_argument("--seed", type=int, default=1)
 | |
|     parser.add_argument("--buffer-size", type=int, default=20000)
 | |
|     parser.add_argument("--lr", type=float, default=1e-3)
 | |
|     parser.add_argument("--il-lr", type=float, default=1e-3)
 | |
|     parser.add_argument("--gamma", type=float, default=0.9)
 | |
|     parser.add_argument("--epoch", type=int, default=10)
 | |
|     parser.add_argument("--step-per-epoch", type=int, default=50000)
 | |
|     parser.add_argument("--il-step-per-epoch", type=int, default=1000)
 | |
|     parser.add_argument("--episode-per-collect", type=int, default=16)
 | |
|     parser.add_argument("--step-per-collect", type=int, default=16)
 | |
|     parser.add_argument("--update-per-step", type=float, default=1 / 16)
 | |
|     parser.add_argument("--repeat-per-collect", type=int, default=1)
 | |
|     parser.add_argument("--batch-size", type=int, default=64)
 | |
|     parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
 | |
|     parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128])
 | |
|     parser.add_argument("--training-num", type=int, default=16)
 | |
|     parser.add_argument("--test-num", type=int, default=100)
 | |
|     parser.add_argument("--logdir", type=str, default="log")
 | |
|     parser.add_argument("--render", type=float, default=0.0)
 | |
|     parser.add_argument(
 | |
|         "--device",
 | |
|         type=str,
 | |
|         default="cuda" if torch.cuda.is_available() else "cpu",
 | |
|     )
 | |
|     # a2c special
 | |
|     parser.add_argument("--vf-coef", type=float, default=0.5)
 | |
|     parser.add_argument("--ent-coef", type=float, default=0.0)
 | |
|     parser.add_argument("--max-grad-norm", type=float, default=None)
 | |
|     parser.add_argument("--gae-lambda", type=float, default=1.0)
 | |
|     parser.add_argument("--rew-norm", action="store_true", default=False)
 | |
|     return parser.parse_known_args()[0]
 | |
| 
 | |
| 
 | |
| @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
 | |
| def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
 | |
|     # if you want to use python vector env, please refer to other test scripts
 | |
|     train_envs = env = envpool.make(
 | |
|         args.task,
 | |
|         env_type="gymnasium",
 | |
|         num_envs=args.training_num,
 | |
|         seed=args.seed,
 | |
|     )
 | |
|     test_envs = envpool.make(
 | |
|         args.task,
 | |
|         env_type="gymnasium",
 | |
|         num_envs=args.test_num,
 | |
|         seed=args.seed,
 | |
|     )
 | |
|     args.state_shape = env.observation_space.shape or env.observation_space.n
 | |
|     args.action_shape = env.action_space.shape or env.action_space.n
 | |
|     if args.reward_threshold is None:
 | |
|         default_reward_threshold = {"CartPole-v0": 195}
 | |
|         args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
 | |
|     # seed
 | |
|     np.random.seed(args.seed)
 | |
|     torch.manual_seed(args.seed)
 | |
|     # model
 | |
|     net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
 | |
|     actor = Actor(net, args.action_shape, device=args.device).to(args.device)
 | |
|     critic = Critic(net, device=args.device).to(args.device)
 | |
|     optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
 | |
|     dist = torch.distributions.Categorical
 | |
|     policy: A2CPolicy = A2CPolicy(
 | |
|         actor=actor,
 | |
|         critic=critic,
 | |
|         optim=optim,
 | |
|         dist_fn=dist,
 | |
|         action_scaling=isinstance(env.action_space, Box),
 | |
|         discount_factor=args.gamma,
 | |
|         gae_lambda=args.gae_lambda,
 | |
|         vf_coef=args.vf_coef,
 | |
|         ent_coef=args.ent_coef,
 | |
|         max_grad_norm=args.max_grad_norm,
 | |
|         reward_normalization=args.rew_norm,
 | |
|         action_space=env.action_space,
 | |
|     )
 | |
|     # collector
 | |
|     train_collector = Collector(
 | |
|         policy,
 | |
|         train_envs,
 | |
|         VectorReplayBuffer(args.buffer_size, len(train_envs)),
 | |
|     )
 | |
|     train_collector.reset()
 | |
|     test_collector = Collector(policy, test_envs)
 | |
|     test_collector.reset()
 | |
|     # log
 | |
|     log_path = os.path.join(args.logdir, args.task, "a2c")
 | |
|     writer = SummaryWriter(log_path)
 | |
|     logger = TensorboardLogger(writer)
 | |
| 
 | |
|     def save_best_fn(policy: BasePolicy) -> None:
 | |
|         torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | |
| 
 | |
|     def stop_fn(mean_rewards: float) -> bool:
 | |
|         return mean_rewards >= args.reward_threshold
 | |
| 
 | |
|     # trainer
 | |
|     result = OnpolicyTrainer(
 | |
|         policy=policy,
 | |
|         train_collector=train_collector,
 | |
|         test_collector=test_collector,
 | |
|         max_epoch=args.epoch,
 | |
|         step_per_epoch=args.step_per_epoch,
 | |
|         repeat_per_collect=args.repeat_per_collect,
 | |
|         episode_per_test=args.test_num,
 | |
|         batch_size=args.batch_size,
 | |
|         episode_per_collect=args.episode_per_collect,
 | |
|         stop_fn=stop_fn,
 | |
|         save_best_fn=save_best_fn,
 | |
|         logger=logger,
 | |
|     ).run()
 | |
|     assert stop_fn(result.best_reward)
 | |
| 
 | |
|     if __name__ == "__main__":
 | |
|         pprint.pprint(result)
 | |
|         # Let's watch its performance!
 | |
|         env = gym.make(args.task)
 | |
|         policy.eval()
 | |
|         collector = Collector(policy, env)
 | |
|         collector_stats = collector.collect(n_episode=1, render=args.render)
 | |
|         print(collector_stats)
 | |
| 
 | |
|     policy.eval()
 | |
|     # here we define an imitation collector with a trivial policy
 | |
|     # if args.task == 'CartPole-v0':
 | |
|     #     env.spec.reward_threshold = 190  # lower the goal
 | |
|     net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
 | |
|     net = Actor(net, args.action_shape, device=args.device).to(args.device)
 | |
|     optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
 | |
|     il_policy: ImitationPolicy = ImitationPolicy(
 | |
|         actor=net,
 | |
|         optim=optim,
 | |
|         action_space=env.action_space,
 | |
|     )
 | |
|     il_test_collector = Collector(
 | |
|         il_policy,
 | |
|         envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed),
 | |
|     )
 | |
|     train_collector.reset()
 | |
|     result = OffpolicyTrainer(
 | |
|         policy=il_policy,
 | |
|         train_collector=train_collector,
 | |
|         test_collector=il_test_collector,
 | |
|         max_epoch=args.epoch,
 | |
|         step_per_epoch=args.il_step_per_epoch,
 | |
|         step_per_collect=args.step_per_collect,
 | |
|         episode_per_test=args.test_num,
 | |
|         batch_size=args.batch_size,
 | |
|         stop_fn=stop_fn,
 | |
|         save_best_fn=save_best_fn,
 | |
|         logger=logger,
 | |
|     ).run()
 | |
|     assert stop_fn(result.best_reward)
 | |
| 
 | |
|     if __name__ == "__main__":
 | |
|         pprint.pprint(result)
 | |
|         # Let's watch its performance!
 | |
|         env = gym.make(args.task)
 | |
|         il_policy.eval()
 | |
|         collector = Collector(il_policy, env)
 | |
|         collector_stats = collector.collect(n_episode=1, render=args.render)
 | |
|         print(collector_stats)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     test_a2c_with_il()
 |