Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
281 lines
8.2 KiB
Python
Executable File
281 lines
8.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import re
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mticker
|
|
import numpy as np
|
|
from tools import csv2numpy, find_all_files, group_files
|
|
|
|
|
|
def smooth(y, radius, mode="two_sided", valid_only=False):
|
|
"""Smooth signal y, where radius is determines the size of the window.
|
|
|
|
mode='twosided':
|
|
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
|
|
mode='causal':
|
|
average over the window [max(index - radius, 0), index]
|
|
valid_only: put nan in entries where the full-sized window is not available
|
|
"""
|
|
assert mode in ("two_sided", "causal")
|
|
if len(y) < 2 * radius + 1:
|
|
return np.ones_like(y) * y.mean()
|
|
if mode == "two_sided":
|
|
convkernel = np.ones(2 * radius + 1)
|
|
out = np.convolve(y, convkernel, mode="same") / np.convolve(
|
|
np.ones_like(y),
|
|
convkernel,
|
|
mode="same",
|
|
)
|
|
if valid_only:
|
|
out[:radius] = out[-radius:] = np.nan
|
|
elif mode == "causal":
|
|
convkernel = np.ones(radius)
|
|
out = np.convolve(y, convkernel, mode="full") / np.convolve(
|
|
np.ones_like(y),
|
|
convkernel,
|
|
mode="full",
|
|
)
|
|
out = out[: -radius + 1]
|
|
if valid_only:
|
|
out[:radius] = np.nan
|
|
return out
|
|
|
|
|
|
COLORS = [
|
|
# deepmind style
|
|
"#0072B2",
|
|
"#009E73",
|
|
"#D55E00",
|
|
"#CC79A7",
|
|
# '#F0E442',
|
|
"#d73027", # RED
|
|
# built-in color
|
|
"blue",
|
|
"red",
|
|
"pink",
|
|
"cyan",
|
|
"magenta",
|
|
"yellow",
|
|
"black",
|
|
"purple",
|
|
"brown",
|
|
"orange",
|
|
"teal",
|
|
"lightblue",
|
|
"lime",
|
|
"lavender",
|
|
"turquoise",
|
|
"darkgreen",
|
|
"tan",
|
|
"salmon",
|
|
"gold",
|
|
"darkred",
|
|
"darkblue",
|
|
"green",
|
|
# personal color
|
|
"#313695", # DARK BLUE
|
|
"#74add1", # LIGHT BLUE
|
|
"#f46d43", # ORANGE
|
|
"#4daf4a", # GREEN
|
|
"#984ea3", # PURPLE
|
|
"#f781bf", # PINK
|
|
"#ffc832", # YELLOW
|
|
"#000000", # BLACK
|
|
]
|
|
|
|
|
|
def plot_ax(
|
|
ax,
|
|
file_lists,
|
|
legend_pattern=".*",
|
|
xlabel=None,
|
|
ylabel=None,
|
|
title=None,
|
|
xlim=None,
|
|
xkey="env_step",
|
|
ykey="reward",
|
|
smooth_radius=0,
|
|
shaded_std=True,
|
|
legend_outside=False,
|
|
):
|
|
def legend_fn(x):
|
|
# return os.path.split(os.path.join(
|
|
# args.root_dir, x))[0].replace('/', '_') + " (10)"
|
|
return re.search(legend_pattern, x).group(0)
|
|
|
|
legneds = map(legend_fn, file_lists)
|
|
# sort filelist according to legends
|
|
file_lists = [f for _, f in sorted(zip(legneds, file_lists, strict=True))]
|
|
legneds = list(map(legend_fn, file_lists))
|
|
|
|
for index, csv_file in enumerate(file_lists):
|
|
csv_dict = csv2numpy(csv_file)
|
|
x, y = csv_dict[xkey], csv_dict[ykey]
|
|
y = smooth(y, radius=smooth_radius)
|
|
color = COLORS[index % len(COLORS)]
|
|
ax.plot(x, y, color=color)
|
|
if shaded_std and ykey + ":shaded" in csv_dict:
|
|
y_shaded = smooth(csv_dict[ykey + ":shaded"], radius=smooth_radius)
|
|
ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=0.2)
|
|
|
|
ax.legend(
|
|
legneds,
|
|
loc=2 if legend_outside else None,
|
|
bbox_to_anchor=(1, 1) if legend_outside else None,
|
|
)
|
|
ax.xaxis.set_major_formatter(mticker.EngFormatter())
|
|
if xlim is not None:
|
|
ax.set_xlim(xmin=0, xmax=xlim)
|
|
# add title
|
|
ax.set_title(title)
|
|
# add labels
|
|
if xlabel is not None:
|
|
ax.set_xlabel(xlabel)
|
|
if ylabel is not None:
|
|
ax.set_ylabel(ylabel)
|
|
|
|
|
|
def plot_figure(
|
|
file_lists,
|
|
group_pattern=None,
|
|
fig_length=6,
|
|
fig_width=6,
|
|
sharex=False,
|
|
sharey=False,
|
|
title=None,
|
|
**kwargs,
|
|
):
|
|
if not group_pattern:
|
|
fig, ax = plt.subplots(figsize=(fig_length, fig_width))
|
|
plot_ax(ax, file_lists, title=title, **kwargs)
|
|
else:
|
|
res = group_files(file_lists, group_pattern)
|
|
row_n = int(np.ceil(len(res) / 3))
|
|
col_n = min(len(res), 3)
|
|
fig, axes = plt.subplots(
|
|
row_n,
|
|
col_n,
|
|
sharex=sharex,
|
|
sharey=sharey,
|
|
figsize=(fig_length * col_n, fig_width * row_n),
|
|
squeeze=False,
|
|
)
|
|
axes = axes.flatten()
|
|
for i, (k, v) in enumerate(res.items()):
|
|
plot_ax(axes[i], v, title=k, **kwargs)
|
|
if title: # add title
|
|
fig.suptitle(title, fontsize=20)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="plotter")
|
|
parser.add_argument(
|
|
"--fig-length",
|
|
type=int,
|
|
default=6,
|
|
help="matplotlib figure length (default: 6)",
|
|
)
|
|
parser.add_argument(
|
|
"--fig-width",
|
|
type=int,
|
|
default=6,
|
|
help="matplotlib figure width (default: 6)",
|
|
)
|
|
parser.add_argument(
|
|
"--style",
|
|
default="seaborn",
|
|
help="matplotlib figure style (default: seaborn)",
|
|
)
|
|
parser.add_argument("--title", default=None, help="matplotlib figure title (default: None)")
|
|
parser.add_argument(
|
|
"--xkey",
|
|
default="env_step",
|
|
help="x-axis key in csv file (default: env_step)",
|
|
)
|
|
parser.add_argument("--ykey", default="rew", help="y-axis key in csv file (default: rew)")
|
|
parser.add_argument(
|
|
"--smooth",
|
|
type=int,
|
|
default=0,
|
|
help="smooth radius of y axis (default: 0)",
|
|
)
|
|
parser.add_argument("--xlabel", default="Timesteps", help="matplotlib figure xlabel")
|
|
parser.add_argument("--ylabel", default="Episode Reward", help="matplotlib figure ylabel")
|
|
parser.add_argument(
|
|
"--shaded-std",
|
|
action="store_true",
|
|
help="shaded region corresponding to standard deviation of the group",
|
|
)
|
|
parser.add_argument(
|
|
"--sharex",
|
|
action="store_true",
|
|
help="whether to share x axis within multiple sub-figures",
|
|
)
|
|
parser.add_argument(
|
|
"--sharey",
|
|
action="store_true",
|
|
help="whether to share y axis within multiple sub-figures",
|
|
)
|
|
parser.add_argument(
|
|
"--legend-outside",
|
|
action="store_true",
|
|
help="place the legend outside of the figure",
|
|
)
|
|
parser.add_argument("--xlim", type=int, default=None, help="x-axis limitation (default: None)")
|
|
parser.add_argument("--root-dir", default="./", help="root dir (default: ./)")
|
|
parser.add_argument(
|
|
"--file-pattern",
|
|
type=str,
|
|
default=r".*/test_rew_\d+seeds.csv$",
|
|
help="regular expression to determine whether or not to include target csv "
|
|
"file, default to including all test_rew_{num}seeds.csv file under rootdir",
|
|
)
|
|
parser.add_argument(
|
|
"--group-pattern",
|
|
type=str,
|
|
default=r"(/|^)\w*?\-v(\d|$)",
|
|
help="regular expression to group files in sub-figure, default to grouping "
|
|
'according to env_name dir, "" means no grouping',
|
|
)
|
|
parser.add_argument(
|
|
"--legend-pattern",
|
|
type=str,
|
|
default=r".*",
|
|
help="regular expression to extract legend from csv file path, default to "
|
|
"using file path as legend name.",
|
|
)
|
|
parser.add_argument("--show", action="store_true", help="show figure")
|
|
parser.add_argument("--output-path", type=str, help="figure save path", default="./figure.png")
|
|
parser.add_argument("--dpi", type=int, default=200, help="figure dpi (default: 200)")
|
|
args = parser.parse_args()
|
|
file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
|
|
file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
|
|
if args.style:
|
|
plt.style.use(args.style)
|
|
os.chdir(args.root_dir)
|
|
plot_figure(
|
|
file_lists,
|
|
group_pattern=args.group_pattern,
|
|
legend_pattern=args.legend_pattern,
|
|
fig_length=args.fig_length,
|
|
fig_width=args.fig_width,
|
|
title=args.title,
|
|
xlabel=args.xlabel,
|
|
ylabel=args.ylabel,
|
|
xkey=args.xkey,
|
|
ykey=args.ykey,
|
|
xlim=args.xlim,
|
|
sharex=args.sharex,
|
|
sharey=args.sharey,
|
|
smooth_radius=args.smooth,
|
|
shaded_std=args.shaded_std,
|
|
legend_outside=args.legend_outside,
|
|
)
|
|
if args.output_path:
|
|
plt.savefig(args.output_path, dpi=args.dpi, bbox_inches="tight")
|
|
if args.show:
|
|
plt.show()
|