BaseTrainer: Refactoring

New method training_step, which
    * collects training data (method _collect_training_data)
    * performs "test in train" (method _test_in_train)
    * performs policy update
  The old method named train_step performed only the first two points
  and was now split into two separate methods
This commit is contained in:
Dominik Jain 2024-05-02 18:06:01 +02:00
parent 4f16494609
commit ca4dad1139

View File

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from typing import Optional, Tuple
import numpy as np import numpy as np
import tqdm import tqdm
@ -303,8 +304,10 @@ class BaseTrainer(ABC):
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
train_stat: CollectStatsBase train_stat: CollectStatsBase
while t.n < t.total and not self.stop_fn_flag: while t.n < t.total and not self.stop_fn_flag:
if self.train_collector is not None:
train_stat, self.stop_fn_flag = self.train_step() train_stat, update_stat, self.stop_fn_flag = self.training_step()
if isinstance(train_stat, CollectStats):
pbar_data_dict = { pbar_data_dict = {
"env_step": str(self.env_step), "env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}", "rew": f"{self.last_rew:.2f}",
@ -313,23 +316,17 @@ class BaseTrainer(ABC):
"n/st": str(train_stat.n_collected_steps), "n/st": str(train_stat.n_collected_steps),
} }
t.update(train_stat.n_collected_steps) t.update(train_stat.n_collected_steps)
if self.stop_fn_flag:
t.set_postfix(**pbar_data_dict)
break
else: else:
pbar_data_dict = {} pbar_data_dict = {}
assert self.buffer, "No train_collector or buffer specified"
train_stat = CollectStatsBase(
n_collected_episodes=len(self.buffer),
)
t.update() t.update()
update_stat = self.policy_update_fn(train_stat)
pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict)
pbar_data_dict["gradient_step"] = str(self._gradient_step) pbar_data_dict["gradient_step"] = str(self._gradient_step)
t.set_postfix(**pbar_data_dict) t.set_postfix(**pbar_data_dict)
if self.stop_fn_flag:
break
if t.n <= t.total and not self.stop_fn_flag: if t.n <= t.total and not self.stop_fn_flag:
t.update() t.update()
@ -410,45 +407,71 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag return test_stat, stop_fn_flag
def train_step(self) -> tuple[CollectStats, bool]: def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
"""Perform one training step. should_stop_training = False
if self.train_collector is not None:
collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats)
else:
collect_stats = CollectStatsBase(
n_collected_episodes=len(self.buffer),
)
if not should_stop_training:
training_stats = self.policy_update_fn(collect_stats)
else:
training_stats = None
return collect_stats, training_stats, should_stop_training
def _collect_training_data(self) -> CollectStats:
"""Performs training data collection
:return: the data collection stats
"""
assert self.episode_per_test is not None
assert self.train_collector is not None
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
collect_stats = self.train_collector.collect(
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
self.env_step += collect_stats.n_collected_steps
if collect_stats.n_collected_episodes > 0:
assert collect_stats.returns_stat is not None # for mypy
assert collect_stats.lens_stat is not None # for mypy
self.last_rew = collect_stats.returns_stat.mean
self.last_len = collect_stats.lens_stat.mean
if self.reward_metric: # TODO: move inside collector
rew = self.reward_metric(collect_stats.returns)
collect_stats.returns = rew
collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew)
self.logger.log_train_data(asdict(collect_stats), self.env_step)
return collect_stats
def _test_in_train(self, collect_stats: CollectStats) -> bool:
"""
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
on it. on it.
Finally, if the latter is also True, will set should_stop_training to True. Finally, if the latter is also True, will set should_stop_training to True.
:return: A tuple of the training stats and a boolean indicating whether to stop training. :param collect_stats: the data collection stats
:return: flag indicating whether to stop training
""" """
assert self.episode_per_test is not None
assert self.train_collector is not None
should_stop_training = False should_stop_training = False
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
result = self.train_collector.collect(
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
self.env_step += result.n_collected_steps
if result.n_collected_episodes > 0:
assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy
self.last_rew = result.returns_stat.mean
self.last_len = result.lens_stat.mean
if self.reward_metric: # TODO: move inside collector
rew = self.reward_metric(result.returns)
result.returns = rew
result.returns_stat = SequenceSummaryStats.from_sequence(rew)
self.logger.log_train_data(asdict(result), self.env_step)
if ( if (
result.n_collected_episodes > 0 collect_stats.n_collected_episodes > 0
and self.test_in_train and self.test_in_train
and self.stop_fn and self.stop_fn
and self.stop_fn(result.returns_stat.mean) # type: ignore and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
): ):
assert self.test_collector is not None assert self.test_collector is not None
test_result = test_episode( test_result = test_episode(
@ -464,7 +487,8 @@ class BaseTrainer(ABC):
should_stop_training = True should_stop_training = True
self.best_reward = test_result.returns_stat.mean self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std self.best_reward_std = test_result.returns_stat.std
return result, should_stop_training
return should_stop_training
# TODO: move moving average computation and logging into its own logger # TODO: move moving average computation and logging into its own logger
# TODO: maybe think about a command line logger instead of always printing data dict # TODO: maybe think about a command line logger instead of always printing data dict