Minor simplification in train_step (#1019)
This commit is contained in:
parent
522f7fbf98
commit
789340f8d6
@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
@ -312,7 +311,14 @@ class BaseTrainer(ABC):
|
||||
while t.n < t.total and not self.stop_fn_flag:
|
||||
train_stat: CollectStatsBase
|
||||
if self.train_collector is not None:
|
||||
pbar_data_dict, train_stat, self.stop_fn_flag = self.train_step()
|
||||
train_stat, self.stop_fn_flag = self.train_step()
|
||||
pbar_data_dict = {
|
||||
"env_step": str(self.env_step),
|
||||
"rew": f"{self.last_rew:.2f}",
|
||||
"len": str(int(self.last_len)),
|
||||
"n/ep": str(train_stat.n_collected_episodes),
|
||||
"n/st": str(train_stat.n_collected_steps),
|
||||
}
|
||||
t.update(train_stat.n_collected_steps)
|
||||
if self.stop_fn_flag:
|
||||
t.set_postfix(**pbar_data_dict)
|
||||
@ -322,13 +328,12 @@ class BaseTrainer(ABC):
|
||||
assert self.buffer, "No train_collector or buffer specified"
|
||||
train_stat = CollectStatsBase(
|
||||
n_collected_episodes=len(self.buffer),
|
||||
n_collected_steps=int(self._gradient_step),
|
||||
)
|
||||
t.update()
|
||||
|
||||
update_stat = self.policy_update_fn(train_stat)
|
||||
pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict)
|
||||
pbar_data_dict["gradient_step"] = self._gradient_step
|
||||
pbar_data_dict["gradient_step"] = str(self._gradient_step)
|
||||
|
||||
t.set_postfix(**pbar_data_dict)
|
||||
|
||||
@ -413,11 +418,19 @@ class BaseTrainer(ABC):
|
||||
|
||||
return test_stat, stop_fn_flag
|
||||
|
||||
def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]:
|
||||
"""Perform one training step."""
|
||||
def train_step(self) -> tuple[CollectStats, bool]:
|
||||
"""Perform one training step.
|
||||
|
||||
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
||||
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
||||
on it.
|
||||
Finally, if the latter is also True, will set should_stop_training to True.
|
||||
|
||||
:return: A tuple of the training stats and a boolean indicating whether to stop training.
|
||||
"""
|
||||
assert self.episode_per_test is not None
|
||||
assert self.train_collector is not None
|
||||
stop_fn_flag = False
|
||||
should_stop_training = False
|
||||
if self.train_fn:
|
||||
self.train_fn(self.epoch, self.env_step)
|
||||
result = self.train_collector.collect(
|
||||
@ -439,13 +452,6 @@ class BaseTrainer(ABC):
|
||||
|
||||
self.logger.log_train_data(asdict(result), self.env_step)
|
||||
|
||||
data = {
|
||||
"env_step": str(self.env_step),
|
||||
"rew": f"{self.last_rew:.2f}",
|
||||
"len": str(int(self.last_len)),
|
||||
"n/ep": str(result.n_collected_episodes),
|
||||
"n/st": str(result.n_collected_steps),
|
||||
}
|
||||
if (
|
||||
result.n_collected_episodes > 0
|
||||
and self.test_in_train
|
||||
@ -464,12 +470,12 @@ class BaseTrainer(ABC):
|
||||
)
|
||||
assert test_result.returns_stat is not None # for mypy
|
||||
if self.stop_fn(test_result.returns_stat.mean):
|
||||
stop_fn_flag = True
|
||||
should_stop_training = True
|
||||
self.best_reward = test_result.returns_stat.mean
|
||||
self.best_reward_std = test_result.returns_stat.std
|
||||
else:
|
||||
self.policy.train()
|
||||
return data, result, stop_fn_flag
|
||||
return result, should_stop_training
|
||||
|
||||
# TODO: move moving average computation and logging into its own logger
|
||||
# TODO: maybe think about a command line logger instead of always printing data dict
|
||||
|
Loading…
x
Reference in New Issue
Block a user