add plotter (#335)

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-04-14 14:06:36 +08:00 committed by GitHub
parent dd4a01132c
commit 333b8fbd66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 357 additions and 12 deletions

View File

@ -36,17 +36,26 @@ $ tensorboard --logdir log
You can also reproduce the benchmark (e.g. SAC in Ant-v3) with the example script we provide under `examples/mujoco/`:
```bash
$ ./run_experiments.sh Ant-v3
$ ./run_experiments.sh Ant-v3 sac
```
This will start 10 experiments with different seeds.
Now that all the experiments are finished, we can convert all tfevent files into csv files and then try plotting the results.
```bash
$ ./tools.py --root-dir ./results/Ant-v3/sac
$ ./plotter.py --root-dir ./results/Ant-v3 --shaded-std --legend-pattern "\\w+"
```
#### Example benchmark
<img src="./benchmark/Ant-v3/offpolicy.png" width="500" height="450">
Other graphs can be found under `/examples/mujuco/benchmark/`
For pretrained agents, detailed graphs (single agent, single game) and log details, please refer to [https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/](https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/).
## Offpolicy algorithms
#### Notes
@ -236,7 +245,7 @@ Other graphs can be found under `/examples/mujuco/benchmark/`
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
<a name="footnote2">[2]</a> Pretrained agents, detailed graphs (single agent, single game) and log details can all be found [here](https://cloud.tsinghua.edu.cn/d/356e0f5d1e66426b9828/).
<a name="footnote2">[2]</a> Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/](https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/).
<a name="footnote3">[3]</a> We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though)

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 344 KiB

After

Width:  |  Height:  |  Size: 252 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 203 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 342 KiB

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 374 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 423 KiB

After

Width:  |  Height:  |  Size: 378 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 289 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 304 KiB

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 183 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 368 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 328 KiB

After

Width:  |  Height:  |  Size: 226 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 314 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 351 KiB

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 223 KiB

View File

