Andriy Drozdyuk 8a5e2190f7
Add Weights and Biases Logger (#427)
- rename BasicLogger to TensorboardLogger
- refactor logger code
- add WandbLogger

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
2021-08-30 22:35:02 +08:00

87 lines
2.9 KiB
Python

import os
import pprint
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import DummyVectorEnv
from tianshou.data import Collector
from tianshou.policy import RandomPolicy
from tianshou.utils import TensorboardLogger
from tic_tac_toe_env import TicTacToeEnv
from tic_tac_toe import get_parser, get_agents, train_agent, watch
def get_args():
parser = get_parser()
parser.add_argument('--self_play_round', type=int, default=20)
args = parser.parse_known_args()[0]
return args
def gomoku(args=get_args()):
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
if args.watch:
watch(args)
return
policy, optim = get_agents(args)
agent_learn = policy.policies[args.agent_id - 1]
agent_opponent = policy.policies[2 - args.agent_id]
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
writer = SummaryWriter(log_path)
args.logger = TensorboardLogger(writer)
opponent_pool = [agent_opponent]
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
for opponent in opponent_pool:
policy.replace_policy(opponent, 3 - args.agent_id)
test_collector = Collector(policy, test_envs)
results = test_collector.collect(n_episode=100)
rews.append(results['rews'].mean())
rews = np.array(rews)
# weight opponent by their difficulty level
rews = np.exp(-rews * 10.0)
rews /= np.sum(rews)
total_epoch = args.epoch
args.epoch = 1
for epoch in range(total_epoch):
# sample one opponent
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
print(f'selection probability {rews.tolist()}')
print(f'selected opponent {opp_id}')
opponent = opponent_pool[opp_id.item(0)]
agent = RandomPolicy()
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn',
f'policy_round_{r}_epoch_{epoch}.pth')
result, agent_learn = train_agent(
args, agent_learn=agent_learn,
agent_opponent=agent, optim=optim)
print(f'round_{r}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)
opponent_pool.append(learnt_agent)
args.epoch = total_epoch
if __name__ == '__main__':
# Let's watch its performance!
opponent = opponent_pool[-2]
watch(args, agent_learn, opponent)
if __name__ == '__main__':
gomoku(get_args())