137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
|
import gym
|
||
|
import numpy as np
|
||
|
from functools import partial
|
||
|
from typing import Tuple, Optional
|
||
|
|
||
|
from tianshou.env import MultiAgentEnv
|
||
|
|
||
|
|
||
|
class TicTacToeEnv(MultiAgentEnv):
|
||
|
"""This is a simple implementation of the Tic-Tac-Toe game, where two
|
||
|
agents play against each other.
|
||
|
|
||
|
The implementation is intended to show how to wrap an environment to
|
||
|
satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`.
|
||
|
|
||
|
:param size: the size of the board (square board)
|
||
|
:param win_size: how many units in a row is considered to win
|
||
|
"""
|
||
|
|
||
|
def __init__(self, size: int = 3, win_size: int = 3):
|
||
|
super().__init__()
|
||
|
assert size > 0, f'board size should be positive, but got {size}'
|
||
|
self.size = size
|
||
|
assert win_size > 0, f'win-size should be positive, but got {win_size}'
|
||
|
self.win_size = win_size
|
||
|
assert win_size <= size, f'win-size {win_size} should not ' \
|
||
|
f'be larger than board size {size}'
|
||
|
self.convolve_kernel = np.ones(win_size)
|
||
|
self.observation_space = gym.spaces.Box(
|
||
|
low=-1.0, high=1.0, shape=(size, size), dtype=np.float32)
|
||
|
self.action_space = gym.spaces.Discrete(size * size)
|
||
|
self.current_board = None
|
||
|
self.current_agent = None
|
||
|
self._last_move = None
|
||
|
self.step_num = None
|
||
|
|
||
|
def reset(self) -> dict:
|
||
|
self.current_board = np.zeros((self.size, self.size), dtype=np.int32)
|
||
|
self.current_agent = 1
|
||
|
self._last_move = (-1, -1)
|
||
|
self.step_num = 0
|
||
|
return {
|
||
|
'agent_id': self.current_agent,
|
||
|
'obs': np.array(self.current_board),
|
||
|
'mask': self.current_board.flatten() == 0
|
||
|
}
|
||
|
|
||
|
def step(self, action: [int, np.ndarray]
|
||
|
) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
|
||
|
if self.current_agent is None:
|
||
|
raise ValueError(
|
||
|
"calling step() of unreset environment is prohibited!")
|
||
|
assert 0 <= action < self.size * self.size
|
||
|
assert self.current_board.item(action) == 0
|
||
|
_current_agent = self.current_agent
|
||
|
self._move(action)
|
||
|
mask = self.current_board.flatten() == 0
|
||
|
is_win, is_opponent_win = False, False
|
||
|
is_win = self._test_win()
|
||
|
# the game is over when one wins or there is only one empty place
|
||
|
done = is_win
|
||
|
if sum(mask) == 1:
|
||
|
done = True
|
||
|
self._move(np.where(mask)[0][0])
|
||
|
is_opponent_win = self._test_win()
|
||
|
if is_win:
|
||
|
reward = 1
|
||
|
elif is_opponent_win:
|
||
|
reward = -1
|
||
|
else:
|
||
|
reward = 0
|
||
|
obs = {
|
||
|
'agent_id': self.current_agent,
|
||
|
'obs': np.array(self.current_board),
|
||
|
'mask': mask
|
||
|
}
|
||
|
rew_agent_1 = reward if _current_agent == 1 else (-reward)
|
||
|
rew_agent_2 = reward if _current_agent == 2 else (-reward)
|
||
|
vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32)
|
||
|
if done:
|
||
|
self.current_agent = None
|
||
|
return obs, vec_rew, np.array(done), {}
|
||
|
|
||
|
def _move(self, action):
|
||
|
row, col = action // self.size, action % self.size
|
||
|
if self.current_agent == 1:
|
||
|
self.current_board[row, col] = 1
|
||
|
else:
|
||
|
self.current_board[row, col] = -1
|
||
|
self.current_agent = 3 - self.current_agent
|
||
|
self._last_move = (row, col)
|
||
|
self.step_num += 1
|
||
|
|
||
|
def _test_win(self):
|
||
|
"""test if someone wins by checking the situation around last move"""
|
||
|
row, col = self._last_move
|
||
|
rboard = self.current_board[row, :]
|
||
|
cboard = self.current_board[:, col]
|
||
|
current = self.current_board[row, col]
|
||
|
rightup = [self.current_board[row - i, col + i]
|
||
|
for i in range(1, self.size - col) if row - i >= 0]
|
||
|
leftdown = [self.current_board[row + i, col - i]
|
||
|
for i in range(1, col + 1) if row + i < self.size]
|
||
|
rdiag = np.array(leftdown[::-1] + [current] + rightup)
|
||
|
rightdown = [self.current_board[row + i, col + i]
|
||
|
for i in range(1, self.size - col) if row + i < self.size]
|
||
|
leftup = [self.current_board[row - i, col - i]
|
||
|
for i in range(1, col + 1) if row - i >= 0]
|
||
|
diag = np.array(leftup[::-1] + [current] + rightdown)
|
||
|
results = [np.convolve(k, self.convolve_kernel, mode='valid')
|
||
|
for k in (rboard, cboard, rdiag, diag)]
|
||
|
return any([(np.abs(x) == self.win_size).any() for x in results])
|
||
|
|
||
|
def seed(self, seed: Optional[int] = None) -> int:
|
||
|
pass
|
||
|
|
||
|
def render(self, **kwargs) -> None:
|
||
|
print(f'board (step {self.step_num}):')
|
||
|
pad = '==='
|
||
|
top = pad + '=' * (2 * self.size - 1) + pad
|
||
|
print(top)
|
||
|
|
||
|
def f(i, data):
|
||
|
j, number = data
|
||
|
last_move = i == self._last_move[0] and j == self._last_move[1]
|
||
|
if number == 1:
|
||
|
return 'X' if last_move else 'x'
|
||
|
if number == -1:
|
||
|
return 'O' if last_move else 'o'
|
||
|
return '_'
|
||
|
for i, row in enumerate(self.current_board):
|
||
|
print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad)
|
||
|
print(top)
|
||
|
|
||
|
def close(self) -> None:
|
||
|
pass
|