EfficientZeroV2/ez/data/global_storage.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

105 lines
3.1 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import os
import time
import ray
import numpy as np
@ray.remote
class GlobalStorage:
def __init__(self, self_play_model, reanalyze_model, latest_model):
self.models = {
'self_play': self_play_model,
'reanalyze': reanalyze_model,
'latest': latest_model
}
self.log_scalar = {}
self.eval_log_scalar = {}
self.log_distribution = {}
self.counter = 0
self.eval_counter = 0
self.best_score = - np.inf
self.start = False
# self.batch = None
def get_weights(self, model_name):
assert model_name in self.models.keys()
return self.models[model_name].get_weights()
def set_weights(self, weights, model_name):
assert model_name in self.models.keys()
# print('[Update] set recent model of {}'.format(model_name))
return self.models[model_name].set_weights(weights)
def increase_counter(self):
self.counter += 1
def get_counter(self):
return self.counter
def set_eval_counter(self, counter):
self.eval_counter = counter
def get_eval_counter(self):
return self.eval_counter
def set_start_signal(self):
self.start = True
def get_start_signal(self):
return self.start
def set_best_score(self, score):
self.best_score = max(self.best_score, score)
def get_best_score(self):
return self.best_score
def add_log_scalar(self, dic):
for key, val in dic.items():
if key not in self.log_scalar.keys():
self.log_scalar[key] = []
self.log_scalar[key].append(val)
def add_eval_log_scalar(self, dic):
for key, val in dic.items():
if key not in self.eval_log_scalar.keys():
self.eval_log_scalar[key] = []
self.eval_log_scalar[key].append(val)
def add_log_distribution(self, dic):
for key, val in dic.items():
if key not in self.log_distribution.keys():
self.log_distribution[key] = []
self.log_distribution[key] += val.tolist()
def get_log(self):
# for scalar
scalar = {}
for key, val in self.log_scalar.items():
scalar[key] = np.mean(val)
eval_scalar = {}
for key, val in self.eval_log_scalar.items():
eval_scalar[key] = np.mean(val)
# for distribution
distribution = {}
for key, val in self.log_distribution.items():
distribution[key] = np.array(val).flatten()
self.log_scalar = {}
self.eval_log_scalar = {}
self.log_distribution = {}
return eval_scalar, scalar, distribution
# ======================================================================================================================
# global storage server
# ======================================================================================================================