ChenDRAG 150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00

85 lines
2.8 KiB
Python

import os
import pprint
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import DummyVectorEnv
from tianshou.data import Collector
from tianshou.policy import RandomPolicy
from tic_tac_toe_env import TicTacToeEnv
from tic_tac_toe import get_parser, get_agents, train_agent, watch
def get_args():
parser = get_parser()
parser.add_argument('--self_play_round', type=int, default=20)
args = parser.parse_known_args()[0]
return args
def gomoku(args=get_args()):
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
if args.watch:
watch(args)
return
policy, optim = get_agents(args)
agent_learn = policy.policies[args.agent_id - 1]
agent_opponent = policy.policies[2 - args.agent_id]
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
args.writer = SummaryWriter(log_path)
opponent_pool = [agent_opponent]
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
for opponent in opponent_pool:
policy.replace_policy(opponent, 3 - args.agent_id)
test_collector = Collector(policy, test_envs)
results = test_collector.collect(n_episode=100)
rews.append(results['rews'].mean())
rews = np.array(rews)
# weight opponent by their difficulty level
rews = np.exp(-rews * 10.0)
rews /= np.sum(rews)
total_epoch = args.epoch
args.epoch = 1
for epoch in range(total_epoch):
# sample one opponent
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
print(f'selection probability {rews.tolist()}')
print(f'selected opponent {opp_id}')
opponent = opponent_pool[opp_id.item(0)]
agent = RandomPolicy()
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn',
f'policy_round_{r}_epoch_{epoch}.pth')
result, agent_learn = train_agent(
args, agent_learn=agent_learn,
agent_opponent=agent, optim=optim)
print(f'round_{r}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)
opponent_pool.append(learnt_agent)
args.epoch = total_epoch
if __name__ == '__main__':
# Let's watch its performance!
opponent = opponent_pool[-2]
watch(args, agent_learn, opponent)
if __name__ == '__main__':
gomoku(get_args())