* add makefile * bump version * add isort and yapf * update contributing.md * update PR template * spelling check
		
			
				
	
	
		
			289 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			289 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| import argparse
 | |
| import os
 | |
| import re
 | |
| 
 | |
| import matplotlib.pyplot as plt
 | |
| import matplotlib.ticker as mticker
 | |
| import numpy as np
 | |
| from tools import csv2numpy, find_all_files, group_files
 | |
| 
 | |
| 
 | |
| def smooth(y, radius, mode='two_sided', valid_only=False):
 | |
|     '''Smooth signal y, where radius is determines the size of the window.
 | |
| 
 | |
|     mode='twosided':
 | |
|         average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
 | |
|     mode='causal':
 | |
|         average over the window [max(index - radius, 0), index]
 | |
|     valid_only: put nan in entries where the full-sized window is not available
 | |
|     '''
 | |
|     assert mode in ('two_sided', 'causal')
 | |
|     if len(y) < 2 * radius + 1:
 | |
|         return np.ones_like(y) * y.mean()
 | |
|     elif mode == 'two_sided':
 | |
|         convkernel = np.ones(2 * radius + 1)
 | |
|         out = np.convolve(y, convkernel, mode='same') / \
 | |
|             np.convolve(np.ones_like(y), convkernel, mode='same')
 | |
|         if valid_only:
 | |
|             out[:radius] = out[-radius:] = np.nan
 | |
|     elif mode == 'causal':
 | |
|         convkernel = np.ones(radius)
 | |
|         out = np.convolve(y, convkernel, mode='full') / \
 | |
|             np.convolve(np.ones_like(y), convkernel, mode='full')
 | |
|         out = out[:-radius + 1]
 | |
|         if valid_only:
 | |
|             out[:radius] = np.nan
 | |
|     return out
 | |
| 
 | |
| 
 | |
| COLORS = (
 | |
|     [
 | |
|         # deepmind style
 | |
|         '#0072B2',
 | |
|         '#009E73',
 | |
|         '#D55E00',
 | |
|         '#CC79A7',
 | |
|         # '#F0E442',
 | |
|         '#d73027',  # RED
 | |
|         # built-in color
 | |
|         'blue',
 | |
|         'red',
 | |
|         'pink',
 | |
|         'cyan',
 | |
|         'magenta',
 | |
|         'yellow',
 | |
|         'black',
 | |
|         'purple',
 | |
|         'brown',
 | |
|         'orange',
 | |
|         'teal',
 | |
|         'lightblue',
 | |
|         'lime',
 | |
|         'lavender',
 | |
|         'turquoise',
 | |
|         'darkgreen',
 | |
|         'tan',
 | |
|         'salmon',
 | |
|         'gold',
 | |
|         'darkred',
 | |
|         'darkblue',
 | |
|         'green',
 | |
|         # personal color
 | |
|         '#313695',  # DARK BLUE
 | |
|         '#74add1',  # LIGHT BLUE
 | |
|         '#f46d43',  # ORANGE
 | |
|         '#4daf4a',  # GREEN
 | |
|         '#984ea3',  # PURPLE
 | |
|         '#f781bf',  # PINK
 | |
|         '#ffc832',  # YELLOW
 | |
|         '#000000',  # BLACK
 | |
|     ]
 | |
| )
 | |
| 
 | |
| 
 | |
| def plot_ax(
 | |
|     ax,
 | |
|     file_lists,
 | |
|     legend_pattern=".*",
 | |
|     xlabel=None,
 | |
|     ylabel=None,
 | |
|     title=None,
 | |
|     xlim=None,
 | |
|     xkey='env_step',
 | |
|     ykey='rew',
 | |
|     smooth_radius=0,
 | |
|     shaded_std=True,
 | |
|     legend_outside=False,
 | |
| ):
 | |
| 
 | |
|     def legend_fn(x):
 | |
|         # return os.path.split(os.path.join(
 | |
|         #     args.root_dir, x))[0].replace('/', '_') + " (10)"
 | |
|         return re.search(legend_pattern, x).group(0)
 | |
| 
 | |
|     legneds = map(legend_fn, file_lists)
 | |
|     # sort filelist according to legends
 | |
|     file_lists = [f for _, f in sorted(zip(legneds, file_lists))]
 | |
|     legneds = list(map(legend_fn, file_lists))
 | |
| 
 | |
|     for index, csv_file in enumerate(file_lists):
 | |
|         csv_dict = csv2numpy(csv_file)
 | |
|         x, y = csv_dict[xkey], csv_dict[ykey]
 | |
|         y = smooth(y, radius=smooth_radius)
 | |
|         color = COLORS[index % len(COLORS)]
 | |
|         ax.plot(x, y, color=color)
 | |
|         if shaded_std and ykey + ':shaded' in csv_dict:
 | |
|             y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius)
 | |
|             ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2)
 | |
| 
 | |
|     ax.legend(
 | |
|         legneds,
 | |
|         loc=2 if legend_outside else None,
 | |
|         bbox_to_anchor=(1, 1) if legend_outside else None
 | |
|     )
 | |
|     ax.xaxis.set_major_formatter(mticker.EngFormatter())
 | |
|     if xlim is not None:
 | |
|         ax.set_xlim(xmin=0, xmax=xlim)
 | |
|     # add title
 | |
|     ax.set_title(title)
 | |
|     # add labels
 | |
