add speed stat
This commit is contained in:
parent
cef5de8b83
commit
8b0b970c9b
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -1,7 +1,7 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Python package
|
||||
name: Unittest
|
||||
|
||||
on:
|
||||
push:
|
||||
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 n+e
|
||||
Copyright (c) 2020 TSAIL
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
2
setup.py
2
setup.py
@ -41,7 +41,7 @@ setup(
|
||||
'gym',
|
||||
'tqdm',
|
||||
'numpy',
|
||||
'torch',
|
||||
'torch>=1.2.0', # for supporting tensorboard
|
||||
'cloudpickle',
|
||||
'tensorboard',
|
||||
],
|
||||
|
@ -14,7 +14,7 @@ from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
@ -93,8 +93,9 @@ def test_dqn(args=get_args()):
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
for epoch in range(args.epoch):
|
||||
desc = f"Epoch #{epoch + 1}"
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f"Epoch #{epoch}"
|
||||
# train
|
||||
policy.train()
|
||||
policy.sync_weight()
|
||||
@ -102,9 +103,9 @@ def test_dqn(args=get_args()):
|
||||
with tqdm.trange(
|
||||
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
for _ in t:
|
||||
training_collector.collect(n_step=args.collect_per_step)
|
||||
result = training_collector.collect(
|
||||
n_step=args.collect_per_step)
|
||||
global_step += 1
|
||||
result = training_collector.stat()
|
||||
loss = policy.learn(training_collector.sample(args.batch_size))
|
||||
stat_loss.add(loss)
|
||||
writer.add_scalar(
|
||||
@ -113,26 +114,34 @@ def test_dqn(args=get_args()):
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.6f}')
|
||||
length=f'{result["length"]:.6f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_collector.collect(n_episode=args.test_num)
|
||||
result = test_collector.stat()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if args.task == 'CartPole-v0' and best_reward >= 200:
|
||||
break
|
||||
assert best_reward >= 200
|
||||
if __name__ == '__main__':
|
||||
# let's watch its performance!
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
obs = env.reset()
|
||||
done = False
|
||||
@ -143,10 +152,9 @@ def test_dqn(args=get_args()):
|
||||
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
|
||||
total += rew
|
||||
env.render()
|
||||
time.sleep(1 / 100)
|
||||
time.sleep(1 / 35)
|
||||
env.close()
|
||||
print(f'Final test: {total}')
|
||||
return best_reward
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -15,6 +15,7 @@ class Collector(object):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
self.collect_step = 0
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
@ -40,6 +41,7 @@ class Collector(object):
|
||||
self.state = None
|
||||
self.stat_reward = MovAvg(stat_size)
|
||||
self.stat_length = MovAvg(stat_size)
|
||||
self.stat_speed = MovAvg(stat_size)
|
||||
|
||||
def reset_buffer(self):
|
||||
if self._multi_buf:
|
||||
@ -78,6 +80,8 @@ class Collector(object):
|
||||
return [data]
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, render=0):
|
||||
start_time = time.time()
|
||||
start_step = self.collect_step
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
cur_step = 0
|
||||
@ -119,9 +123,11 @@ class Collector(object):
|
||||
elif self._multi_buf:
|
||||
self.buffer[i].add(**data)
|
||||
cur_step += 1
|
||||
self.collect_step += 1
|
||||
else:
|
||||
self.buffer.add(**data)
|
||||
cur_step += 1
|
||||
self.collect_step += 1
|
||||
if self._done[i]:
|
||||
cur_episode[i] += 1
|
||||
self.stat_reward.add(self.reward[i])
|
||||
@ -130,6 +136,7 @@ class Collector(object):
|
||||
if self._cached_buf:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
cur_step += len(self._cached_buf[i])
|
||||
self.collect_step += len(self._cached_buf[i])
|
||||
self._cached_buf[i].reset()
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
@ -145,6 +152,7 @@ class Collector(object):
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info)
|
||||
cur_step += 1
|
||||
self.collect_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
self.stat_reward.add(self.reward)
|
||||
@ -158,6 +166,13 @@ class Collector(object):
|
||||
break
|
||||
self._obs = obs_next
|
||||
self._obs = obs_next
|
||||
self.stat_speed.add((self.collect_step - start_step) / (
|
||||
time.time() - start_time))
|
||||
return {
|
||||
'reward': self.stat_reward.get(),
|
||||
'length': self.stat_length.get(),
|
||||
'speed': self.stat_speed.get(),
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
if self._multi_buf:
|
||||
@ -179,9 +194,3 @@ class Collector(object):
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
return batch_data
|
||||
|
||||
def stat(self):
|
||||
return {
|
||||
'reward': self.stat_reward.get(),
|
||||
'length': self.stat_length.get(),
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ class BasePolicy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch, hidden_state=None):
|
||||
# return Batch(policy, action, hidden)
|
||||
# return Batch(act=np.array(), state=None, ...)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -22,6 +22,3 @@ class BasePolicy(ABC):
|
||||
|
||||
def sync_weight(self):
|
||||
pass
|
||||
|
||||
def exploration(self):
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user