add speed stat
This commit is contained in:
		
							parent
							
								
									cef5de8b83
								
							
						
					
					
						commit
						8b0b970c9b
					
				
							
								
								
									
										2
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
 | 
			
		||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
 | 
			
		||||
 | 
			
		||||
name: Python package
 | 
			
		||||
name: Unittest
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								LICENSE
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								LICENSE
									
									
									
									
									
								
							@ -1,6 +1,6 @@
 | 
			
		||||
MIT License
 | 
			
		||||
 | 
			
		||||
Copyright (c) 2019 n+e
 | 
			
		||||
Copyright (c) 2020 TSAIL
 | 
			
		||||
 | 
			
		||||
Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -41,7 +41,7 @@ setup(
 | 
			
		||||
        'gym',
 | 
			
		||||
        'tqdm',
 | 
			
		||||
        'numpy',
 | 
			
		||||
        'torch',
 | 
			
		||||
        'torch>=1.2.0',  # for supporting tensorboard
 | 
			
		||||
        'cloudpickle',
 | 
			
		||||
        'tensorboard',
 | 
			
		||||
    ],
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Net(nn.Module):
 | 
			
		||||
    def __init__(self, layer_num, state_shape, action_shape, device):
 | 
			
		||||
    def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.device = device
 | 
			
		||||
        self.model = [
 | 
			
		||||
@ -93,8 +93,9 @@ def test_dqn(args=get_args()):
 | 
			
		||||
    writer = SummaryWriter(args.logdir)
 | 
			
		||||
    best_epoch = -1
 | 
			
		||||
    best_reward = -1e10
 | 
			
		||||
    for epoch in range(args.epoch):
 | 
			
		||||
        desc = f"Epoch #{epoch + 1}"
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    for epoch in range(1, 1 + args.epoch):
 | 
			
		||||
        desc = f"Epoch #{epoch}"
 | 
			
		||||
        # train
 | 
			
		||||
        policy.train()
 | 
			
		||||
        policy.sync_weight()
 | 
			
		||||
@ -102,9 +103,9 @@ def test_dqn(args=get_args()):
 | 
			
		||||
        with tqdm.trange(
 | 
			
		||||
                0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
 | 
			
		||||
            for _ in t:
 | 
			
		||||
                training_collector.collect(n_step=args.collect_per_step)
 | 
			
		||||
                result = training_collector.collect(
 | 
			
		||||
                    n_step=args.collect_per_step)
 | 
			
		||||
                global_step += 1
 | 
			
		||||
                result = training_collector.stat()
 | 
			
		||||
                loss = policy.learn(training_collector.sample(args.batch_size))
 | 
			
		||||
                stat_loss.add(loss)
 | 
			
		||||
                writer.add_scalar(
 | 
			
		||||
@ -113,26 +114,34 @@ def test_dqn(args=get_args()):
 | 
			
		||||
                    'length', result['length'], global_step=global_step)
 | 
			
		||||
                writer.add_scalar(
 | 
			
		||||
                    'loss', stat_loss.get(), global_step=global_step)
 | 
			
		||||
                writer.add_scalar(
 | 
			
		||||
                    'speed', result['speed'], global_step=global_step)
 | 
			
		||||
                t.set_postfix(loss=f'{stat_loss.get():.6f}',
 | 
			
		||||
                              reward=f'{result["reward"]:.6f}',
 | 
			
		||||
                              length=f'{result["length"]:.6f}')
 | 
			
		||||
                              length=f'{result["length"]:.6f}',
 | 
			
		||||
                              speed=f'{result["speed"]:.2f}')
 | 
			
		||||
        # eval
 | 
			
		||||
        test_collector.reset_env()
 | 
			
		||||
        test_collector.reset_buffer()
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        test_collector.collect(n_episode=args.test_num)
 | 
			
		||||
        result = test_collector.stat()
 | 
			
		||||
        result = test_collector.collect(n_episode=args.test_num)
 | 
			
		||||
        if best_reward < result['reward']:
 | 
			
		||||
            best_reward = result['reward']
 | 
			
		||||
            best_epoch = epoch
 | 
			
		||||
        print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
 | 
			
		||||
        print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
 | 
			
		||||
              f'best_reward: {best_reward:.6f} in #{best_epoch}')
 | 
			
		||||
        if args.task == 'CartPole-v0' and best_reward >= 200:
 | 
			
		||||
            break
 | 
			
		||||
    assert best_reward >= 200
 | 
			
		||||
    if __name__ == '__main__':
 | 
			
		||||
        # let's watch its performance!
 | 
			
		||||
        train_cnt = training_collector.collect_step
 | 
			
		||||
        test_cnt = test_collector.collect_step
 | 
			
		||||
        duration = time.time() - start_time
 | 
			
		||||
        print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
 | 
			
		||||
              f'in {duration:.2f}s, '
 | 
			
		||||
              f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        obs = env.reset()
 | 
			
		||||
        done = False
 | 
			
		||||
