Preparation for #914 and #920 Changes formatting to ruff and black. Remove python 3.8 ## Additional Changes - Removed flake8 dependencies - Adjusted pre-commit. Now CI and Make use pre-commit, reducing the duplication of linting calls - Removed check-docstyle option (ruff is doing that) - Merged format and lint. In CI the format-lint step fails if any changes are done, so it fulfills the lint functionality. --------- Co-authored-by: Jiayi Weng <jiayi@openai.com>
		
			
				
	
	
		
			163 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			163 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import datetime
 | |
| import os
 | |
| import pprint
 | |
| 
 | |
| import gymnasium as gym
 | |
| import numpy as np
 | |
| import torch
 | |
| from torch.utils.tensorboard import SummaryWriter
 | |
| 
 | |
| from tianshou.data import Collector, VectorReplayBuffer
 | |
| from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv
 | |
| from tianshou.policy import BranchingDQNPolicy
 | |
| from tianshou.trainer import OffpolicyTrainer
 | |
| from tianshou.utils import TensorboardLogger
 | |
| from tianshou.utils.net.common import BranchingNet
 | |
| 
 | |
| 
 | |
| def get_args():
 | |
|     parser = argparse.ArgumentParser()
 | |
|     # task
 | |
|     parser.add_argument("--task", type=str, default="BipedalWalker-v3")
 | |
|     # network architecture
 | |
|     parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[512, 256])
 | |
|     parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128])
 | |
|     parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128])
 | |
|     parser.add_argument("--action-per-branch", type=int, default=25)
 | |
|     # training hyperparameters
 | |
|     parser.add_argument("--seed", type=int, default=0)
 | |
|     parser.add_argument("--eps-test", type=float, default=0.0)
 | |
|     parser.add_argument("--eps-train", type=float, default=0.73)
 | |
|     parser.add_argument("--eps-decay", type=float, default=5e-6)
 | |
|     parser.add_argument("--buffer-size", type=int, default=100000)
 | |
|     parser.add_argument("--lr", type=float, default=1e-4)
 | |
|     parser.add_argument("--gamma", type=float, default=0.99)
 | |
|     parser.add_argument("--target-update-freq", type=int, default=1000)
 | |
|     parser.add_argument("--epoch", type=int, default=1000)
 | |
|     parser.add_argument("--step-per-epoch", type=int, default=80000)
 | |
|     parser.add_argument("--step-per-collect", type=int, default=16)
 | |
|     parser.add_argument("--update-per-step", type=float, default=0.0625)
 | |
|     parser.add_argument("--batch-size", type=int, default=512)
 | |
|     parser.add_argument("--training-num", type=int, default=20)
 | |
|     parser.add_argument("--test-num", type=int, default=10)
 | |
|     # other
 | |
|     parser.add_argument("--logdir", type=str, default="log")
 | |
|     parser.add_argument("--render", type=float, default=0.0)
 | |
|     parser.add_argument(
 | |
|         "--device",
 | |
|         type=str,
 | |
|         default="cuda" if torch.cuda.is_available() else "cpu",
 | |
|     )
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def test_bdq(args=get_args()):
 | |
|     env = gym.make(args.task)
 | |
|     env = ContinuousToDiscrete(env, args.action_per_branch)
 | |
| 
 | |
|     args.state_shape = env.observation_space.shape or env.observation_space.n
 | |
|     args.action_shape = env.action_space.shape or env.action_space.n
 | |
|     args.num_branches = (
 | |
|         args.action_shape if isinstance(args.action_shape, int) else args.action_shape[0]
 | |
|     )
 | |
| 
 | |
|     print("Observations shape:", args.state_shape)
 | |
|     print("Num branches:", args.num_branches)
 | |
|     print("Actions per branch:", args.action_per_branch)
 | |
| 
 | |
|     # train_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
 | |
|     # you can also use tianshou.env.SubprocVectorEnv
 | |
|     train_envs = SubprocVectorEnv(
 | |
|         [
 | |
|             lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
 | |
|             for _ in range(args.training_num)
 | |
|         ],
 | |
|     )
 | |
|     # test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
 | |
|     test_envs = SubprocVectorEnv(
 | |
|         [
 | |
|             lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
 | |
|             for _ in range(args.test_num)
 | |
|         ],
 | |
|     )
 | |
|     # seed
 | |
|     np.random.seed(args.seed)
 | |
|     torch.manual_seed(args.seed)
 | |
|     train_envs.seed(args.seed)
 | |
|     test_envs.seed(args.seed)
 | |
|     # model
 | |
|     net = BranchingNet(
 | |
|         args.state_shape,
 | |
|         args.num_branches,
 | |
|         args.action_per_branch,
 | |
|         args.common_hidden_sizes,
 | |
|         args.value_hidden_sizes,
 | |
|         args.action_hidden_sizes,
 | |
|         device=args.device,
 | |
|     ).to(args.device)
 | |
|     optim = torch.optim.Adam(net.parameters(), lr=args.lr)
 | |
|     policy = BranchingDQNPolicy(net, optim, args.gamma, target_update_freq=args.target_update_freq)
 | |
|     # collector
 | |
|     train_collector = Collector(
 | |
|         policy,
 | |
|         train_envs,
 | |
|         VectorReplayBuffer(args.buffer_size, len(train_envs)),
 | |
|         exploration_noise=True,
 | |
|     )
 | |
|     test_collector = Collector(policy, test_envs, exploration_noise=False)
 | |
|     # policy.set_eps(1)
 | |
|     train_collector.collect(n_step=args.batch_size * args.training_num)
 | |
|     # log
 | |
|     current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
 | |
|     log_path = os.path.join(args.logdir, "bdq", args.task, current_time)
 | |
|     writer = SummaryWriter(log_path)
 | |
|     logger = TensorboardLogger(writer)
 | |
| 
 | |
|     def save_best_fn(policy):
 | |
|         torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | |
| 
 | |
|     def stop_fn(mean_rewards):
 | |
|         return mean_rewards >= getattr(env.spec.reward_threshold)
 | |
| 
 | |
|     def train_fn(epoch, env_step):  # exp decay
 | |
|         eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)
 | |
|         policy.set_eps(eps)
 | |
| 
 | |
|     def test_fn(epoch, env_step):
 | |
|         policy.set_eps(args.eps_test)
 | |
| 
 | |
|     # trainer
 | |
|     result = OffpolicyTrainer(
 | |
|         policy=policy,
 | |
|         train_collector=train_collector,
 | |
|         test_collector=test_collector,
 | |
|         max_epoch=args.epoch,
 | |
|         step_per_epoch=args.step_per_epoch,
 | |
|         step_per_collect=args.step_per_collect,
 | |
|         episode_per_test=args.test_num,
 | |
|         batch_size=args.batch_size,
 | |
|         update_per_step=args.update_per_step,
 | |
|         # stop_fn=stop_fn,
 | |
|         train_fn=train_fn,
 | |
|         test_fn=test_fn,
 | |
|         save_best_fn=save_best_fn,
 | |
|         logger=logger,
 | |
|     ).run()
 | |
| 
 | |
|     # assert stop_fn(result["best_reward"])
 | |
|     if __name__ == "__main__":
 | |
|         pprint.pprint(result)
 | |
|         # Let's watch its performance!
 | |
|         policy.eval()
 | |
|         policy.set_eps(args.eps_test)
 | |
|         test_envs.seed(args.seed)
 | |
|         test_collector.reset()
 | |
|         result = test_collector.collect(n_episode=args.test_num, render=args.render)
 | |
|         rews, lens = result["rews"], result["lens"]
 | |
|         print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     test_bdq(get_args())
 |