This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally. Things changed: 1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer; 2. remove utils.SummaryWriter;
87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
import os
|
|
import pprint
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from tianshou.env import DummyVectorEnv
|
|
from tianshou.data import Collector
|
|
from tianshou.policy import RandomPolicy
|
|
from tianshou.utils import BasicLogger
|
|
|
|
from tic_tac_toe_env import TicTacToeEnv
|
|
from tic_tac_toe import get_parser, get_agents, train_agent, watch
|
|
|
|
|
|
def get_args():
|
|
parser = get_parser()
|
|
parser.add_argument('--self_play_round', type=int, default=20)
|
|
args = parser.parse_known_args()[0]
|
|
return args
|
|
|
|
|
|
def gomoku(args=get_args()):
|
|
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
|
|
if args.watch:
|
|
watch(args)
|
|
return
|
|
|
|
policy, optim = get_agents(args)
|
|
agent_learn = policy.policies[args.agent_id - 1]
|
|
agent_opponent = policy.policies[2 - args.agent_id]
|
|
|
|
# log
|
|
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
|
writer = SummaryWriter(log_path)
|
|
args.logger = BasicLogger(writer)
|
|
|
|
opponent_pool = [agent_opponent]
|
|
|
|
def env_func():
|
|
return TicTacToeEnv(args.board_size, args.win_size)
|
|
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
|
for r in range(args.self_play_round):
|
|
rews = []
|
|
agent_learn.set_eps(0.0)
|
|
# compute the reward over previous learner
|
|
for opponent in opponent_pool:
|
|
policy.replace_policy(opponent, 3 - args.agent_id)
|
|
test_collector = Collector(policy, test_envs)
|
|
results = test_collector.collect(n_episode=100)
|
|
rews.append(results['rews'].mean())
|
|
rews = np.array(rews)
|
|
# weight opponent by their difficulty level
|
|
rews = np.exp(-rews * 10.0)
|
|
rews /= np.sum(rews)
|
|
total_epoch = args.epoch
|
|
args.epoch = 1
|
|
for epoch in range(total_epoch):
|
|
# sample one opponent
|
|
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
|
|
print(f'selection probability {rews.tolist()}')
|
|
print(f'selected opponent {opp_id}')
|
|
opponent = opponent_pool[opp_id.item(0)]
|
|
agent = RandomPolicy()
|
|
# previous learner can only be used for forward
|
|
agent.forward = opponent.forward
|
|
args.model_save_path = os.path.join(
|
|
args.logdir, 'Gomoku', 'dqn',
|
|
f'policy_round_{r}_epoch_{epoch}.pth')
|
|
result, agent_learn = train_agent(
|
|
args, agent_learn=agent_learn,
|
|
agent_opponent=agent, optim=optim)
|
|
print(f'round_{r}_epoch_{epoch}')
|
|
pprint.pprint(result)
|
|
learnt_agent = deepcopy(agent_learn)
|
|
learnt_agent.set_eps(0.0)
|
|
opponent_pool.append(learnt_agent)
|
|
args.epoch = total_epoch
|
|
if __name__ == '__main__':
|
|
# Let's watch its performance!
|
|
opponent = opponent_pool[-2]
|
|
watch(args, agent_learn, opponent)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
gomoku(get_args())
|