@ -2,36 +2,36 @@
## Ant-v3
![](Ant-v3/offpolicy.png)
![](Ant-v3/all.png)
## HalfCheetah-v3
![](HalfCheetah-v3/offpolicy.png)
![](HalfCheetah-v3/all.png)
## Hopper-v3
![](Hopper-v3/offpolicy.png)
![](Hopper-v3/all.png)
## Walker2d-v3
![](Walker2d-v3/offpolicy.png)
![](Walker2d-v3/all.png)
## Swimmer-v3
![](Swimmer-v3/offpolicy.png)
![](Swimmer-v3/all.png)
## Humanoid-v3
![](Humanoid-v3/offpolicy.png)
![](Humanoid-v3/all.png)
## Reacher-v2
![](Reacher-v2/offpolicy.png)
![](Reacher-v2/all.png)
## InvertedPendulum-v2
![](InvertedPendulum-v2/offpolicy.png)
![](InvertedPendulum-v2/all.png)
## InvertedDoublePendulum-v2
![](InvertedDoublePendulum-v2/offpolicy.png)
![](InvertedDoublePendulum-v2/all.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 206 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 232 KiB

After

Width:  |  Height:  |  Size: 126 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 163 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 302 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 340 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 356 KiB

After

Width:  |  Height:  |  Size: 302 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

235
examples/mujoco/plotter.py Executable file
View File

@ -0,0 +1,235 @@
#!/usr/bin/env python3
import re
import os
import csv
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from collections import defaultdict
from tools import find_all_files
def smooth(y, radius, mode='two_sided', valid_only=False):
'''Smooth signal y, where radius is determines the size of the window.
mode='twosided':
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
mode='causal':
average over the window [max(index - radius, 0), index]
valid_only: put nan in entries where the full-sized window is not available
'''
assert mode in ('two_sided', 'causal')
if len(y) < 2 * radius + 1:
return np.ones_like(y) * y.mean()
elif mode == 'two_sided':
convkernel = np.ones(2 * radius + 1)
out = np.convolve(y, convkernel, mode='same') / \
np.convolve(np.ones_like(y), convkernel, mode='same')
if valid_only:
out[:radius] = out[-radius:] = np.nan
elif mode == 'causal':
convkernel = np.ones(radius)
out = np.convolve(y, convkernel, mode='full') / \
np.convolve(np.ones_like(y), convkernel, mode='full')
out = out[:-radius + 1]
if valid_only:
out[:radius] = np.nan
return out
COLORS = ([
# deepmind style
'#0072B2',
'#009E73',
'#D55E00',
'#CC79A7',
# '#F0E442',
'#d73027', # RED
# built-in color
'blue', 'red', 'pink', 'cyan', 'magenta', 'yellow', 'black', 'purple',
'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue', 'green',
# personal color
'#313695', # DARK BLUE
'#74add1', # LIGHT BLUE
'#f46d43', # ORANGE
'#4daf4a', # GREEN
'#984ea3', # PURPLE
'#f781bf', # PINK
'#ffc832', # YELLOW
'#000000', # BLACK
])
def csv2numpy(csv_file):
csv_dict = defaultdict(list)
reader = csv.DictReader(open(csv_file))
for row in reader:
for k, v in row.items():
csv_dict[k].append(eval(v))
return {k: np.array(v) for k, v in csv_dict.items()}
def group_files(file_list, pattern):
res = defaultdict(list)
for f in file_list:
match = re.search(pattern, f)
key = match.group() if match else ''
res[key].append(f)
return res
def plot_ax(
ax,
file_lists,
legend_pattern=".*",
xlabel=None,
ylabel=None,
title=None,
xlim=None,
xkey='env_step',
ykey='rew',
smooth_radius=0,
shaded_std=True,
legend_outside=False,
):
def legend_fn(x):
# return os.path.split(os.path.join(
# args.root_dir, x))[0].replace('/', '_') + " (10)"
return re.search(legend_pattern, x).group(0)
legneds = map(legend_fn, file_lists)
# sort filelist according to legends
file_lists = [f for _, f in sorted(zip(legneds, file_lists))]
legneds = list(map(legend_fn, file_lists))
for index, csv_file in enumerate(file_lists):
csv_dict = csv2numpy(csv_file)
x, y = csv_dict[xkey], csv_dict[ykey]
y = smooth(y, radius=smooth_radius)
color = COLORS[index % len(COLORS)]
ax.plot(x, y, color=color)
if shaded_std and ykey + ':shaded' in csv_dict:
y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius)
ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2)
ax.legend(legneds, loc=2 if legend_outside else None,
bbox_to_anchor=(1, 1) if legend_outside else None)
ax.xaxis.set_major_formatter(mticker.EngFormatter())
if xlim is not None:
ax.set_xlim(xmin=0, xmax=xlim)
# add title
ax.set_title(title)
# add labels
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
def plot_figure(
file_lists,
group_pattern=None,
fig_length=6,
fig_width=6,
sharex=False,
sharey=False,
title=None,
**kwargs,
):
if not group_pattern:
fig, ax = plt.subplots(figsize=(fig_length, fig_width))
plot_ax(ax, file_lists, title=title, **kwargs)
else:
res = group_files(file_lists, group_pattern)
row_n = int(np.ceil(len(res) / 3))
col_n = min(len(res), 3)
fig, axes = plt.subplots(row_n, col_n, sharex=sharex, sharey=sharey, figsize=(
fig_length * col_n, fig_width * row_n), squeeze=False)
axes = axes.flatten()
for i, (k, v) in enumerate(res.items()):
plot_ax(axes[i], v, title=k, **kwargs)
if title: # add title
fig.suptitle(title, fontsize=20)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='plotter')
parser.add_argument('--fig-length', type=int, default=6,
help='matplotlib figure length (default: 6)')
parser.add_argument('--fig-width', type=int, default=6,
help='matplotlib figure width (default: 6)')
parser.add_argument('--style', default='seaborn',
help='matplotlib figure style (default: seaborn)')
parser.add_argument('--title', default=None,
help='matplotlib figure title (default: None)')
parser.add_argument('--xkey', default='env_step',
help='x-axis key in csv file (default: env_step)')
parser.add_argument('--ykey', default='rew',
help='y-axis key in csv file (default: rew)')
parser.add_argument('--smooth', type=int, default=0,
help='smooth radius of y axis (default: 0)')
parser.add_argument('--xlabel', default='Timesteps',
help='matplotlib figure xlabel')
parser.add_argument('--ylabel', default='Episode Reward',
help='matplotlib figure ylabel')
parser.add_argument(
'--shaded-std', action='store_true',
help='shaded region corresponding to standard deviation of the group')
parser.add_argument('--sharex', action='store_true',
help='whether to share x axis within multiple sub-figures')
parser.add_argument('--sharey', action='store_true',
help='whether to share y axis within multiple sub-figures')
parser.add_argument('--legend-outside', action='store_true',
help='place the legend outside of the figure')
parser.add_argument('--xlim', type=int, default=None,
help='x-axis limitation (default: None)')
parser.add_argument('--root-dir', default='./', help='root dir (default: ./)')
parser.add_argument(
'--file-pattern', type=str, default=r".*/test_rew_\d+seeds.csv$",
help='regular expression to determine whether or not to include target csv '
'file, default to including all test_rew_{num}seeds.csv file under rootdir')
parser.add_argument(
'--group-pattern', type=str, default=r"(/|^)\w*?\-v(\d|$)",
help='regular expression to group files in sub-figure, default to grouping '
'according to env_name dir, "" means no grouping')
parser.add_argument(
'--legend-pattern', type=str, default=r".*",
help='regular expression to extract legend from csv file path, default to '
'using file path as legend name.')
parser.add_argument('--show', action='store_true', help='show figure')
parser.add_argument('--output-path', type=str,
help='figure save path', default="./figure.png")
parser.add_argument('--dpi', type=int, default=200,
help='figure dpi (default: 200)')
args = parser.parse_args()
file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
if args.style:
plt.style.use(args.style)
os.chdir(args.root_dir)
plot_figure(
file_lists,
group_pattern=args.group_pattern,
legend_pattern=args.legend_pattern,
fig_length=args.fig_length,
fig_width=args.fig_width,
title=args.title,
xlabel=args.xlabel,
ylabel=args.ylabel,
xkey=args.xkey,
ykey=args.ykey,
xlim=args.xlim,
sharex=args.sharex,
sharey=args.sharey,
smooth_radius=args.smooth,
shaded_std=args.shaded_std,
legend_outside=args.legend_outside)
if args.output_path:
plt.savefig(args.output_path,
dpi=args.dpi, bbox_inches='tight')
if args.show:
plt.show()

