Improve logging

This commit is contained in:
Dominik Jain 2023-12-14 11:53:05 +01:00
parent 3caa3805f0
commit 1903a72ecb
2 changed files with 8 additions and 1 deletions

View File

@ -115,6 +115,9 @@ class AgentFactory(ABC, ToStringMixin):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs) test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0: if self.sampling_config.start_timesteps > 0:
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
)
train_collector.collect( train_collector.collect(
n_step=self.sampling_config.start_timesteps, n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random, random=self.sampling_config.start_timesteps_random,

View File

@ -251,7 +251,9 @@ class Experiment(ToStringMixin):
logger = LazyLogger() logger = LazyLogger()
# create policy and collectors # create policy and collectors
log.info("Creating policy")
policy = self.agent_factory.create_policy(envs, self.config.device) policy = self.agent_factory.create_policy(envs, self.config.device)
log.info("Creating collectors")
train_collector, test_collector = self.agent_factory.create_train_test_collector( train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy, policy,
envs, envs,
@ -277,15 +279,17 @@ class Experiment(ToStringMixin):
) )
# train policy # train policy
log.info("Starting training")
trainer_result: dict[str, Any] | None = None trainer_result: dict[str, Any] | None = None
if self.config.train: if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence) trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer world.trainer = trainer
trainer_result = trainer.run() trainer_result = trainer.run()
log.info(f"Trainer result:\n{pformat(trainer_result)}") log.info(f"Training result:\n{pformat(trainer_result)}")
# watch agent performance # watch agent performance
if self.config.watch: if self.config.watch:
log.info("Watching agent performance")
self._watch_agent( self._watch_agent(
self.config.watch_num_episodes, self.config.watch_num_episodes,
policy, policy,