Tianshou/examples/mujoco/plotter.py
ChenDRAG 333b8fbd66
add plotter (#335)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2021-04-14 14:06:36 +08:00

236 lines
8.5 KiB
Python
Executable File

#!/usr/bin/env python3
import re
import os
import csv
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from collections import defaultdict
from tools import find_all_files
def smooth(y, radius, mode='two_sided', valid_only=False):
'''Smooth signal y, where radius is determines the size of the window.
mode='twosided':
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
mode='causal':
average over the window [max(index - radius, 0), index]
valid_only: put nan in entries where the full-sized window is not available
'''
assert mode in ('two_sided', 'causal')
if len(y) < 2 * radius + 1:
return np.ones_like(y) * y.mean()
elif mode == 'two_sided':
convkernel = np.ones(2 * radius + 1)
out = np.convolve(y, convkernel, mode='same') / \
np.convolve(np.ones_like(y), convkernel, mode='same')
if valid_only:
out[:radius] = out[-radius:] = np.nan
elif mode == 'causal':
convkernel = np.ones(radius)
out = np.convolve(y, convkernel, mode='full') / \
np.convolve(np.ones_like(y), convkernel, mode='full')
out = out[:-radius + 1]
if valid_only:
out[:radius] = np.nan
return out
COLORS = ([
# deepmind style
'#0072B2',
'#009E73',
'#D55E00',
'#CC79A7',
# '#F0E442',
'#d73027', # RED
# built-in color
'blue', 'red', 'pink', 'cyan', 'magenta', 'yellow', 'black', 'purple',
'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue', 'green',
# personal color
'#313695', # DARK BLUE
'#74add1', # LIGHT BLUE
'#f46d43', # ORANGE
'#4daf4a', # GREEN
'#984ea3', # PURPLE
'#f781bf', # PINK
'#ffc832', # YELLOW
'#000000', # BLACK
])
def csv2numpy(csv_file):
csv_dict = defaultdict(list)
reader = csv.DictReader(open(csv_file))
for row in reader:
for k, v in row.items():
csv_dict[k].append(eval(v))
return {k: np.array(v) for k, v in csv_dict.items()}
def group_files(file_list, pattern):
res = defaultdict(list)
for f in file_list:
match = re.search(pattern, f)
key = match.group() if match else ''
res[key].append(f)
return res
def plot_ax(
ax,
file_lists,
legend_pattern=".*",
xlabel=None,
ylabel=None,
title=None,
xlim=None,
xkey='env_step',
ykey='rew',
smooth_radius=0,
shaded_std=True,
legend_outside=False,
):
def legend_fn(x):
# return os.path.split(os.path.join(
# args.root_dir, x))[0].replace('/', '_') + " (10)"
return re.search(legend_pattern, x).group(0)
legneds = map(legend_fn, file_lists)
# sort filelist according to legends
file_lists = [f for _, f in sorted(zip(legneds, file_lists))]
legneds = list(map(legend_fn, file_lists))
for index, csv_file in enumerate(file_lists):
csv_dict = csv2numpy(csv_file)
x, y = csv_dict[xkey], csv_dict[ykey]
y = smooth(y, radius=smooth_radius)
color = COLORS[index % len(COLORS)]
ax.plot(x, y, color=color)
if shaded_std and ykey + ':shaded' in csv_dict:
y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius)
ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2)
ax.legend(legneds, loc=2 if legend_outside else None,
bbox_to_anchor=(1, 1) if legend_outside else None)
ax.xaxis.set_major_formatter(mticker.EngFormatter())
if xlim is not None:
ax.set_xlim(xmin=0, xmax=xlim)
# add title
ax.set_title(title)
# add labels
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
def plot_figure(
file_lists,
group_pattern=None,
fig_length=6,
fig_width=6,
sharex=False,
sharey=False,
title=None,
**kwargs,
):
if not group_pattern:
fig, ax = plt.subplots(figsize=(fig_length, fig_width))
plot_ax(ax, file_lists, title=title, **kwargs)
else:
res = group_files(file_lists, group_pattern)
row_n = int(np.ceil(len(res) / 3))
col_n = min(len(res), 3)
fig, axes = plt.subplots(row_n, col_n, sharex=sharex, sharey=sharey, figsize=(
fig_length * col_n, fig_width * row_n), squeeze=False)
axes = axes.flatten()
for i, (k, v) in enumerate(res.items()):
plot_ax(axes[i], v, title=k, **kwargs)
if title: # add title
fig.suptitle(title, fontsize=20)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='plotter')
parser.add_argument('--fig-length', type=int, default=6,
help='matplotlib figure length (default: 6)')
parser.add_argument('--fig-width', type=int, default=6,
help='matplotlib figure width (default: 6)')
parser.add_argument('--style', default='seaborn',
help='matplotlib figure style (default: seaborn)')
parser.add_argument('--title', default=None,
help='matplotlib figure title (default: None)')
parser.add_argument('--xkey', default='env_step',
help='x-axis key in csv file (default: env_step)')
parser.add_argument('--ykey', default='rew',
help='y-axis key in csv file (default: rew)')
parser.add_argument('--smooth', type=int, default=0,
help='smooth radius of y axis (default: 0)')
parser.add_argument('--xlabel', default='Timesteps',
help='matplotlib figure xlabel')
parser.add_argument('--ylabel', default='Episode Reward',
help='matplotlib figure ylabel')
parser.add_argument(
'--shaded-std', action='store_true',
help='shaded region corresponding to standard deviation of the group')
parser.add_argument('--sharex', action='store_true',
help='whether to share x axis within multiple sub-figures')
parser.add_argument('--sharey', action='store_true',
help='whether to share y axis within multiple sub-figures')
parser.add_argument('--legend-outside', action='store_true',
help='place the legend outside of the figure')
parser.add_argument('--xlim', type=int, default=None,
help='x-axis limitation (default: None)')
parser.add_argument('--root-dir', default='./', help='root dir (default: ./)')
parser.add_argument(
'--file-pattern', type=str, default=r".*/test_rew_\d+seeds.csv$",
help='regular expression to determine whether or not to include target csv '
'file, default to including all test_rew_{num}seeds.csv file under rootdir')
parser.add_argument(
'--group-pattern', type=str, default=r"(/|^)\w*?\-v(\d|$)",
help='regular expression to group files in sub-figure, default to grouping '
'according to env_name dir, "" means no grouping')
parser.add_argument(
'--legend-pattern', type=str, default=r".*",
help='regular expression to extract legend from csv file path, default to '
'using file path as legend name.')
parser.add_argument('--show', action='store_true', help='show figure')
parser.add_argument('--output-path', type=str,
help='figure save path', default="./figure.png")
parser.add_argument('--dpi', type=int, default=200,
help='figure dpi (default: 200)')
args = parser.parse_args()
file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
if args.style:
plt.style.use(args.style)
os.chdir(args.root_dir)
plot_figure(
file_lists,
group_pattern=args.group_pattern,
legend_pattern=args.legend_pattern,
fig_length=args.fig_length,
fig_width=args.fig_width,
title=args.title,
xlabel=args.xlabel,
ylabel=args.ylabel,
xkey=args.xkey,
ykey=args.ykey,
xlim=args.xlim,
sharex=args.sharex,
sharey=args.sharey,
smooth_radius=args.smooth,
shaded_std=args.shaded_std,
legend_outside=args.legend_outside)
if args.output_path:
plt.savefig(args.output_path,
dpi=args.dpi, bbox_inches='tight')
if args.show:
plt.show()