|     if xlabel is not None:
 | |
|         ax.set_xlabel(xlabel)
 | |
|     if ylabel is not None:
 | |
|         ax.set_ylabel(ylabel)
 | |
| 
 | |
| 
 | |
| def plot_figure(
 | |
|     file_lists,
 | |
|     group_pattern=None,
 | |
|     fig_length=6,
 | |
|     fig_width=6,
 | |
|     sharex=False,
 | |
|     sharey=False,
 | |
|     title=None,
 | |
|     **kwargs,
 | |
| ):
 | |
|     if not group_pattern:
 | |
|         fig, ax = plt.subplots(figsize=(fig_length, fig_width))
 | |
|         plot_ax(ax, file_lists, title=title, **kwargs)
 | |
|     else:
 | |
|         res = group_files(file_lists, group_pattern)
 | |
|         row_n = int(np.ceil(len(res) / 3))
 | |
|         col_n = min(len(res), 3)
 | |
|         fig, axes = plt.subplots(
 | |
|             row_n,
 | |
|             col_n,
 | |
|             sharex=sharex,
 | |
|             sharey=sharey,
 | |
|             figsize=(fig_length * col_n, fig_width * row_n),
 | |
|             squeeze=False
 | |
|         )
 | |
|         axes = axes.flatten()
 | |
|         for i, (k, v) in enumerate(res.items()):
 | |
|             plot_ax(axes[i], v, title=k, **kwargs)
 | |
|     if title:  # add title
 | |
|         fig.suptitle(title, fontsize=20)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description='plotter')
 | |
|     parser.add_argument(
 | |
|         '--fig-length',
 | |
|         type=int,
 | |
|         default=6,
 | |
|         help='matplotlib figure length (default: 6)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--fig-width',
 | |
|         type=int,
 | |
|         default=6,
 | |
|         help='matplotlib figure width (default: 6)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--style',
 | |
|         default='seaborn',
 | |
|         help='matplotlib figure style (default: seaborn)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--title', default=None, help='matplotlib figure title (default: None)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--xkey',
 | |
|         default='env_step',
 | |
|         help='x-axis key in csv file (default: env_step)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--ykey', default='rew', help='y-axis key in csv file (default: rew)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--smooth', type=int, default=0, help='smooth radius of y axis (default: 0)'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--xlabel', default='Timesteps', help='matplotlib figure xlabel'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--ylabel', default='Episode Reward', help='matplotlib figure ylabel'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--shaded-std',
 | |
|         action='store_true',
 | |
|         help='shaded region corresponding to standard deviation of the group'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--sharex',
 | |
|         action='store_true',
 | |
|         help='whether to share x axis within multiple sub-figures'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--sharey',
 | |
|         action='store_true',
 | |
|         help='whether to share y axis within multiple sub-figures'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--legend-outside',
 | |
|         action='store_true',
 | |
|         help='place the legend outside of the figure'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--xlim', type=int, default=None, help='x-axis limitation (default: None)'
 | |
|     )
 | |
|     parser.add_argument('--root-dir', default='./', help='root dir (default: ./)')
 | |
|     parser.add_argument(
 | |
|         '--file-pattern',
 | |
|         type=str,
 | |
|         default=r".*/test_rew_\d+seeds.csv$",
 | |
|         help='regular expression to determine whether or not to include target csv '
 | |
|         'file, default to including all test_rew_{num}seeds.csv file under rootdir'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--group-pattern',
 | |
|         type=str,
 | |
|         default=r"(/|^)\w*?\-v(\d|$)",
 | |
|         help='regular expression to group files in sub-figure, default to grouping '
 | |
|         'according to env_name dir, "" means no grouping'
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--legend-pattern',
 | |
|         type=str,
 | |
|         default=r".*",
 | |
|         help='regular expression to extract legend from csv file path, default to '
 | |
|         'using file path as legend name.'
 | |
|     )
 | |
|     parser.add_argument('--show', action='store_true', help='show figure')
 | |
|     parser.add_argument(
 | |
|         '--output-path', type=str, help='figure save path', default="./figure.png"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         '--dpi', type=int, default=200, help='figure dpi (default: 200)'
 | |
|     )
 | |
|     args = parser.parse_args()
 | |
|     file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
 | |
|     file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
 | |
|     if args.style:
 | |
|         plt.style.use(args.style)
 | |
|     os.chdir(args.root_dir)
 | |
|     plot_figure(
 | |
|         file_lists,
 | |
|         group_pattern=args.group_pattern,
 | |
|         legend_pattern=args.legend_pattern,
 | |
|         fig_length=args.fig_length,
 | |
|         fig_width=args.fig_width,
 | |
|         title=args.title,
 | |
|         xlabel=args.xlabel,
 | |
|         ylabel=args.ylabel,
 | |
|         xkey=args.xkey,
 | |
|         ykey=args.ykey,
 | |
|         xlim=args.xlim,
 | |
|         sharex=args.sharex,
 | |
|         sharey=args.sharey,
 | |
|         smooth_radius=args.smooth,
 | |
|         shaded_std=args.shaded_std,
 | |
|         legend_outside=args.legend_outside
 | |
|     )
 | |
|     if args.output_path:
 | |
|         plt.savefig(args.output_path, dpi=args.dpi, bbox_inches='tight')
 | |
|     if args.show:
 | |
|         plt.show()
 |