/*
This file is part of Leela Zero.
Copyright (C) 2017 Gian-Carlo Pascutto
Leela Zero is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Zero is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Zero. If not, see .
*/
#ifndef NETWORK_H_INCLUDED
#define NETWORK_H_INCLUDED
#include "config.h"
#include
#include
#include
#include
#include
#ifdef USE_OPENCL
#include
class UCTNode;
#endif
#include "FastState.h"
#include "GameState.h"
class Network {
public:
enum Ensemble {
DIRECT, RANDOM_ROTATION
};
static const int board_size = 19;
using BoardPlane = std::bitset;
using NNPlanes = std::vector;
using scored_node = std::pair;
using Netresult = std::pair, float>;
static Netresult get_scored_moves(GameState * state,
Ensemble ensemble,
int rotation = -1);
static constexpr int INPUT_CHANNELS = 18;
static constexpr int MAX_CHANNELS = 256;
static void initialize();
static void benchmark(GameState * state);
static void show_heatmap(FastState * state, Netresult & netres, bool topmoves);
static void softmax(const std::vector& input,
std::vector& output,
float temperature = 1.0f);
// tianshou_code
static void show_once(std::string hash_key) {
printf("%s\n", hash_key.c_str());
}
private:
static Netresult get_scored_moves_internal(
GameState * state, NNPlanes & planes, int rotation);
static void gather_features(GameState * state, NNPlanes & planes);
static int rotate_nn_idx(const int vertex, int symmetry);
static int rev_rotate_nn_idx(const int vertex, int symmetry);
};
#endif