BaseTrainer: Refactoring
New method training_step, which * collects training data (method _collect_training_data) * performs "test in train" (method _test_in_train) * performs policy update The old method named train_step performed only the first two points and was now split into two separate methods
This commit is contained in:
parent
4f16494609
commit
ca4dad1139
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -303,8 +304,10 @@ class BaseTrainer(ABC):
|
|||||||
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
|
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
|
||||||
train_stat: CollectStatsBase
|
train_stat: CollectStatsBase
|
||||||
while t.n < t.total and not self.stop_fn_flag:
|
while t.n < t.total and not self.stop_fn_flag:
|
||||||
if self.train_collector is not None:
|
|
||||||
train_stat, self.stop_fn_flag = self.train_step()
|
train_stat, update_stat, self.stop_fn_flag = self.training_step()
|
||||||
|
|
||||||
|
if isinstance(train_stat, CollectStats):
|
||||||
pbar_data_dict = {
|
pbar_data_dict = {
|
||||||
"env_step": str(self.env_step),
|
"env_step": str(self.env_step),
|
||||||
"rew": f"{self.last_rew:.2f}",
|
"rew": f"{self.last_rew:.2f}",
|
||||||
@ -313,23 +316,17 @@ class BaseTrainer(ABC):
|
|||||||
"n/st": str(train_stat.n_collected_steps),
|
"n/st": str(train_stat.n_collected_steps),
|
||||||
}
|
}
|
||||||
t.update(train_stat.n_collected_steps)
|
t.update(train_stat.n_collected_steps)
|
||||||
if self.stop_fn_flag:
|
|
||||||
t.set_postfix(**pbar_data_dict)
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
pbar_data_dict = {}
|
pbar_data_dict = {}
|
||||||
assert self.buffer, "No train_collector or buffer specified"
|
|
||||||
train_stat = CollectStatsBase(
|
|
||||||
n_collected_episodes=len(self.buffer),
|
|
||||||
)
|
|
||||||
t.update()
|
t.update()
|
||||||
|
|
||||||
update_stat = self.policy_update_fn(train_stat)
|
|
||||||
pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict)
|
pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict)
|
||||||
pbar_data_dict["gradient_step"] = str(self._gradient_step)
|
pbar_data_dict["gradient_step"] = str(self._gradient_step)
|
||||||
|
|
||||||
t.set_postfix(**pbar_data_dict)
|
t.set_postfix(**pbar_data_dict)
|
||||||
|
|
||||||
|
if self.stop_fn_flag:
|
||||||
|
break
|
||||||
|
|
||||||
if t.n <= t.total and not self.stop_fn_flag:
|
if t.n <= t.total and not self.stop_fn_flag:
|
||||||
t.update()
|
t.update()
|
||||||
|
|
||||||
@ -410,45 +407,71 @@ class BaseTrainer(ABC):
|
|||||||
|
|
||||||
return test_stat, stop_fn_flag
|
return test_stat, stop_fn_flag
|
||||||
|
|
||||||
def train_step(self) -> tuple[CollectStats, bool]:
|
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
|
||||||
"""Perform one training step.
|
should_stop_training = False
|
||||||
|
|
||||||
|
if self.train_collector is not None:
|
||||||
|
collect_stats = self._collect_training_data()
|
||||||
|
should_stop_training = self._test_in_train(collect_stats)
|
||||||
|
else:
|
||||||
|
collect_stats = CollectStatsBase(
|
||||||
|
n_collected_episodes=len(self.buffer),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not should_stop_training:
|
||||||
|
training_stats = self.policy_update_fn(collect_stats)
|
||||||
|
else:
|
||||||
|
training_stats = None
|
||||||
|
|
||||||
|
return collect_stats, training_stats, should_stop_training
|
||||||
|
|
||||||
|
def _collect_training_data(self) -> CollectStats:
|
||||||
|
"""Performs training data collection
|
||||||
|
|
||||||
|
:return: the data collection stats
|
||||||
|
"""
|
||||||
|
assert self.episode_per_test is not None
|
||||||
|
assert self.train_collector is not None
|
||||||
|
if self.train_fn:
|
||||||
|
self.train_fn(self.epoch, self.env_step)
|
||||||
|
collect_stats = self.train_collector.collect(
|
||||||
|
n_step=self.step_per_collect,
|
||||||
|
n_episode=self.episode_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.env_step += collect_stats.n_collected_steps
|
||||||
|
|
||||||
|
if collect_stats.n_collected_episodes > 0:
|
||||||
|
assert collect_stats.returns_stat is not None # for mypy
|
||||||
|
assert collect_stats.lens_stat is not None # for mypy
|
||||||
|
self.last_rew = collect_stats.returns_stat.mean
|
||||||
|
self.last_len = collect_stats.lens_stat.mean
|
||||||
|
if self.reward_metric: # TODO: move inside collector
|
||||||
|
rew = self.reward_metric(collect_stats.returns)
|
||||||
|
collect_stats.returns = rew
|
||||||
|
collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew)
|
||||||
|
|
||||||
|
self.logger.log_train_data(asdict(collect_stats), self.env_step)
|
||||||
|
|
||||||
|
return collect_stats
|
||||||
|
|
||||||
|
def _test_in_train(self, collect_stats: CollectStats) -> bool:
|
||||||
|
"""
|
||||||
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
||||||
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
||||||
on it.
|
on it.
|
||||||
Finally, if the latter is also True, will set should_stop_training to True.
|
Finally, if the latter is also True, will set should_stop_training to True.
|
||||||
|
|
||||||
:return: A tuple of the training stats and a boolean indicating whether to stop training.
|
:param collect_stats: the data collection stats
|
||||||
|
:return: flag indicating whether to stop training
|
||||||
"""
|
"""
|
||||||
assert self.episode_per_test is not None
|
|
||||||
assert self.train_collector is not None
|
|
||||||
should_stop_training = False
|
should_stop_training = False
|
||||||
if self.train_fn:
|
|
||||||
self.train_fn(self.epoch, self.env_step)
|
|
||||||
result = self.train_collector.collect(
|
|
||||||
n_step=self.step_per_collect,
|
|
||||||
n_episode=self.episode_per_collect,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.env_step += result.n_collected_steps
|
|
||||||
|
|
||||||
if result.n_collected_episodes > 0:
|
|
||||||
assert result.returns_stat is not None # for mypy
|
|
||||||
assert result.lens_stat is not None # for mypy
|
|
||||||
self.last_rew = result.returns_stat.mean
|
|
||||||
self.last_len = result.lens_stat.mean
|
|
||||||
if self.reward_metric: # TODO: move inside collector
|
|
||||||
rew = self.reward_metric(result.returns)
|
|
||||||
result.returns = rew
|
|
||||||
result.returns_stat = SequenceSummaryStats.from_sequence(rew)
|
|
||||||
|
|
||||||
self.logger.log_train_data(asdict(result), self.env_step)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
result.n_collected_episodes > 0
|
collect_stats.n_collected_episodes > 0
|
||||||
and self.test_in_train
|
and self.test_in_train
|
||||||
and self.stop_fn
|
and self.stop_fn
|
||||||
and self.stop_fn(result.returns_stat.mean) # type: ignore
|
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
|
||||||
):
|
):
|
||||||
assert self.test_collector is not None
|
assert self.test_collector is not None
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
@ -464,7 +487,8 @@ class BaseTrainer(ABC):
|
|||||||
should_stop_training = True
|
should_stop_training = True
|
||||||
self.best_reward = test_result.returns_stat.mean
|
self.best_reward = test_result.returns_stat.mean
|
||||||
self.best_reward_std = test_result.returns_stat.std
|
self.best_reward_std = test_result.returns_stat.std
|
||||||
return result, should_stop_training
|
|
||||||
|
return should_stop_training
|
||||||
|
|
||||||
# TODO: move moving average computation and logging into its own logger
|
# TODO: move moving average computation and logging into its own logger
|
||||||
# TODO: maybe think about a command line logger instead of always printing data dict
|
# TODO: maybe think about a command line logger instead of always printing data dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user