137 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			137 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import gym | ||
|  | import numpy as np | ||
|  | from functools import partial | ||
|  | from typing import Tuple, Optional | ||
|  | 
 | ||
|  | from tianshou.env import MultiAgentEnv | ||
|  | 
 | ||
|  | 
 | ||
|  | class TicTacToeEnv(MultiAgentEnv): | ||
|  |     """This is a simple implementation of the Tic-Tac-Toe game, where two
 | ||
|  |     agents play against each other. | ||
|  | 
 | ||
|  |     The implementation is intended to show how to wrap an environment to | ||
|  |     satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`. | ||
|  | 
 | ||
|  |     :param size: the size of the board (square board) | ||
|  |     :param win_size: how many units in a row is considered to win | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def __init__(self, size: int = 3, win_size: int = 3): | ||
|  |         super().__init__() | ||
|  |         assert size > 0, f'board size should be positive, but got {size}' | ||
|  |         self.size = size | ||
|  |         assert win_size > 0, f'win-size should be positive, but got {win_size}' | ||
|  |         self.win_size = win_size | ||
|  |         assert win_size <= size, f'win-size {win_size} should not ' \ | ||
|  |             f'be larger than board size {size}' | ||
|  |         self.convolve_kernel = np.ones(win_size) | ||
|  |         self.observation_space = gym.spaces.Box( | ||
|  |             low=-1.0, high=1.0, shape=(size, size), dtype=np.float32) | ||
|  |         self.action_space = gym.spaces.Discrete(size * size) | ||
|  |         self.current_board = None | ||
|  |         self.current_agent = None | ||
|  |         self._last_move = None | ||
|  |         self.step_num = None | ||
|  | 
 | ||
|  |     def reset(self) -> dict: | ||
|  |         self.current_board = np.zeros((self.size, self.size), dtype=np.int32) | ||
|  |         self.current_agent = 1 | ||
|  |         self._last_move = (-1, -1) | ||
|  |         self.step_num = 0 | ||
|  |         return { | ||
|  |             'agent_id': self.current_agent, | ||
|  |             'obs': np.array(self.current_board), | ||
|  |             'mask': self.current_board.flatten() == 0 | ||
|  |         } | ||
|  | 
 | ||
|  |     def step(self, action: [int, np.ndarray] | ||
|  |              ) -> Tuple[dict, np.ndarray, np.ndarray, dict]: | ||
|  |         if self.current_agent is None: | ||
|  |             raise ValueError( | ||
|  |                 "calling step() of unreset environment is prohibited!") | ||
|  |         assert 0 <= action < self.size * self.size | ||
|  |         assert self.current_board.item(action) == 0 | ||
|  |         _current_agent = self.current_agent | ||
|  |         self._move(action) | ||
|  |         mask = self.current_board.flatten() == 0 | ||
|  |         is_win, is_opponent_win = False, False | ||
|  |         is_win = self._test_win() | ||
|  |         # the game is over when one wins or there is only one empty place | ||
|  |         done = is_win | ||
|  |         if sum(mask) == 1: | ||
|  |             done = True | ||
|  |             self._move(np.where(mask)[0][0]) | ||
|  |             is_opponent_win = self._test_win() | ||
|  |         if is_win: | ||
|  |             reward = 1 | ||
|  |         elif is_opponent_win: | ||
|  |             reward = -1 | ||
|  |         else: | ||
|  |             reward = 0 | ||
|  |         obs = { | ||
|  |             'agent_id': self.current_agent, | ||
|  |             'obs': np.array(self.current_board), | ||
|  |             'mask': mask | ||
|  |         } | ||
|  |         rew_agent_1 = reward if _current_agent == 1 else (-reward) | ||
|  |         rew_agent_2 = reward if _current_agent == 2 else (-reward) | ||
|  |         vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32) | ||
|  |         if done: | ||
|  |             self.current_agent = None | ||
|  |         return obs, vec_rew, np.array(done), {} | ||
|  | 
 | ||
|  |     def _move(self, action): | ||
|  |         row, col = action // self.size, action % self.size | ||
|  |         if self.current_agent == 1: | ||
|  |             self.current_board[row, col] = 1 | ||
|  |         else: | ||
|  |             self.current_board[row, col] = -1 | ||
|  |         self.current_agent = 3 - self.current_agent | ||
|  |         self._last_move = (row, col) | ||
|  |         self.step_num += 1 | ||
|  | 
 | ||
|  |     def _test_win(self): | ||
|  |         """test if someone wins by checking the situation around last move""" | ||
|  |         row, col = self._last_move | ||
|  |         rboard = self.current_board[row, :] | ||
|  |         cboard = self.current_board[:, col] | ||
|  |         current = self.current_board[row, col] | ||
|  |         rightup = [self.current_board[row - i, col + i] | ||
|  |                    for i in range(1, self.size - col) if row - i >= 0] | ||
|  |         leftdown = [self.current_board[row + i, col - i] | ||
|  |                     for i in range(1, col + 1) if row + i < self.size] | ||
|  |         rdiag = np.array(leftdown[::-1] + [current] + rightup) | ||
|  |         rightdown = [self.current_board[row + i, col + i] | ||
|  |                      for i in range(1, self.size - col) if row + i < self.size] | ||
|  |         leftup = [self.current_board[row - i, col - i] | ||
|  |                   for i in range(1, col + 1) if row - i >= 0] | ||
|  |         diag = np.array(leftup[::-1] + [current] + rightdown) | ||
|  |         results = [np.convolve(k, self.convolve_kernel, mode='valid') | ||
|  |                    for k in (rboard, cboard, rdiag, diag)] | ||
|  |         return any([(np.abs(x) == self.win_size).any() for x in results]) | ||
|  | 
 | ||
|  |     def seed(self, seed: Optional[int] = None) -> int: | ||
|  |         pass | ||
|  | 
 | ||
|  |     def render(self, **kwargs) -> None: | ||
|  |         print(f'board (step {self.step_num}):') | ||
|  |         pad = '===' | ||
|  |         top = pad + '=' * (2 * self.size - 1) + pad | ||
|  |         print(top) | ||
|  | 
 | ||
|  |         def f(i, data): | ||
|  |             j, number = data | ||
|  |             last_move = i == self._last_move[0] and j == self._last_move[1] | ||
|  |             if number == 1: | ||
|  |                 return 'X' if last_move else 'x' | ||
|  |             if number == -1: | ||
|  |                 return 'O' if last_move else 'o' | ||
|  |             return '_' | ||
|  |         for i, row in enumerate(self.current_board): | ||
|  |             print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad) | ||
|  |         print(top) | ||
|  | 
 | ||
|  |     def close(self) -> None: | ||
|  |         pass |