add speed stat

This commit is contained in:
Trinkle23897 2020-03-16 15:04:58 +08:00
parent cef5de8b83
commit 8b0b970c9b
6 changed files with 39 additions and 25 deletions

View File

@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Python package name: Unittest
on: on:
push: push:

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2019 n+e Copyright (c) 2020 TSAIL
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -41,7 +41,7 @@ setup(
'gym', 'gym',
'tqdm', 'tqdm',
'numpy', 'numpy',
'torch', 'torch>=1.2.0', # for supporting tensorboard
'cloudpickle', 'cloudpickle',
'tensorboard', 'tensorboard',
], ],

View File

@ -14,7 +14,7 @@ from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module): class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device): def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__() super().__init__()
self.device = device self.device = device
self.model = [ self.model = [
@ -93,8 +93,9 @@ def test_dqn(args=get_args()):
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)
best_epoch = -1 best_epoch = -1
best_reward = -1e10 best_reward = -1e10
for epoch in range(args.epoch): start_time = time.time()
desc = f"Epoch #{epoch + 1}" for epoch in range(1, 1 + args.epoch):
desc = f"Epoch #{epoch}"
# train # train
policy.train() policy.train()
policy.sync_weight() policy.sync_weight()
@ -102,9 +103,9 @@ def test_dqn(args=get_args()):
with tqdm.trange( with tqdm.trange(
0, args.step_per_epoch, desc=desc, **tqdm_config) as t: 0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
for _ in t: for _ in t:
training_collector.collect(n_step=args.collect_per_step) result = training_collector.collect(
n_step=args.collect_per_step)
global_step += 1 global_step += 1
result = training_collector.stat()
loss = policy.learn(training_collector.sample(args.batch_size)) loss = policy.learn(training_collector.sample(args.batch_size))
stat_loss.add(loss) stat_loss.add(loss)
writer.add_scalar( writer.add_scalar(
@ -113,26 +114,34 @@ def test_dqn(args=get_args()):
'length', result['length'], global_step=global_step) 'length', result['length'], global_step=global_step)
writer.add_scalar( writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step) 'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}', t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}', reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.6f}') length=f'{result["length"]:.6f}',
speed=f'{result["speed"]:.2f}')
# eval # eval
test_collector.reset_env() test_collector.reset_env()
test_collector.reset_buffer() test_collector.reset_buffer()
policy.eval() policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
test_collector.collect(n_episode=args.test_num) result = test_collector.collect(n_episode=args.test_num)
result = test_collector.stat()
if best_reward < result['reward']: if best_reward < result['reward']:
best_reward = result['reward'] best_reward = result['reward']
best_epoch = epoch best_epoch = epoch
print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, ' print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}') f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200: if args.task == 'CartPole-v0' and best_reward >= 200:
break break
assert best_reward >= 200 assert best_reward >= 200
if __name__ == '__main__': if __name__ == '__main__':
# let's watch its performance! train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step
duration = time.time() - start_time
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
obs = env.reset() obs = env.reset()
done = False done = False
@ -143,10 +152,9 @@ def test_dqn(args=get_args()):
obs, rew, done, info = env.step(action[0].detach().cpu().numpy()) obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
total += rew total += rew
env.render() env.render()
time.sleep(1 / 100) time.sleep(1 / 35)
env.close() env.close()
print(f'Final test: {total}') print(f'Final test: {total}')
return best_reward
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -15,6 +15,7 @@ class Collector(object):
super().__init__() super().__init__()
self.env = env self.env = env
self.env_num = 1 self.env_num = 1
self.collect_step = 0
self.buffer = buffer self.buffer = buffer
self.policy = policy self.policy = policy
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
@ -40,6 +41,7 @@ class Collector(object):
self.state = None self.state = None
self.stat_reward = MovAvg(stat_size) self.stat_reward = MovAvg(stat_size)
self.stat_length = MovAvg(stat_size) self.stat_length = MovAvg(stat_size)
self.stat_speed = MovAvg(stat_size)
def reset_buffer(self): def reset_buffer(self):
if self._multi_buf: if self._multi_buf:
@ -78,6 +80,8 @@ class Collector(object):
return [data] return [data]
def collect(self, n_step=0, n_episode=0, render=0): def collect(self, n_step=0, n_episode=0, render=0):
start_time = time.time()
start_step = self.collect_step
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!" "One and only one collection number specification permitted!"
cur_step = 0 cur_step = 0
@ -119,9 +123,11 @@ class Collector(object):
elif self._multi_buf: elif self._multi_buf:
self.buffer[i].add(**data) self.buffer[i].add(**data)
cur_step += 1 cur_step += 1
self.collect_step += 1
else: else:
self.buffer.add(**data) self.buffer.add(**data)
cur_step += 1 cur_step += 1
self.collect_step += 1
if self._done[i]: if self._done[i]:
cur_episode[i] += 1 cur_episode[i] += 1
self.stat_reward.add(self.reward[i]) self.stat_reward.add(self.reward[i])
@ -130,6 +136,7 @@ class Collector(object):
if self._cached_buf: if self._cached_buf:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
cur_step += len(self._cached_buf[i]) cur_step += len(self._cached_buf[i])
self.collect_step += len(self._cached_buf[i])
self._cached_buf[i].reset() self._cached_buf[i].reset()
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[i] = None self.state[i] = None
@ -145,6 +152,7 @@ class Collector(object):
self._obs, self._act[0], self._rew, self._obs, self._act[0], self._rew,
self._done, obs_next, self._info) self._done, obs_next, self._info)
cur_step += 1 cur_step += 1
self.collect_step += 1
if self._done: if self._done:
cur_episode += 1 cur_episode += 1
self.stat_reward.add(self.reward) self.stat_reward.add(self.reward)
@ -158,6 +166,13 @@ class Collector(object):
break break
self._obs = obs_next self._obs = obs_next
self._obs = obs_next self._obs = obs_next
self.stat_speed.add((self.collect_step - start_step) / (
time.time() - start_time))
return {
'reward': self.stat_reward.get(),
'length': self.stat_length.get(),
'speed': self.stat_speed.get(),
}
def sample(self, batch_size): def sample(self, batch_size):
if self._multi_buf: if self._multi_buf:
@ -179,9 +194,3 @@ class Collector(object):
batch_data, indice = self.buffer.sample(batch_size) batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice) batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data return batch_data
def stat(self):
return {
'reward': self.stat_reward.get(),
'length': self.stat_length.get(),
}

View File

@ -10,7 +10,7 @@ class BasePolicy(ABC):
@abstractmethod @abstractmethod
def __call__(self, batch, hidden_state=None): def __call__(self, batch, hidden_state=None):
# return Batch(policy, action, hidden) # return Batch(act=np.array(), state=None, ...)
pass pass
@abstractmethod @abstractmethod
@ -22,6 +22,3 @@ class BasePolicy(ABC):
def sync_weight(self): def sync_weight(self):
pass pass
def exploration(self):
pass