80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import time
|
|
import torch
|
|
import ray
|
|
from ez.data.replay_buffer import ReplayBuffer
|
|
from ez.data.global_storage import GlobalStorage
|
|
|
|
|
|
class Worker:
|
|
def __init__(self, rank: int, agent, replay_buffer: ReplayBuffer, storage: GlobalStorage, config):
|
|
self.rank = rank
|
|
self.agent = agent
|
|
self.replay_buffer = replay_buffer
|
|
self.storage = storage
|
|
self.config = config
|
|
|
|
self.model = None
|
|
self.latest_model = None
|
|
self.model_update_interval = 0
|
|
self.last_model_index = -1
|
|
self.last_latest_model_index = -1
|
|
self.last_log_index = -1
|
|
self.log_info = {}
|
|
self.total_steps = self.config.train.training_steps + self.config.train.offline_training_steps
|
|
|
|
def run(self, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
def get_recent_model(self, trained_steps, model_name):
|
|
assert self.model_update_interval > 0
|
|
assert self.model
|
|
|
|
new_model_index = trained_steps // self.model_update_interval
|
|
if new_model_index > self.last_model_index:
|
|
self.last_model_index = new_model_index
|
|
|
|
# update model
|
|
weights = ray.get(self.storage.get_weights.remote(model_name))
|
|
self.model.set_weights(weights)
|
|
self.model.cuda()
|
|
self.model.eval()
|
|
if self.config.ray.single_process:
|
|
print('[Update {}] get recent model at step {}'.format(model_name, trained_steps))
|
|
|
|
def get_latest_model(self, trained_steps, model_name):
|
|
new_model_index = trained_steps // 30
|
|
if new_model_index > self.last_latest_model_index:
|
|
self.last_latest_model_index = new_model_index
|
|
weights = ray.get(self.storage.get_weights.remote(model_name))
|
|
self.latest_model.set_weights(weights)
|
|
self.latest_model.cuda()
|
|
self.latest_model.eval()
|
|
|
|
def resume_model(self):
|
|
load_path = self.config.train.load_model_path
|
|
if os.path.exists(load_path):
|
|
print('[Worker] resume model from path: ', load_path)
|
|
weights = torch.load(load_path)
|
|
self.model.load_state_dict(weights)
|
|
|
|
def is_finished(self, trained_steps):
|
|
if trained_steps >= self.total_steps:
|
|
time.sleep(1)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def reset_log_info(self):
|
|
self.log_info = {}
|
|
|
|
def log(self, key, val):
|
|
if not self.log_info.get(key):
|
|
self.log_info['key'] = []
|
|
self.log_info['key'].append(val)
|