add an example of bullet env (experiment from jiqizhixin) (#15)
* add_pybullet_ens_test test on pybullet envs modify some log config * delete DS_Store file * add pybullet_envs test add HalfCheetahBulletEnv-v0 test modify log config * fix pep 8 errors * add pybullet to dev * delete a line * by pass F401 * add log_interval to onpolicy_trainer * add comments * Update halfcheetahBullet_v0_sac.py
This commit is contained in:
		
							parent
							
								
									974ade8019
								
							
						
					
					
						commit
						9380368ca3
					
				
							
								
								
									
										120
									
								
								examples/halfcheetahBullet_v0_sac.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								examples/halfcheetahBullet_v0_sac.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,120 @@ | |||||||
|  | import os | ||||||
|  | import gym | ||||||
|  | import torch | ||||||
|  | import pprint | ||||||
|  | import argparse | ||||||
|  | import numpy as np | ||||||
|  | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | 
 | ||||||
|  | from tianshou.env import SubprocVectorEnv | ||||||
|  | from tianshou.policy import SACPolicy | ||||||
|  | from tianshou.trainer import offpolicy_trainer | ||||||
|  | from tianshou.data import Collector, ReplayBuffer | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     import pybullet_envs | ||||||
|  | except ImportError: | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
|  | from continuous_net import ActorProb, Critic | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_args(): | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument('--task', type=str, default='HalfCheetahBulletEnv-v0') | ||||||
|  |     parser.add_argument('--run-id', type=str, default='test') | ||||||
|  |     parser.add_argument('--seed', type=int, default=1626) | ||||||
|  |     parser.add_argument('--buffer-size', type=int, default=20000) | ||||||
|  |     parser.add_argument('--actor-lr', type=float, default=3e-4) | ||||||
|  |     parser.add_argument('--critic-lr', type=float, default=1e-3) | ||||||
|  |     parser.add_argument('--gamma', type=float, default=0.99) | ||||||
|  |     parser.add_argument('--tau', type=float, default=0.005) | ||||||
|  |     parser.add_argument('--alpha', type=float, default=0.2) | ||||||
|  |     parser.add_argument('--epoch', type=int, default=200) | ||||||
|  |     parser.add_argument('--step-per-epoch', type=int, default=1000) | ||||||
|  |     parser.add_argument('--collect-per-step', type=int, default=10) | ||||||
|  |     parser.add_argument('--batch-size', type=int, default=128) | ||||||
|  |     parser.add_argument('--layer-num', type=int, default=1) | ||||||
|  |     parser.add_argument('--training-num', type=int, default=8) | ||||||
|  |     parser.add_argument('--test-num', type=int, default=4) | ||||||
|  |     parser.add_argument('--logdir', type=str, default='log') | ||||||
|  |     parser.add_argument('--log-interval', type=int, default=100) | ||||||
|  |     parser.add_argument('--render', type=float, default=0.) | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--device', type=str, | ||||||
|  |         default='cuda' if torch.cuda.is_available() else 'cpu') | ||||||
|  |     args = parser.parse_known_args()[0] | ||||||
|  |     return args | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_sac(args=get_args()): | ||||||
|  |     torch.set_num_threads(1) | ||||||
|  |     env = gym.make(args.task) | ||||||
|  |     args.state_shape = env.observation_space.shape or env.observation_space.n | ||||||
|  |     args.action_shape = env.action_space.shape or env.action_space.n | ||||||
|  |     args.max_action = env.action_space.high[0] | ||||||
|  |     # you can also use tianshou.env.SubprocVectorEnv | ||||||
|  |     # train_envs = gym.make(args.task) | ||||||
|  |     train_envs = SubprocVectorEnv( | ||||||
|  |         [lambda: gym.make(args.task) for _ in range(args.training_num)]) | ||||||
|  |     # test_envs = gym.make(args.task) | ||||||
|  |     test_envs = SubprocVectorEnv( | ||||||
|  |         [lambda: gym.make(args.task) for _ in range(args.test_num)]) | ||||||
|  |     # seed | ||||||
|  |     np.random.seed(args.seed) | ||||||
|  |     torch.manual_seed(args.seed) | ||||||
|  |     train_envs.seed(args.seed) | ||||||
|  |     test_envs.seed(args.seed) | ||||||
|  |     # model | ||||||
|  |     actor = ActorProb( | ||||||
|  |         args.layer_num, args.state_shape, args.action_shape, | ||||||
|  |         args.max_action, args.device | ||||||
|  |     ).to(args.device) | ||||||
|  |     actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||||||
|  |     critic1 = Critic( | ||||||
|  |         args.layer_num, args.state_shape, args.action_shape, args.device | ||||||
|  |     ).to(args.device) | ||||||
|  |     critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) | ||||||
|  |     critic2 = Critic( | ||||||
|  |         args.layer_num, args.state_shape, args.action_shape, args.device | ||||||
|  |     ).to(args.device) | ||||||
|  |     critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) | ||||||
|  |     policy = SACPolicy( | ||||||
|  |         actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, | ||||||
|  |         args.tau, args.gamma, args.alpha, | ||||||
|  |         [env.action_space.low[0], env.action_space.high[0]], | ||||||
|  |         reward_normalization=True, ignore_done=True) | ||||||
|  |     # collector | ||||||
|  |     train_collector = Collector( | ||||||
|  |         policy, train_envs, ReplayBuffer(args.buffer_size)) | ||||||
|  |     test_collector = Collector(policy, test_envs) | ||||||
|  |     # train_collector.collect(n_step=args.buffer_size) | ||||||
|  |     # log | ||||||
|  |     log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) | ||||||
|  |     writer = SummaryWriter(log_path) | ||||||
|  | 
 | ||||||