@ -143,10 +152,9 @@ def test_dqn(args=get_args()):
 | 
			
		||||
            obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
 | 
			
		||||
            total += rew
 | 
			
		||||
            env.render()
 | 
			
		||||
            time.sleep(1 / 100)
 | 
			
		||||
            time.sleep(1 / 35)
 | 
			
		||||
        env.close()
 | 
			
		||||
        print(f'Final test: {total}')
 | 
			
		||||
    return best_reward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ class Collector(object):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.env = env
 | 
			
		||||
        self.env_num = 1
 | 
			
		||||
        self.collect_step = 0
 | 
			
		||||
        self.buffer = buffer
 | 
			
		||||
        self.policy = policy
 | 
			
		||||
        self.process_fn = policy.process_fn
 | 
			
		||||
@ -40,6 +41,7 @@ class Collector(object):
 | 
			
		||||
        self.state = None
 | 
			
		||||
        self.stat_reward = MovAvg(stat_size)
 | 
			
		||||
        self.stat_length = MovAvg(stat_size)
 | 
			
		||||
        self.stat_speed = MovAvg(stat_size)
 | 
			
		||||
 | 
			
		||||
    def reset_buffer(self):
 | 
			
		||||
        if self._multi_buf:
 | 
			
		||||
@ -78,6 +80,8 @@ class Collector(object):
 | 
			
		||||
            return [data]
 | 
			
		||||
 | 
			
		||||
    def collect(self, n_step=0, n_episode=0, render=0):
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        start_step = self.collect_step
 | 
			
		||||
        assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
 | 
			
		||||
            "One and only one collection number specification permitted!"
 | 
			
		||||
        cur_step = 0
 | 
			
		||||
@ -119,9 +123,11 @@ class Collector(object):
 | 
			
		||||
                    elif self._multi_buf:
 | 
			
		||||
                        self.buffer[i].add(**data)
 | 
			
		||||
                        cur_step += 1
 | 
			
		||||
                        self.collect_step += 1
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.buffer.add(**data)
 | 
			
		||||
                        cur_step += 1
 | 
			
		||||
                        self.collect_step += 1
 | 
			
		||||
                    if self._done[i]:
 | 
			
		||||
                        cur_episode[i] += 1
 | 
			
		||||
                        self.stat_reward.add(self.reward[i])
 | 
			
		||||
@ -130,6 +136,7 @@ class Collector(object):
 | 
			
		||||
                        if self._cached_buf:
 | 
			
		||||
                            self.buffer.update(self._cached_buf[i])
 | 
			
		||||
                            cur_step += len(self._cached_buf[i])
 | 
			
		||||
                            self.collect_step += len(self._cached_buf[i])
 | 
			
		||||
                            self._cached_buf[i].reset()
 | 
			
		||||
                        if isinstance(self.state, list):
 | 
			
		||||
                            self.state[i] = None
 | 
			
		||||
@ -145,6 +152,7 @@ class Collector(object):
 | 
			
		||||
                    self._obs, self._act[0], self._rew,
 | 
			
		||||
                    self._done, obs_next, self._info)
 | 
			
		||||
                cur_step += 1
 | 
			
		||||
                self.collect_step += 1
 | 
			
		||||
                if self._done:
 | 
			
		||||
                    cur_episode += 1
 | 
			
		||||
                    self.stat_reward.add(self.reward)
 | 
			
		||||
@ -158,6 +166,13 @@ class Collector(object):
 | 
			
		||||
                break
 | 
			
		||||
            self._obs = obs_next
 | 
			
		||||
        self._obs = obs_next
 | 
			
		||||
        self.stat_speed.add((self.collect_step - start_step) / (
 | 
			
		||||
            time.time() - start_time))
 | 
			
		||||
        return {
 | 
			
		||||
            'reward': self.stat_reward.get(),
 | 
			
		||||
            'length': self.stat_length.get(),
 | 
			
		||||
            'speed': self.stat_speed.get(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def sample(self, batch_size):
 | 
			
		||||
        if self._multi_buf:
 | 
			
		||||
@ -179,9 +194,3 @@ class Collector(object):
 | 
			
		||||
            batch_data, indice = self.buffer.sample(batch_size)
 | 
			
		||||
            batch_data = self.process_fn(batch_data, self.buffer, indice)
 | 
			
		||||
        return batch_data
 | 
			
		||||
 | 
			
		||||
    def stat(self):
 | 
			
		||||
        return {
 | 
			
		||||
            'reward': self.stat_reward.get(),
 | 
			
		||||
            'length': self.stat_length.get(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ class BasePolicy(ABC):
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def __call__(self, batch, hidden_state=None):
 | 
			
		||||
        # return Batch(policy, action, hidden)
 | 
			
		||||
        # return Batch(act=np.array(), state=None, ...)
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
@ -22,6 +22,3 @@ class BasePolicy(ABC):
 | 
			
		||||
 | 
			
		||||
    def sync_weight(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def exploration(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user