Minor simplification in train_step (#1019)

This commit is contained in:
Michael Panchenko 2024-01-09 17:51:49 +01:00 committed by GitHub
parent 522f7fbf98
commit 789340f8d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from typing import Any
import numpy as np import numpy as np
import tqdm import tqdm
@ -312,7 +311,14 @@ class BaseTrainer(ABC):
while t.n < t.total and not self.stop_fn_flag: while t.n < t.total and not self.stop_fn_flag:
train_stat: CollectStatsBase train_stat: CollectStatsBase
if self.train_collector is not None: if self.train_collector is not None:
pbar_data_dict, train_stat, self.stop_fn_flag = self.train_step() train_stat, self.stop_fn_flag = self.train_step()
pbar_data_dict = {
"env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
"n/ep": str(train_stat.n_collected_episodes),
"n/st": str(train_stat.n_collected_steps),
}
t.update(train_stat.n_collected_steps) t.update(train_stat.n_collected_steps)
if self.stop_fn_flag: if self.stop_fn_flag:
t.set_postfix(**pbar_data_dict) t.set_postfix(**pbar_data_dict)
@ -322,13 +328,12 @@ class BaseTrainer(ABC):
assert self.buffer, "No train_collector or buffer specified" assert self.buffer, "No train_collector or buffer specified"
train_stat = CollectStatsBase( train_stat = CollectStatsBase(
n_collected_episodes=len(self.buffer), n_collected_episodes=len(self.buffer),
n_collected_steps=int(self._gradient_step),
) )
t.update() t.update()
update_stat = self.policy_update_fn(train_stat) update_stat = self.policy_update_fn(train_stat)
pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict)
pbar_data_dict["gradient_step"] = self._gradient_step pbar_data_dict["gradient_step"] = str(self._gradient_step)
t.set_postfix(**pbar_data_dict) t.set_postfix(**pbar_data_dict)
@ -413,11 +418,19 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag return test_stat, stop_fn_flag
def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]: def train_step(self) -> tuple[CollectStats, bool]:
"""Perform one training step.""" """Perform one training step.
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
on it.
Finally, if the latter is also True, will set should_stop_training to True.
:return: A tuple of the training stats and a boolean indicating whether to stop training.
"""
assert self.episode_per_test is not None assert self.episode_per_test is not None
assert self.train_collector is not None assert self.train_collector is not None
stop_fn_flag = False should_stop_training = False
if self.train_fn: if self.train_fn:
self.train_fn(self.epoch, self.env_step) self.train_fn(self.epoch, self.env_step)
result = self.train_collector.collect( result = self.train_collector.collect(
@ -439,13 +452,6 @@ class BaseTrainer(ABC):
self.logger.log_train_data(asdict(result), self.env_step) self.logger.log_train_data(asdict(result), self.env_step)
data = {
"env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
"n/ep": str(result.n_collected_episodes),
"n/st": str(result.n_collected_steps),
}
if ( if (
result.n_collected_episodes > 0 result.n_collected_episodes > 0
and self.test_in_train and self.test_in_train
@ -464,12 +470,12 @@ class BaseTrainer(ABC):
) )
assert test_result.returns_stat is not None # for mypy assert test_result.returns_stat is not None # for mypy
if self.stop_fn(test_result.returns_stat.mean): if self.stop_fn(test_result.returns_stat.mean):
stop_fn_flag = True should_stop_training = True
self.best_reward = test_result.returns_stat.mean self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std self.best_reward_std = test_result.returns_stat.std
else: else:
self.policy.train() self.policy.train()
return data, result, stop_fn_flag return result, should_stop_training
# TODO: move moving average computation and logging into its own logger # TODO: move moving average computation and logging into its own logger
# TODO: maybe think about a command line logger instead of always printing data dict # TODO: maybe think about a command line logger instead of always printing data dict