#!/usr/bin/env python3 import argparse import os import re from typing import Any, Literal import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np from tools import csv2numpy, find_all_files, group_files def smooth( y: np.ndarray, radius: int, mode: Literal["two_sided", "causal"] = "two_sided", valid_only: bool = False, ) -> np.ndarray: """Smooth signal y, where radius is determines the size of the window. mode='twosided': average over the window [max(index - radius, 0), min(index + radius, len(y)-1)] mode='causal': average over the window [max(index - radius, 0), index] valid_only: put nan in entries where the full-sized window is not available """ if len(y) < 2 * radius + 1: return np.ones_like(y) * y.mean() if mode == "two_sided": convkernel = np.ones(2 * radius + 1) out = np.convolve(y, convkernel, mode="same") / np.convolve( np.ones_like(y), convkernel, mode="same", ) if valid_only: out[:radius] = out[-radius:] = np.nan elif mode == "causal": convkernel = np.ones(radius) out = np.convolve(y, convkernel, mode="full") / np.convolve( np.ones_like(y), convkernel, mode="full", ) out = out[: -radius + 1] if valid_only: out[:radius] = np.nan return out COLORS = [ # deepmind style "#0072B2", "#009E73", "#D55E00", "#CC79A7", # '#F0E442', "#d73027", # RED # built-in color "blue", "red", "pink", "cyan", "magenta", "yellow", "black", "purple", "brown", "orange", "teal", "lightblue", "lime", "lavender", "turquoise", "darkgreen", "tan", "salmon", "gold", "darkred", "darkblue", "green", # personal color "#313695", # DARK BLUE "#74add1", # LIGHT BLUE "#f46d43", # ORANGE "#4daf4a", # GREEN "#984ea3", # PURPLE "#f781bf", # PINK "#ffc832", # YELLOW "#000000", # BLACK ] def plot_ax( ax: plt.Axes, file_lists: list[str], legend_pattern: str = ".*", xlabel: str | None = None, ylabel: str | None = None, title: str = "", xlim: float | None = None, xkey: str = "env_step", ykey: str = "reward", smooth_radius: int = 0, shaded_std: bool = True, legend_outside: bool = False, ) -> None: def legend_fn(x: str) -> str: # return os.path.split(os.path.join( # args.root_dir, x))[0].replace('/', '_') + " (10)" match = re.search(legend_pattern, x) assert match is not None # for mypy return match.group(0) legneds = map(legend_fn, file_lists) # sort filelist according to legends file_lists = [f for _, f in sorted(zip(legneds, file_lists, strict=True))] legneds = list(map(legend_fn, file_lists)) for index, csv_file in enumerate(file_lists): csv_dict = csv2numpy(csv_file) x, y = csv_dict[xkey], csv_dict[ykey] y = smooth(y, radius=smooth_radius) color = COLORS[index % len(COLORS)] ax.plot(x, y, color=color) if shaded_std and ykey + ":shaded" in csv_dict: y_shaded = smooth(csv_dict[ykey + ":shaded"], radius=smooth_radius) ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=0.2) ax.legend( legneds, loc=2 if legend_outside else None, bbox_to_anchor=(1, 1) if legend_outside else None, ) ax.xaxis.set_major_formatter(mticker.EngFormatter()) if xlim is not None: ax.set_xlim(xmin=0, xmax=xlim) # add title ax.set_title(title) # add labels if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) def plot_figure( file_lists: list[str], group_pattern: str | None = None, fig_length: int = 6, fig_width: int = 6, sharex: bool = False, sharey: bool = False, title: str = "", **kwargs: Any, ) -> None: if not group_pattern: fig, ax = plt.subplots(figsize=(fig_length, fig_width)) plot_ax(ax, file_lists, title=title, **kwargs) else: res = group_files(file_lists, group_pattern) row_n = int(np.ceil(len(res) / 3)) col_n = min(len(res), 3) fig, axes = plt.subplots( row_n, col_n, sharex=sharex, sharey=sharey, figsize=(fig_length * col_n, fig_width * row_n), squeeze=False, ) axes = axes.flatten() for i, (k, v) in enumerate(res.items()): plot_ax(axes[i], v, title=k, **kwargs) if title: # add title fig.suptitle(title, fontsize=20) if __name__ == "__main__": parser = argparse.ArgumentParser(description="plotter") parser.add_argument( "--fig-length", type=int, default=6, help="matplotlib figure length (default: 6)", ) parser.add_argument( "--fig-width", type=int, default=6, help="matplotlib figure width (default: 6)", ) parser.add_argument( "--style", default="seaborn", help="matplotlib figure style (default: seaborn)", ) parser.add_argument("--title", default=None, help="matplotlib figure title (default: None)") parser.add_argument( "--xkey", default="env_step", help="x-axis key in csv file (default: env_step)", ) parser.add_argument("--ykey", default="rew", help="y-axis key in csv file (default: rew)") parser.add_argument( "--smooth", type=int, default=0, help="smooth radius of y axis (default: 0)", ) parser.add_argument("--xlabel", default="Timesteps", help="matplotlib figure xlabel") parser.add_argument("--ylabel", default="Episode Reward", help="matplotlib figure ylabel") parser.add_argument( "--shaded-std", action="store_true", help="shaded region corresponding to standard deviation of the group", ) parser.add_argument( "--sharex", action="store_true", help="whether to share x axis within multiple sub-figures", ) parser.add_argument( "--sharey", action="store_true", help="whether to share y axis within multiple sub-figures", ) parser.add_argument( "--legend-outside", action="store_true", help="place the legend outside of the figure", ) parser.add_argument("--xlim", type=int, default=None, help="x-axis limitation (default: None)") parser.add_argument("--root-dir", default="./", help="root dir (default: ./)") parser.add_argument( "--file-pattern", type=str, default=r".*/test_rew_\d+seeds.csv$", help="regular expression to determine whether or not to include target csv " "file, default to including all test_rew_{num}seeds.csv file under rootdir", ) parser.add_argument( "--group-pattern", type=str, default=r"(/|^)\w*?\-v(\d|$)", help="regular expression to group files in sub-figure, default to grouping " 'according to env_name dir, "" means no grouping', ) parser.add_argument( "--legend-pattern", type=str, default=r".*", help="regular expression to extract legend from csv file path, default to " "using file path as legend name.", ) parser.add_argument("--show", action="store_true", help="show figure") parser.add_argument("--output-path", type=str, help="figure save path", default="./figure.png") parser.add_argument("--dpi", type=int, default=200, help="figure dpi (default: 200)") args = parser.parse_args() file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern)) file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists] if args.style: plt.style.use(args.style) os.chdir(args.root_dir) plot_figure( file_lists, group_pattern=args.group_pattern, legend_pattern=args.legend_pattern, fig_length=args.fig_length, fig_width=args.fig_width, title=args.title, xlabel=args.xlabel, ylabel=args.ylabel, xkey=args.xkey, ykey=args.ykey, xlim=args.xlim, sharex=args.sharex, sharey=args.sharey, smooth_radius=args.smooth, shaded_std=args.shaded_std, legend_outside=args.legend_outside, ) if args.output_path: plt.savefig(args.output_path, dpi=args.dpi, bbox_inches="tight") if args.show: plt.show()