“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

80 lines
2.7 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import os
import time
import torch
import ray
from ez.data.replay_buffer import ReplayBuffer
from ez.data.global_storage import GlobalStorage
class Worker:
def __init__(self, rank: int, agent, replay_buffer: ReplayBuffer, storage: GlobalStorage, config):
self.rank = rank
self.agent = agent
self.replay_buffer = replay_buffer
self.storage = storage
self.config = config
self.model = None
self.latest_model = None
self.model_update_interval = 0
self.last_model_index = -1
self.last_latest_model_index = -1
self.last_log_index = -1
self.log_info = {}
self.total_steps = self.config.train.training_steps + self.config.train.offline_training_steps
def run(self, **kwargs):
raise NotImplementedError()
def get_recent_model(self, trained_steps, model_name):
assert self.model_update_interval > 0
assert self.model
new_model_index = trained_steps // self.model_update_interval
if new_model_index > self.last_model_index:
self.last_model_index = new_model_index
# update model
weights = ray.get(self.storage.get_weights.remote(model_name))
self.model.set_weights(weights)
self.model.cuda()
self.model.eval()
if self.config.ray.single_process:
print('[Update {}] get recent model at step {}'.format(model_name, trained_steps))
def get_latest_model(self, trained_steps, model_name):
new_model_index = trained_steps // 30
if new_model_index > self.last_latest_model_index:
self.last_latest_model_index = new_model_index
weights = ray.get(self.storage.get_weights.remote(model_name))
self.latest_model.set_weights(weights)
self.latest_model.cuda()
self.latest_model.eval()
def resume_model(self):
load_path = self.config.train.load_model_path
if os.path.exists(load_path):
print('[Worker] resume model from path: ', load_path)
weights = torch.load(load_path)
self.model.load_state_dict(weights)
def is_finished(self, trained_steps):
if trained_steps >= self.total_steps:
time.sleep(1)
return True
else:
return False
def reset_log_info(self):
self.log_info = {}
def log(self, key, val):
if not self.log_info.get(key):
self.log_info['key'] = []
self.log_info['key'].append(val)