add speed stat
This commit is contained in:
parent
cef5de8b83
commit
8b0b970c9b
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -1,7 +1,7 @@
|
|||||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
name: Python package
|
name: Unittest
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2019 n+e
|
Copyright (c) 2020 TSAIL
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
2
setup.py
2
setup.py
@ -41,7 +41,7 @@ setup(
|
|||||||
'gym',
|
'gym',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'numpy',
|
'numpy',
|
||||||
'torch',
|
'torch>=1.2.0', # for supporting tensorboard
|
||||||
'cloudpickle',
|
'cloudpickle',
|
||||||
'tensorboard',
|
'tensorboard',
|
||||||
],
|
],
|
||||||
|
@ -14,7 +14,7 @@ from tianshou.data import Collector, ReplayBuffer
|
|||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device):
|
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = [
|
self.model = [
|
||||||
@ -93,8 +93,9 @@ def test_dqn(args=get_args()):
|
|||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
best_epoch = -1
|
best_epoch = -1
|
||||||
best_reward = -1e10
|
best_reward = -1e10
|
||||||
for epoch in range(args.epoch):
|
start_time = time.time()
|
||||||
desc = f"Epoch #{epoch + 1}"
|
for epoch in range(1, 1 + args.epoch):
|
||||||
|
desc = f"Epoch #{epoch}"
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
policy.sync_weight()
|
policy.sync_weight()
|
||||||
@ -102,9 +103,9 @@ def test_dqn(args=get_args()):
|
|||||||
with tqdm.trange(
|
with tqdm.trange(
|
||||||
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||||
for _ in t:
|
for _ in t:
|
||||||
training_collector.collect(n_step=args.collect_per_step)
|
result = training_collector.collect(
|
||||||
|
n_step=args.collect_per_step)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
result = training_collector.stat()
|
|
||||||
loss = policy.learn(training_collector.sample(args.batch_size))
|
loss = policy.learn(training_collector.sample(args.batch_size))
|
||||||
stat_loss.add(loss)
|
stat_loss.add(loss)
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
@ -113,26 +114,34 @@ def test_dqn(args=get_args()):
|
|||||||
'length', result['length'], global_step=global_step)
|
'length', result['length'], global_step=global_step)
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
'loss', stat_loss.get(), global_step=global_step)
|
'loss', stat_loss.get(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'speed', result['speed'], global_step=global_step)
|
||||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||||
reward=f'{result["reward"]:.6f}',
|
reward=f'{result["reward"]:.6f}',
|
||||||
length=f'{result["length"]:.6f}')
|
length=f'{result["length"]:.6f}',
|
||||||
|
speed=f'{result["speed"]:.2f}')
|
||||||
# eval
|
# eval
|
||||||
test_collector.reset_env()
|
test_collector.reset_env()
|
||||||
test_collector.reset_buffer()
|
test_collector.reset_buffer()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_collector.collect(n_episode=args.test_num)
|
result = test_collector.collect(n_episode=args.test_num)
|
||||||
result = test_collector.stat()
|
|
||||||
if best_reward < result['reward']:
|
if best_reward < result['reward']:
|
||||||
best_reward = result['reward']
|
best_reward = result['reward']
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
|
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
if args.task == 'CartPole-v0' and best_reward >= 200:
|
if args.task == 'CartPole-v0' and best_reward >= 200:
|
||||||
break
|
break
|
||||||
assert best_reward >= 200
|
assert best_reward >= 200
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# let's watch its performance!
|
train_cnt = training_collector.collect_step
|
||||||
|
test_cnt = test_collector.collect_step
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||||
|
f'in {duration:.2f}s, '
|
||||||
|
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||||
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
done = False
|
done = False
|
||||||
@ -143,10 +152,9 @@ def test_dqn(args=get_args()):
|
|||||||
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
|
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
|
||||||
total += rew
|
total += rew
|
||||||
env.render()
|
env.render()
|
||||||
time.sleep(1 / 100)
|
time.sleep(1 / 35)
|
||||||
env.close()
|
env.close()
|
||||||
print(f'Final test: {total}')
|
print(f'Final test: {total}')
|
||||||
return best_reward
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -15,6 +15,7 @@ class Collector(object):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
|
self.collect_step = 0
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
@ -40,6 +41,7 @@ class Collector(object):
|
|||||||
self.state = None
|
self.state = None
|
||||||
self.stat_reward = MovAvg(stat_size)
|
self.stat_reward = MovAvg(stat_size)
|
||||||
self.stat_length = MovAvg(stat_size)
|
self.stat_length = MovAvg(stat_size)
|
||||||
|
self.stat_speed = MovAvg(stat_size)
|
||||||
|
|
||||||
def reset_buffer(self):
|
def reset_buffer(self):
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
@ -78,6 +80,8 @@ class Collector(object):
|
|||||||
return [data]
|
return [data]
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=0):
|
def collect(self, n_step=0, n_episode=0, render=0):
|
||||||
|
start_time = time.time()
|
||||||
|
start_step = self.collect_step
|
||||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
@ -119,9 +123,11 @@ class Collector(object):
|
|||||||
elif self._multi_buf:
|
elif self._multi_buf:
|
||||||
self.buffer[i].add(**data)
|
self.buffer[i].add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
|
self.collect_step += 1
|
||||||
else:
|
else:
|
||||||
self.buffer.add(**data)
|
self.buffer.add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
|
self.collect_step += 1
|
||||||
if self._done[i]:
|
if self._done[i]:
|
||||||
cur_episode[i] += 1
|
cur_episode[i] += 1
|
||||||
self.stat_reward.add(self.reward[i])
|
self.stat_reward.add(self.reward[i])
|
||||||
@ -130,6 +136,7 @@ class Collector(object):
|
|||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
self.buffer.update(self._cached_buf[i])
|
self.buffer.update(self._cached_buf[i])
|
||||||
cur_step += len(self._cached_buf[i])
|
cur_step += len(self._cached_buf[i])
|
||||||
|
self.collect_step += len(self._cached_buf[i])
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[i] = None
|
self.state[i] = None
|
||||||
@ -145,6 +152,7 @@ class Collector(object):
|
|||||||
self._obs, self._act[0], self._rew,
|
self._obs, self._act[0], self._rew,
|
||||||
self._done, obs_next, self._info)
|
self._done, obs_next, self._info)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
|
self.collect_step += 1
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
self.stat_reward.add(self.reward)
|
self.stat_reward.add(self.reward)
|
||||||
@ -158,6 +166,13 @@ class Collector(object):
|
|||||||
break
|
break
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
|
self.stat_speed.add((self.collect_step - start_step) / (
|
||||||
|
time.time() - start_time))
|
||||||
|
return {
|
||||||
|
'reward': self.stat_reward.get(),
|
||||||
|
'length': self.stat_length.get(),
|
||||||
|
'speed': self.stat_speed.get(),
|
||||||
|
}
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
@ -179,9 +194,3 @@ class Collector(object):
|
|||||||
batch_data, indice = self.buffer.sample(batch_size)
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
return batch_data
|
return batch_data
|
||||||
|
|
||||||
def stat(self):
|
|
||||||
return {
|
|
||||||
'reward': self.stat_reward.get(),
|
|
||||||
'length': self.stat_length.get(),
|
|
||||||
}
|
|
||||||
|
@ -10,7 +10,7 @@ class BasePolicy(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, batch, hidden_state=None):
|
def __call__(self, batch, hidden_state=None):
|
||||||
# return Batch(policy, action, hidden)
|
# return Batch(act=np.array(), state=None, ...)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -22,6 +22,3 @@ class BasePolicy(ABC):
|
|||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def exploration(self):
|
|
||||||
pass
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user