Improve logging

This commit is contained in:
Dominik Jain 2023-12-14 11:53:05 +01:00
parent 3caa3805f0
commit 1903a72ecb
2 changed files with 8 additions and 1 deletions

View File

@ -115,6 +115,9 @@ class AgentFactory(ABC, ToStringMixin):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0:
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
)
train_collector.collect(
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,

View File

@ -251,7 +251,9 @@ class Experiment(ToStringMixin):
logger = LazyLogger()
# create policy and collectors
log.info("Creating policy")
policy = self.agent_factory.create_policy(envs, self.config.device)
log.info("Creating collectors")
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
@ -277,15 +279,17 @@ class Experiment(ToStringMixin):
)
# train policy
log.info("Starting training")
trainer_result: dict[str, Any] | None = None
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
trainer_result = trainer.run()
log.info(f"Trainer result:\n{pformat(trainer_result)}")
log.info(f"Training result:\n{pformat(trainer_result)}")
# watch agent performance
if self.config.watch:
log.info("Watching agent performance")
self._watch_agent(
self.config.watch_num_episodes,
policy,