add speed stat

This commit is contained in:
Trinkle23897 2020-03-16 15:04:58 +08:00
parent cef5de8b83
commit 8b0b970c9b
6 changed files with 39 additions and 25 deletions

View File

@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Python package
name: Unittest
on:
push:

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2019 n+e
Copyright (c) 2020 TSAIL
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -41,7 +41,7 @@ setup(
'gym',
'tqdm',
'numpy',
'torch',
'torch>=1.2.0', # for supporting tensorboard
'cloudpickle',
'tensorboard',
],

View File

@ -14,7 +14,7 @@ from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__()
self.device = device
self.model = [
@ -93,8 +93,9 @@ def test_dqn(args=get_args()):
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
for epoch in range(args.epoch):
desc = f"Epoch #{epoch + 1}"
start_time = time.time()
for epoch in range(1, 1 + args.epoch):
desc = f"Epoch #{epoch}"
# train
policy.train()
policy.sync_weight()
@ -102,9 +103,9 @@ def test_dqn(args=get_args()):
with tqdm.trange(
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
for _ in t:
training_collector.collect(n_step=args.collect_per_step)
result = training_collector.collect(
n_step=args.collect_per_step)
global_step += 1
result = training_collector.stat()
loss = policy.learn(training_collector.sample(args.batch_size))
stat_loss.add(loss)
writer.add_scalar(
@ -113,26 +114,34 @@ def test_dqn(args=get_args()):
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.6f}')
length=f'{result["length"]:.6f}',
speed=f'{result["speed"]:.2f}')
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
policy.set_eps(args.eps_test)
test_collector.collect(n_episode=args.test_num)
result = test_collector.stat()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200:
break
assert best_reward >= 200
if __name__ == '__main__':
# let's watch its performance!
train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step
duration = time.time() - start_time
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
obs = env.reset()
done = False
@ -143,10 +152,9 @@ def test_dqn(args=get_args()):
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
total += rew
env.render()
time.sleep(1 / 100)
time.sleep(1 / 35)
env.close()
print(f'Final test: {total}')
return best_reward
if __name__ == '__main__':

View File

@ -15,6 +15,7 @@ class Collector(object):
super().__init__()
self.env = env
self.env_num = 1
self.collect_step = 0
self.buffer = buffer
self.policy = policy
self.process_fn = policy.process_fn
@ -40,6 +41,7 @@ class Collector(object):
self.state = None
self.stat_reward = MovAvg(stat_size)
self.stat_length = MovAvg(stat_size)
self.stat_speed = MovAvg(stat_size)
def reset_buffer(self):
if self._multi_buf:
@ -78,6 +80,8 @@ class Collector(object):
return [data]
def collect(self, n_step=0, n_episode=0, render=0):
start_time = time.time()
start_step = self.collect_step
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!"
cur_step = 0
@ -119,9 +123,11 @@ class Collector(object):
elif self._multi_buf:
self.buffer[i].add(**data)
cur_step += 1
self.collect_step += 1
else:
self.buffer.add(**data)
cur_step += 1
self.collect_step += 1
if self._done[i]:
cur_episode[i] += 1
self.stat_reward.add(self.reward[i])
@ -130,6 +136,7 @@ class Collector(object):
if self._cached_buf:
self.buffer.update(self._cached_buf[i])
cur_step += len(self._cached_buf[i])
self.collect_step += len(self._cached_buf[i])
self._cached_buf[i].reset()
if isinstance(self.state, list):
self.state[i] = None
@ -145,6 +152,7 @@ class Collector(object):
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
cur_step += 1
self.collect_step += 1
if self._done:
cur_episode += 1
self.stat_reward.add(self.reward)
@ -158,6 +166,13 @@ class Collector(object):
break
self._obs = obs_next
self._obs = obs_next
self.stat_speed.add((self.collect_step - start_step) / (
time.time() - start_time))
return {
'reward': self.stat_reward.get(),
'length': self.stat_length.get(),
'speed': self.stat_speed.get(),
}
def sample(self, batch_size):
if self._multi_buf:
@ -179,9 +194,3 @@ class Collector(object):
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data
def stat(self):
return {
'reward': self.stat_reward.get(),
'length': self.stat_length.get(),
}

View File

@ -10,7 +10,7 @@ class BasePolicy(ABC):
@abstractmethod
def __call__(self, batch, hidden_state=None):
# return Batch(policy, action, hidden)
# return Batch(act=np.array(), state=None, ...)
pass
@abstractmethod
@ -22,6 +22,3 @@ class BasePolicy(ABC):
def sync_weight(self):
pass
def exploration(self):
pass