View File

@ -2,10 +2,11 @@
LOGDIR="results"
TASK=$1
ALGO=$2
echo "Experiments started."
for seed in $(seq 0 9)
do
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
python mujoco_${ALGO}.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
done
echo "Experiments ended."

100
examples/mujoco/tools.py Executable file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
import os
import re
import csv
import tqdm
import argparse
import numpy as np
from typing import Dict, List, Union
from tensorboard.backend.event_processing import event_accumulator
def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
"""Find all files under root_dir according to relative pattern."""
file_list = []
for dirname, _, files in os.walk(root_dir):
for f in files:
absolute_path = os.path.join(dirname, f)
if re.match(pattern, absolute_path):
file_list.append(absolute_path)
return file_list
def convert_tfevents_to_csv(
root_dir: str, refresh: bool = False
) -> Dict[str, np.ndarray]:
"""Recursively convert test/rew from all tfevent file under root_dir to csv.
This function assumes that there is at most one tfevents file in each directory
and will add suffix to that directory.
:param bool refresh: re-create csv file under any condition.
"""
tfevent_files = find_all_files(root_dir, re.compile(r"^.*tfevents.*$"))
print(f"Converting {len(tfevent_files)} tfevents files under {root_dir} ...")
result = {}
with tqdm.tqdm(tfevent_files) as t:
for tfevent_file in t:
t.set_postfix(file=tfevent_file)
output_file = os.path.join(os.path.split(tfevent_file)[0], "test_rew.csv")
if os.path.exists(output_file) and not refresh:
content = list(csv.reader(open(output_file, "r")))
if content[0] == ["env_step", "rew", "time"]:
for i in range(1, len(content)):
content[i] = list(map(eval, content[i]))
result[output_file] = content
continue
ea = event_accumulator.EventAccumulator(tfevent_file)
ea.Reload()
initial_time = ea._first_event_timestamp
content = [["env_step", "rew", "time"]]
for test_rew in ea.scalars.Items("test/rew"):
content.append([
round(test_rew.step, 4),
round(test_rew.value, 4),
round(test_rew.wall_time - initial_time, 4),
])
csv.writer(open(output_file, 'w')).writerows(content)
result[output_file] = content
return result
def merge_csv(
csv_files: List[List[Union[str, int, float]]],
root_dir: str,
remove_zero: bool = False,
) -> None:
"""Merge result in csv_files into a single csv file."""
assert len(csv_files) > 0
if remove_zero:
for k, v in csv_files.items():
if v[1][0] == 0:
v.pop(1)
sorted_keys = sorted(csv_files.keys())
sorted_values = [csv_files[k][1:] for k in sorted_keys]
content = [["env_step", "rew", "rew:shaded"] + list(map(
lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))]
for rows in zip(*sorted_values):
array = np.array(rows)
assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0])
line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)]
line += array[:, 1].tolist()
content.append(line)
output_path = os.path.join(root_dir, f"test_rew_{len(csv_files)}seeds.csv")
print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.")
csv.writer(open(output_path, "w")).writerows(content)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root-dir', type=str)
parser.add_argument(
'--refresh', action="store_true",
help="Re-generate all csv files instead of using existing one.")
parser.add_argument(
'--remove-zero', action="store_true",
help="Remove the data point of env_step == 0.")
args = parser.parse_args()
csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh)
merge_csv(csv_files, args.root_dir, args.remove_zero)