2021-09-24 19:22:23 +05:30
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
|
2022-03-06 17:40:47 -05:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from tianshou.utils import BaseLogger, TensorboardLogger
|
2021-08-30 10:35:02 -04:00
|
|
|
from tianshou.utils.logger.base import LOG_DATA_TYPE
|
|
|
|
|
|
|
|
try:
|
|
|
|
import wandb
|
|
|
|
except ImportError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2021-09-24 19:22:23 +05:30
|
|
|
class WandbLogger(BaseLogger):
|
|
|
|
"""Weights and Biases logger that sends data to https://wandb.ai/.
|
2021-08-30 10:35:02 -04:00
|
|
|
|
2021-09-24 19:22:23 +05:30
|
|
|
This logger creates three panels with plots: train, test, and update.
|
2021-08-30 10:35:02 -04:00
|
|
|
Make sure to select the correct access for each panel in weights and biases:
|
|
|
|
|
|
|
|
Example of usage:
|
|
|
|
::
|
|
|
|
|
2022-03-06 17:40:47 -05:00
|
|
|
logger = WandbLogger()
|
|
|
|
logger.load(SummaryWriter(log_path))
|
|
|
|
result = onpolicy_trainer(policy, train_collector, test_collector,
|
|
|
|
logger=logger)
|
2021-08-30 10:35:02 -04:00
|
|
|
|
|
|
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
|
|
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
|
|
|
:param int update_interval: the log interval in log_update_data().
|
|
|
|
Default to 1000.
|
2022-03-29 20:04:23 -04:00
|
|
|
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
|
|
|
the end of each epoch).
|
|
|
|
:param bool write_flush: whether to flush tensorboard result after each
|
|
|
|
add_scalar operation. Default to True.
|
2021-09-24 19:22:23 +05:30
|
|
|
:param str project: W&B project name. Default to "tianshou".
|
|
|
|
:param str name: W&B run name. Default to None. If None, random name is assigned.
|
|
|
|
:param str entity: W&B team/organization name. Default to None.
|
|
|
|
:param str run_id: run id of W&B run to be resumed. Default to None.
|
|
|
|
:param argparse.Namespace config: experiment configurations. Default to None.
|
2021-08-30 10:35:02 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
train_interval: int = 1000,
|
|
|
|
test_interval: int = 1,
|
|
|
|
update_interval: int = 1000,
|
2021-09-24 19:22:23 +05:30
|
|
|
save_interval: int = 1000,
|
2022-03-29 20:04:23 -04:00
|
|
|
write_flush: bool = True,
|
2022-03-06 17:40:47 -05:00
|
|
|
project: Optional[str] = None,
|
2021-09-24 19:22:23 +05:30
|
|
|
name: Optional[str] = None,
|
|
|
|
entity: Optional[str] = None,
|
|
|
|
run_id: Optional[str] = None,
|
|
|
|
config: Optional[argparse.Namespace] = None,
|
2021-08-30 10:35:02 -04:00
|
|
|
) -> None:
|
|
|
|
super().__init__(train_interval, test_interval, update_interval)
|
2021-09-24 19:22:23 +05:30
|
|
|
self.last_save_step = -1
|
|
|
|
self.save_interval = save_interval
|
2022-03-29 20:04:23 -04:00
|
|
|
self.write_flush = write_flush
|
2021-09-24 19:22:23 +05:30
|
|
|
self.restored = False
|
2022-03-06 17:40:47 -05:00
|
|
|
if project is None:
|
|
|
|
project = os.getenv("WANDB_PROJECT", "tianshou")
|
2021-09-24 19:22:23 +05:30
|
|
|
|
|
|
|
self.wandb_run = wandb.init(
|
|
|
|
project=project,
|
|
|
|
name=name,
|
|
|
|
id=run_id,
|
|
|
|
resume="allow",
|
|
|
|
entity=entity,
|
2022-03-06 17:40:47 -05:00
|
|
|
sync_tensorboard=True,
|
2021-09-24 19:22:23 +05:30
|
|
|
monitor_gym=True,
|
|
|
|
config=config, # type: ignore
|
|
|
|
) if not wandb.run else wandb.run
|
|
|
|
self.wandb_run._label(repo="tianshou") # type: ignore
|
2022-03-06 17:40:47 -05:00
|
|
|
self.tensorboard_logger: Optional[TensorboardLogger] = None
|
|
|
|
|
|
|
|
def load(self, writer: SummaryWriter) -> None:
|
|
|
|
self.writer = writer
|
2022-03-29 20:04:23 -04:00
|
|
|
self.tensorboard_logger = TensorboardLogger(
|
|
|
|
writer, self.train_interval, self.test_interval, self.update_interval,
|
|
|
|
self.save_interval, self.write_flush
|
|
|
|
)
|
2021-08-30 10:35:02 -04:00
|
|
|
|
|
|
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
2022-03-06 17:40:47 -05:00
|
|
|
if self.tensorboard_logger is None:
|
|
|
|
raise Exception(
|
|
|
|
"`logger` needs to load the Tensorboard Writer before "
|
|
|
|
"writing data. Try `logger.load(SummaryWriter(log_path))`"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.tensorboard_logger.write(step_type, step, data)
|
2021-09-24 19:22:23 +05:30
|
|
|
|
|
|
|
def save_data(
|
|
|
|
self,
|
|
|
|
epoch: int,
|
|
|
|
env_step: int,
|
|
|
|
gradient_step: int,
|
2022-06-02 12:07:07 -05:00
|
|
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
2021-09-24 19:22:23 +05:30
|
|
|
) -> None:
|
|
|
|
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
|
|
|
|
|
|
|
:param int epoch: the epoch in trainer.
|
|
|
|
:param int env_step: the env_step in trainer.
|
|
|
|
:param int gradient_step: the gradient_step in trainer.
|
|
|
|
:param function save_checkpoint_fn: a hook defined by user, see trainer
|
|
|
|
documentation for detail.
|
|
|
|
"""
|
|
|
|
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
|
|
|
self.last_save_step = epoch
|
|
|
|
checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step)
|
|
|
|
|
|
|
|
checkpoint_artifact = wandb.Artifact(
|
|
|
|
'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
|
|
|
|
type='model',
|
|
|
|
metadata={
|
|
|
|
"save/epoch": epoch,
|
|
|
|
"save/env_step": env_step,
|
|
|
|
"save/gradient_step": gradient_step,
|
2022-06-02 12:07:07 -05:00
|
|
|
"checkpoint_path": str(checkpoint_path),
|
2021-09-24 19:22:23 +05:30
|
|
|
}
|
|
|
|
)
|
|
|
|
checkpoint_artifact.add_file(str(checkpoint_path))
|
|
|
|
self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore
|
|
|
|
|
|
|
|
def restore_data(self) -> Tuple[int, int, int]:
|
|
|
|
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
|
2022-06-02 12:07:07 -05:00
|
|
|
f"run_{self.wandb_run.id}_checkpoint:latest" # type: ignore
|
2021-09-24 19:22:23 +05:30
|
|
|
)
|
|
|
|
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
|
|
|
|
|
|
|
|
checkpoint_artifact.download(
|
|
|
|
os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])
|
|
|
|
)
|
|
|
|
|
|
|
|
try: # epoch / gradient_step
|
|
|
|
epoch = checkpoint_artifact.metadata["save/epoch"]
|
|
|
|
self.last_save_step = self.last_log_test_step = epoch
|
|
|
|
gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
|
|
|
|
self.last_log_update_step = gradient_step
|
|
|
|
except KeyError:
|
|
|
|
epoch, gradient_step = 0, 0
|
|
|
|
try: # offline trainer doesn't have env_step
|
|
|
|
env_step = checkpoint_artifact.metadata["save/env_step"]
|
|
|
|
self.last_log_train_step = env_step
|
|
|
|
except KeyError:
|
|
|
|
env_step = 0
|
|
|
|
return epoch, env_step, gradient_step
|