* add makefile * bump version * add isort and yapf * update contributing.md * update PR template * spelling check
289 lines
8.3 KiB
Python
Executable File
289 lines
8.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import re
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mticker
|
|
import numpy as np
|
|
from tools import csv2numpy, find_all_files, group_files
|
|
|
|
|
|
def smooth(y, radius, mode='two_sided', valid_only=False):
|
|
'''Smooth signal y, where radius is determines the size of the window.
|
|
|
|
mode='twosided':
|
|
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
|
|
mode='causal':
|
|
average over the window [max(index - radius, 0), index]
|
|
valid_only: put nan in entries where the full-sized window is not available
|
|
'''
|
|
assert mode in ('two_sided', 'causal')
|
|
if len(y) < 2 * radius + 1:
|
|
return np.ones_like(y) * y.mean()
|
|
elif mode == 'two_sided':
|
|
convkernel = np.ones(2 * radius + 1)
|
|
out = np.convolve(y, convkernel, mode='same') / \
|
|
np.convolve(np.ones_like(y), convkernel, mode='same')
|
|
if valid_only:
|
|
out[:radius] = out[-radius:] = np.nan
|
|
elif mode == 'causal':
|
|
convkernel = np.ones(radius)
|
|
out = np.convolve(y, convkernel, mode='full') / \
|
|
np.convolve(np.ones_like(y), convkernel, mode='full')
|
|
out = out[:-radius + 1]
|
|
if valid_only:
|
|
out[:radius] = np.nan
|
|
return out
|
|
|
|
|
|
COLORS = (
|
|
[
|
|
# deepmind style
|
|
'#0072B2',
|
|
'#009E73',
|
|
'#D55E00',
|
|
'#CC79A7',
|
|
# '#F0E442',
|
|
'#d73027', # RED
|
|
# built-in color
|
|
'blue',
|
|
'red',
|
|
'pink',
|
|
'cyan',
|
|
'magenta',
|
|
'yellow',
|
|
'black',
|
|
'purple',
|
|
'brown',
|
|
'orange',
|
|
'teal',
|
|
'lightblue',
|
|
'lime',
|
|
'lavender',
|
|
'turquoise',
|
|
'darkgreen',
|
|
'tan',
|
|
'salmon',
|
|
'gold',
|
|
'darkred',
|
|
'darkblue',
|
|
'green',
|
|
# personal color
|
|
'#313695', # DARK BLUE
|
|
'#74add1', # LIGHT BLUE
|
|
'#f46d43', # ORANGE
|
|
'#4daf4a', # GREEN
|
|
'#984ea3', # PURPLE
|
|
'#f781bf', # PINK
|
|
'#ffc832', # YELLOW
|
|
'#000000', # BLACK
|
|
]
|
|
)
|
|
|
|
|
|
def plot_ax(
|
|
ax,
|
|
file_lists,
|
|
legend_pattern=".*",
|
|
xlabel=None,
|
|
ylabel=None,
|
|
title=None,
|
|
xlim=None,
|
|
xkey='env_step',
|
|
ykey='rew',
|
|
smooth_radius=0,
|
|
shaded_std=True,
|
|
legend_outside=False,
|
|
):
|
|
|
|
def legend_fn(x):
|
|
# return os.path.split(os.path.join(
|
|
# args.root_dir, x))[0].replace('/', '_') + " (10)"
|
|
return re.search(legend_pattern, x).group(0)
|
|
|
|
legneds = map(legend_fn, file_lists)
|
|
# sort filelist according to legends
|
|
file_lists = [f for _, f in sorted(zip(legneds, file_lists))]
|
|
legneds = list(map(legend_fn, file_lists))
|
|
|
|
for index, csv_file in enumerate(file_lists):
|
|
csv_dict = csv2numpy(csv_file)
|
|
x, y = csv_dict[xkey], csv_dict[ykey]
|
|
y = smooth(y, radius=smooth_radius)
|
|
color = COLORS[index % len(COLORS)]
|
|
ax.plot(x, y, color=color)
|
|
if shaded_std and ykey + ':shaded' in csv_dict:
|
|
y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius)
|
|
ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2)
|
|
|
|
ax.legend(
|
|
legneds,
|
|
loc=2 if legend_outside else None,
|
|
bbox_to_anchor=(1, 1) if legend_outside else None
|
|
)
|
|
ax.xaxis.set_major_formatter(mticker.EngFormatter())
|
|
if xlim is not None:
|
|
ax.set_xlim(xmin=0, xmax=xlim)
|
|
# add title
|
|
ax.set_title(title)
|
|
# add labels
|
|
if xlabel is not None:
|
|
ax.set_xlabel(xlabel)
|
|
if ylabel is not None:
|
|
ax.set_ylabel(ylabel)
|
|
|
|
|
|
def plot_figure(
|
|
file_lists,
|
|
group_pattern=None,
|
|
fig_length=6,
|
|
fig_width=6,
|
|
sharex=False,
|
|
sharey=False,
|
|
title=None,
|
|
**kwargs,
|
|
):
|
|
if not group_pattern:
|
|
fig, ax = plt.subplots(figsize=(fig_length, fig_width))
|
|
plot_ax(ax, file_lists, title=title, **kwargs)
|
|
else:
|
|
res = group_files(file_lists, group_pattern)
|
|
row_n = int(np.ceil(len(res) / 3))
|
|
col_n = min(len(res), 3)
|
|
fig, axes = plt.subplots(
|
|
row_n,
|
|
col_n,
|
|
sharex=sharex,
|
|
sharey=sharey,
|
|
figsize=(fig_length * col_n, fig_width * row_n),
|
|
squeeze=False
|
|
)
|
|
axes = axes.flatten()
|
|
for i, (k, v) in enumerate(res.items()):
|
|
plot_ax(axes[i], v, title=k, **kwargs)
|
|
if title: # add title
|
|
fig.suptitle(title, fontsize=20)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='plotter')
|
|
parser.add_argument(
|
|
'--fig-length',
|
|
type=int,
|
|
default=6,
|
|
help='matplotlib figure length (default: 6)'
|
|
)
|
|
parser.add_argument(
|
|
'--fig-width',
|
|
type=int,
|
|
default=6,
|
|
help='matplotlib figure width (default: 6)'
|
|
)
|
|
parser.add_argument(
|
|
'--style',
|
|
default='seaborn',
|
|
help='matplotlib figure style (default: seaborn)'
|
|
)
|
|
parser.add_argument(
|
|
'--title', default=None, help='matplotlib figure title (default: None)'
|
|
)
|
|
parser.add_argument(
|
|
'--xkey',
|
|
default='env_step',
|
|
help='x-axis key in csv file (default: env_step)'
|
|
)
|
|
parser.add_argument(
|
|
'--ykey', default='rew', help='y-axis key in csv file (default: rew)'
|
|
)
|
|
parser.add_argument(
|
|
'--smooth', type=int, default=0, help='smooth radius of y axis (default: 0)'
|
|
)
|
|
parser.add_argument(
|
|
'--xlabel', default='Timesteps', help='matplotlib figure xlabel'
|
|
)
|
|
parser.add_argument(
|
|
'--ylabel', default='Episode Reward', help='matplotlib figure ylabel'
|
|
)
|
|
parser.add_argument(
|
|
'--shaded-std',
|
|
action='store_true',
|
|
help='shaded region corresponding to standard deviation of the group'
|
|
)
|
|
parser.add_argument(
|
|
'--sharex',
|
|
action='store_true',
|
|
help='whether to share x axis within multiple sub-figures'
|
|
)
|
|
parser.add_argument(
|
|
'--sharey',
|
|
action='store_true',
|
|
help='whether to share y axis within multiple sub-figures'
|
|
)
|
|
parser.add_argument(
|
|
'--legend-outside',
|
|
action='store_true',
|
|
help='place the legend outside of the figure'
|
|
)
|
|
parser.add_argument(
|
|
'--xlim', type=int, default=None, help='x-axis limitation (default: None)'
|
|
)
|
|
parser.add_argument('--root-dir', default='./', help='root dir (default: ./)')
|
|
parser.add_argument(
|
|
'--file-pattern',
|
|
type=str,
|
|
default=r".*/test_rew_\d+seeds.csv$",
|
|
help='regular expression to determine whether or not to include target csv '
|
|
'file, default to including all test_rew_{num}seeds.csv file under rootdir'
|
|
)
|
|
parser.add_argument(
|
|
'--group-pattern',
|
|
type=str,
|
|
default=r"(/|^)\w*?\-v(\d|$)",
|
|
help='regular expression to group files in sub-figure, default to grouping '
|
|
'according to env_name dir, "" means no grouping'
|
|
)
|
|
parser.add_argument(
|
|
'--legend-pattern',
|
|
type=str,
|
|
default=r".*",
|
|
help='regular expression to extract legend from csv file path, default to '
|
|
'using file path as legend name.'
|
|
)
|
|
parser.add_argument('--show', action='store_true', help='show figure')
|
|
parser.add_argument(
|
|
'--output-path', type=str, help='figure save path', default="./figure.png"
|
|
)
|
|
parser.add_argument(
|
|
'--dpi', type=int, default=200, help='figure dpi (default: 200)'
|
|
)
|
|
args = parser.parse_args()
|
|
file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
|
|
file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
|
|
if args.style:
|
|
plt.style.use(args.style)
|
|
os.chdir(args.root_dir)
|
|
plot_figure(
|
|
file_lists,
|
|
group_pattern=args.group_pattern,
|
|
legend_pattern=args.legend_pattern,
|
|
fig_length=args.fig_length,
|
|
fig_width=args.fig_width,
|
|
title=args.title,
|
|
xlabel=args.xlabel,
|
|
ylabel=args.ylabel,
|
|
xkey=args.xkey,
|
|
ykey=args.ykey,
|
|
xlim=args.xlim,
|
|
sharex=args.sharex,
|
|
sharey=args.sharey,
|
|
smooth_radius=args.smooth,
|
|
shaded_std=args.shaded_std,
|
|
legend_outside=args.legend_outside
|
|
)
|
|
if args.output_path:
|
|
plt.savefig(args.output_path, dpi=args.dpi, bbox_inches='tight')
|
|
if args.show:
|
|
plt.show()
|