From 1903a72ecb4b6d78dc08f7240844ad1046a9e37d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 14 Dec 2023 11:53:05 +0100 Subject: [PATCH] Improve logging --- tianshou/highlevel/agent.py | 3 +++ tianshou/highlevel/experiment.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 772c9d1..727bfb2 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -115,6 +115,9 @@ class AgentFactory(ABC, ToStringMixin): train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, envs.test_envs) if self.sampling_config.start_timesteps > 0: + log.info( + f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", + ) train_collector.collect( n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 5dd5179..361c3a9 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -251,7 +251,9 @@ class Experiment(ToStringMixin): logger = LazyLogger() # create policy and collectors + log.info("Creating policy") policy = self.agent_factory.create_policy(envs, self.config.device) + log.info("Creating collectors") train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, envs, @@ -277,15 +279,17 @@ class Experiment(ToStringMixin): ) # train policy + log.info("Starting training") trainer_result: dict[str, Any] | None = None if self.config.train: trainer = self.agent_factory.create_trainer(world, policy_persistence) world.trainer = trainer trainer_result = trainer.run() - log.info(f"Trainer result:\n{pformat(trainer_result)}") + log.info(f"Training result:\n{pformat(trainer_result)}") # watch agent performance if self.config.watch: + log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, policy,