|  |     def stop_fn(x): | ||||||
|  |         return x >= env.spec.reward_threshold | ||||||
|  | 
 | ||||||
|  |     # trainer | ||||||
|  |     result = offpolicy_trainer( | ||||||
|  |         policy, train_collector, test_collector, args.epoch, | ||||||
|  |         args.step_per_epoch, args.collect_per_step, args.test_num, | ||||||
|  |         args.batch_size, stop_fn=stop_fn, | ||||||
|  |         writer=writer, log_interval=args.log_interval) | ||||||
|  |     assert stop_fn(result['best_reward']) | ||||||
|  |     train_collector.close() | ||||||
|  |     test_collector.close() | ||||||
|  |     if __name__ == '__main__': | ||||||
|  |         pprint.pprint(result) | ||||||
|  |         # Let's watch its performance! | ||||||
|  |         env = gym.make(args.task) | ||||||
|  |         collector = Collector(policy, env) | ||||||
|  |         result = collector.collect(n_episode=1, render=args.render) | ||||||
|  |         print(f'Final reward: {result["rew"]}, length: {result["len"]}') | ||||||
|  |         collector.close() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     __all__ = ('pybullet_envs',)  # Avoid F401 error :) | ||||||
|  |     test_sac() | ||||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							| @ -69,5 +69,8 @@ setup( | |||||||
|         'mujoco': [ |         'mujoco': [ | ||||||
|             'mujoco_py', |             'mujoco_py', | ||||||
|         ], |         ], | ||||||
|  |         'pybullet': [ | ||||||
|  |             'pybullet', | ||||||
|  |         ], | ||||||
|     }, |     }, | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from tianshou.trainer import test_episode, gather_info | |||||||
| def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, | def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, | ||||||
|                       step_per_epoch, collect_per_step, episode_per_test, |                       step_per_epoch, collect_per_step, episode_per_test, | ||||||
|                       batch_size, train_fn=None, test_fn=None, stop_fn=None, |                       batch_size, train_fn=None, test_fn=None, stop_fn=None, | ||||||
|                       writer=None, verbose=True, task=''): |                       writer=None, log_interval=1, verbose=True, task=''): | ||||||
|     global_step = 0 |     global_step = 0 | ||||||
|     best_epoch, best_reward = -1, -1 |     best_epoch, best_reward = -1, -1 | ||||||
|     stat = {} |     stat = {} | ||||||
| @ -45,7 +45,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, | |||||||
|                     losses = policy.learn(train_collector.sample(batch_size)) |                     losses = policy.learn(train_collector.sample(batch_size)) | ||||||
|                     for k in result.keys(): |                     for k in result.keys(): | ||||||
|                         data[k] = f'{result[k]:.2f}' |                         data[k] = f'{result[k]:.2f}' | ||||||
|                         if writer: |                         if writer and global_step % log_interval == 0: | ||||||
|                             writer.add_scalar( |                             writer.add_scalar( | ||||||
|                                 k + '_' + task if task else k, |                                 k + '_' + task if task else k, | ||||||
|                                 result[k], global_step=global_step) |                                 result[k], global_step=global_step) | ||||||
| @ -54,7 +54,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, | |||||||
|                             stat[k] = MovAvg() |                             stat[k] = MovAvg() | ||||||
|                         stat[k].add(losses[k]) |                         stat[k].add(losses[k]) | ||||||
|                         data[k] = f'{stat[k].get():.6f}' |                         data[k] = f'{stat[k].get():.6f}' | ||||||
|                         if writer: |                         if writer and global_step % log_interval == 0: | ||||||
|                             writer.add_scalar( |                             writer.add_scalar( | ||||||
|                                 k + '_' + task if task else k, |                                 k + '_' + task if task else k, | ||||||
|                                 stat[k].get(), global_step=global_step) |                                 stat[k].get(), global_step=global_step) | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, | |||||||
|                      step_per_epoch, collect_per_step, repeat_per_collect, |                      step_per_epoch, collect_per_step, repeat_per_collect, | ||||||
|                      episode_per_test, batch_size, |                      episode_per_test, batch_size, | ||||||
|                      train_fn=None, test_fn=None, stop_fn=None, |                      train_fn=None, test_fn=None, stop_fn=None, | ||||||
|                      writer=None, verbose=True, task=''): |                      writer=None, log_interval=1, verbose=True, task=''): | ||||||
|     global_step = 0 |     global_step = 0 | ||||||
|     best_epoch, best_reward = -1, -1 |     best_epoch, best_reward = -1, -1 | ||||||
|     stat = {} |     stat = {} | ||||||
| @ -50,7 +50,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, | |||||||
|                 global_step += step |                 global_step += step | ||||||
|                 for k in result.keys(): |                 for k in result.keys(): | ||||||
|                     data[k] = f'{result[k]:.2f}' |                     data[k] = f'{result[k]:.2f}' | ||||||
|                     if writer: |                     if writer and global_step % log_interval == 0: | ||||||
|                         writer.add_scalar( |                         writer.add_scalar( | ||||||
|                             k + '_' + task if task else k, |                             k + '_' + task if task else k, | ||||||
|                             result[k], global_step=global_step) |                             result[k], global_step=global_step) | ||||||
| @ -59,7 +59,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, | |||||||
|                         stat[k] = MovAvg() |                         stat[k] = MovAvg() | ||||||
|                     stat[k].add(losses[k]) |                     stat[k].add(losses[k]) | ||||||
|                     data[k] = f'{stat[k].get():.6f}' |                     data[k] = f'{stat[k].get():.6f}' | ||||||
|                     if writer and global_step: |                     if writer and global_step % log_interval == 0: | ||||||
|                         writer.add_scalar( |                         writer.add_scalar( | ||||||
|                             k + '_' + task if task else k, |                             k + '_' + task if task else k, | ||||||
|                             stat[k].get(), global_step=global_step) |                             stat[k].get(), global_step=global_step) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user