n+e fc251ab0b8
bump to v0.4.3 (#432)
* add makefile
* bump version
* add isort and yapf
* update contributing.md
* update PR template
* spelling check
2021-09-03 05:05:04 +08:00

88 lines
2.9 KiB
Python

import os
import pprint
from copy import deepcopy
import numpy as np
from tic_tac_toe import get_agents, get_parser, train_agent, watch
from tic_tac_toe_env import TicTacToeEnv
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv
from tianshou.policy import RandomPolicy
from tianshou.utils import TensorboardLogger
def get_args():
parser = get_parser()
parser.add_argument('--self_play_round', type=int, default=20)
args = parser.parse_known_args()[0]
return args
def gomoku(args=get_args()):
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
if args.watch:
watch(args)
return
policy, optim = get_agents(args)
agent_learn = policy.policies[args.agent_id - 1]
agent_opponent = policy.policies[2 - args.agent_id]
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
writer = SummaryWriter(log_path)
args.logger = TensorboardLogger(writer)
opponent_pool = [agent_opponent]
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
for opponent in opponent_pool:
policy.replace_policy(opponent, 3 - args.agent_id)
test_collector = Collector(policy, test_envs)
results = test_collector.collect(n_episode=100)
rews.append(results['rews'].mean())
rews = np.array(rews)
# weight opponent by their difficulty level
rews = np.exp(-rews * 10.0)
rews /= np.sum(rews)
total_epoch = args.epoch
args.epoch = 1
for epoch in range(total_epoch):
# sample one opponent
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
print(f'selection probability {rews.tolist()}')
print(f'selected opponent {opp_id}')
opponent = opponent_pool[opp_id.item(0)]
agent = RandomPolicy()
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth'
)
result, agent_learn = train_agent(
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
)
print(f'round_{r}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)
opponent_pool.append(learnt_agent)
args.epoch = total_epoch
if __name__ == '__main__':
# Let's watch its performance!
opponent = opponent_pool[-2]
watch(args, agent_learn, opponent)
if __name__ == '__main__':
gomoku(get_args())