add plotter (#335)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
@ -36,17 +36,26 @@ $ tensorboard --logdir log
|
||||
You can also reproduce the benchmark (e.g. SAC in Ant-v3) with the example script we provide under `examples/mujoco/`:
|
||||
|
||||
```bash
|
||||
$ ./run_experiments.sh Ant-v3
|
||||
$ ./run_experiments.sh Ant-v3 sac
|
||||
```
|
||||
|
||||
This will start 10 experiments with different seeds.
|
||||
|
||||
Now that all the experiments are finished, we can convert all tfevent files into csv files and then try plotting the results.
|
||||
|
||||
```bash
|
||||
$ ./tools.py --root-dir ./results/Ant-v3/sac
|
||||
$ ./plotter.py --root-dir ./results/Ant-v3 --shaded-std --legend-pattern "\\w+"
|
||||
```
|
||||
|
||||
#### Example benchmark
|
||||
|
||||
<img src="./benchmark/Ant-v3/offpolicy.png" width="500" height="450">
|
||||
|
||||
Other graphs can be found under `/examples/mujuco/benchmark/`
|
||||
|
||||
For pretrained agents, detailed graphs (single agent, single game) and log details, please refer to [https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/](https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/).
|
||||
|
||||
## Offpolicy algorithms
|
||||
#### Notes
|
||||
|
||||
@ -236,7 +245,7 @@ Other graphs can be found under `/examples/mujuco/benchmark/`
|
||||
|
||||
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
|
||||
|
||||
<a name="footnote2">[2]</a> Pretrained agents, detailed graphs (single agent, single game) and log details can all be found [here](https://cloud.tsinghua.edu.cn/d/356e0f5d1e66426b9828/).
|
||||
<a name="footnote2">[2]</a> Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/](https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/).
|
||||
|
||||
<a name="footnote3">[3]</a> We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though)
|
||||
|
||||
|
BIN
examples/mujoco/benchmark/Ant-v3/all.png
Normal file
After Width: | Height: | Size: 292 KiB |
Before Width: | Height: | Size: 344 KiB After Width: | Height: | Size: 252 KiB |
BIN
examples/mujoco/benchmark/Ant-v3/onpolicy.png
Normal file
After Width: | Height: | Size: 203 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/all.png
Normal file
After Width: | Height: | Size: 241 KiB |
Before Width: | Height: | Size: 342 KiB After Width: | Height: | Size: 204 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/onpolicy.png
Normal file
After Width: | Height: | Size: 156 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/all.png
Normal file
After Width: | Height: | Size: 374 KiB |
Before Width: | Height: | Size: 423 KiB After Width: | Height: | Size: 378 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/onpolicy.png
Normal file
After Width: | Height: | Size: 232 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/all.png
Normal file
After Width: | Height: | Size: 289 KiB |
Before Width: | Height: | Size: 304 KiB After Width: | Height: | Size: 240 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/onpolicy.png
Normal file
After Width: | Height: | Size: 183 KiB |
BIN
examples/mujoco/benchmark/InvertedDoublePendulum-v2/all.png
Normal file
After Width: | Height: | Size: 368 KiB |
Before Width: | Height: | Size: 328 KiB After Width: | Height: | Size: 226 KiB |
BIN
examples/mujoco/benchmark/InvertedDoublePendulum-v2/onpolicy.png
Normal file
After Width: | Height: | Size: 281 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/all.png
Normal file
After Width: | Height: | Size: 314 KiB |
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 271 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/onpolicy.png
Normal file
After Width: | Height: | Size: 223 KiB |
@ -2,36 +2,36 @@
|
||||
|
||||
## Ant-v3
|
||||
|
||||

|
||||

|
||||
|
||||
## HalfCheetah-v3
|
||||
|
||||

|
||||

|
||||
|
||||
## Hopper-v3
|
||||
|
||||

|
||||

|
||||
|
||||
## Walker2d-v3
|
||||
|
||||

|
||||

|
||||
|
||||
## Swimmer-v3
|
||||
|
||||

|
||||

|
||||
|
||||
## Humanoid-v3
|
||||
|
||||

|
||||

|
||||
|
||||
## Reacher-v2
|
||||
|
||||

|
||||

|
||||
|
||||
## InvertedPendulum-v2
|
||||
|
||||

|
||||

|
||||
|
||||
## InvertedDoublePendulum-v2
|
||||
|
||||

|
||||

|
||||
|
BIN
examples/mujoco/benchmark/Reacher-v2/all.png
Normal file
After Width: | Height: | Size: 206 KiB |
Before Width: | Height: | Size: 232 KiB After Width: | Height: | Size: 126 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/onpolicy.png
Normal file
After Width: | Height: | Size: 163 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/all.png
Normal file
After Width: | Height: | Size: 238 KiB |
Before Width: | Height: | Size: 302 KiB After Width: | Height: | Size: 210 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/onpolicy.png
Normal file
After Width: | Height: | Size: 144 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/all.png
Normal file
After Width: | Height: | Size: 340 KiB |
Before Width: | Height: | Size: 356 KiB After Width: | Height: | Size: 302 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/onpolicy.png
Normal file
After Width: | Height: | Size: 208 KiB |
235
examples/mujoco/plotter.py
Executable file
@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import re
|
||||
import os
|
||||
import csv
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mticker
|
||||
from collections import defaultdict
|
||||
|
||||
from tools import find_all_files
|
||||
|
||||
|
||||
def smooth(y, radius, mode='two_sided', valid_only=False):
|
||||
'''Smooth signal y, where radius is determines the size of the window.
|
||||
|
||||
mode='twosided':
|
||||
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
|
||||
mode='causal':
|
||||
average over the window [max(index - radius, 0), index]
|
||||
valid_only: put nan in entries where the full-sized window is not available
|
||||
'''
|
||||
assert mode in ('two_sided', 'causal')
|
||||
if len(y) < 2 * radius + 1:
|
||||
return np.ones_like(y) * y.mean()
|
||||
elif mode == 'two_sided':
|
||||
convkernel = np.ones(2 * radius + 1)
|
||||
out = np.convolve(y, convkernel, mode='same') / \
|
||||
np.convolve(np.ones_like(y), convkernel, mode='same')
|
||||
if valid_only:
|
||||
out[:radius] = out[-radius:] = np.nan
|
||||
elif mode == 'causal':
|
||||
convkernel = np.ones(radius)
|
||||
out = np.convolve(y, convkernel, mode='full') / \
|
||||
np.convolve(np.ones_like(y), convkernel, mode='full')
|
||||
out = out[:-radius + 1]
|
||||
if valid_only:
|
||||
out[:radius] = np.nan
|
||||
return out
|
||||
|
||||
|
||||
COLORS = ([
|
||||
# deepmind style
|
||||
'#0072B2',
|
||||
'#009E73',
|
||||
'#D55E00',
|
||||
'#CC79A7',
|
||||
# '#F0E442',
|
||||
'#d73027', # RED
|
||||
# built-in color
|
||||
'blue', 'red', 'pink', 'cyan', 'magenta', 'yellow', 'black', 'purple',
|
||||
'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue', 'green',
|
||||
# personal color
|
||||
'#313695', # DARK BLUE
|
||||
'#74add1', # LIGHT BLUE
|
||||
'#f46d43', # ORANGE
|
||||
'#4daf4a', # GREEN
|
||||
'#984ea3', # PURPLE
|
||||
'#f781bf', # PINK
|
||||
'#ffc832', # YELLOW
|
||||
'#000000', # BLACK
|
||||
])
|
||||
|
||||
|
||||
def csv2numpy(csv_file):
|
||||
csv_dict = defaultdict(list)
|
||||
reader = csv.DictReader(open(csv_file))
|
||||
for row in reader:
|
||||
for k, v in row.items():
|
||||
csv_dict[k].append(eval(v))
|
||||
return {k: np.array(v) for k, v in csv_dict.items()}
|
||||
|
||||
|
||||
def group_files(file_list, pattern):
|
||||
res = defaultdict(list)
|
||||
for f in file_list:
|
||||
match = re.search(pattern, f)
|
||||
key = match.group() if match else ''
|
||||
res[key].append(f)
|
||||
return res
|
||||
|
||||
|
||||
def plot_ax(
|
||||
ax,
|
||||
file_lists,
|
||||
legend_pattern=".*",
|
||||
xlabel=None,
|
||||
ylabel=None,
|
||||
title=None,
|
||||
xlim=None,
|
||||
xkey='env_step',
|
||||
ykey='rew',
|
||||
smooth_radius=0,
|
||||
shaded_std=True,
|
||||
legend_outside=False,
|
||||
):
|
||||
def legend_fn(x):
|
||||
# return os.path.split(os.path.join(
|
||||
# args.root_dir, x))[0].replace('/', '_') + " (10)"
|
||||
return re.search(legend_pattern, x).group(0)
|
||||
|
||||
legneds = map(legend_fn, file_lists)
|
||||
# sort filelist according to legends
|
||||
file_lists = [f for _, f in sorted(zip(legneds, file_lists))]
|
||||
legneds = list(map(legend_fn, file_lists))
|
||||
|
||||
for index, csv_file in enumerate(file_lists):
|
||||
csv_dict = csv2numpy(csv_file)
|
||||
x, y = csv_dict[xkey], csv_dict[ykey]
|
||||
y = smooth(y, radius=smooth_radius)
|
||||
color = COLORS[index % len(COLORS)]
|
||||
ax.plot(x, y, color=color)
|
||||
if shaded_std and ykey + ':shaded' in csv_dict:
|
||||
y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius)
|
||||
ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2)
|
||||
|
||||
ax.legend(legneds, loc=2 if legend_outside else None,
|
||||
bbox_to_anchor=(1, 1) if legend_outside else None)
|
||||
ax.xaxis.set_major_formatter(mticker.EngFormatter())
|
||||
if xlim is not None:
|
||||
ax.set_xlim(xmin=0, xmax=xlim)
|
||||
# add title
|
||||
ax.set_title(title)
|
||||
# add labels
|
||||
if xlabel is not None:
|
||||
ax.set_xlabel(xlabel)
|
||||
if ylabel is not None:
|
||||
ax.set_ylabel(ylabel)
|
||||
|
||||
|
||||
def plot_figure(
|
||||
file_lists,
|
||||
group_pattern=None,
|
||||
fig_length=6,
|
||||
fig_width=6,
|
||||
sharex=False,
|
||||
sharey=False,
|
||||
title=None,
|
||||
**kwargs,
|
||||
):
|
||||
if not group_pattern:
|
||||
fig, ax = plt.subplots(figsize=(fig_length, fig_width))
|
||||
plot_ax(ax, file_lists, title=title, **kwargs)
|
||||
else:
|
||||
res = group_files(file_lists, group_pattern)
|
||||
row_n = int(np.ceil(len(res) / 3))
|
||||
col_n = min(len(res), 3)
|
||||
fig, axes = plt.subplots(row_n, col_n, sharex=sharex, sharey=sharey, figsize=(
|
||||
fig_length * col_n, fig_width * row_n), squeeze=False)
|
||||
axes = axes.flatten()
|
||||
for i, (k, v) in enumerate(res.items()):
|
||||
plot_ax(axes[i], v, title=k, **kwargs)
|
||||
if title: # add title
|
||||
fig.suptitle(title, fontsize=20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='plotter')
|
||||
parser.add_argument('--fig-length', type=int, default=6,
|
||||
help='matplotlib figure length (default: 6)')
|
||||
parser.add_argument('--fig-width', type=int, default=6,
|
||||
help='matplotlib figure width (default: 6)')
|
||||
parser.add_argument('--style', default='seaborn',
|
||||
help='matplotlib figure style (default: seaborn)')
|
||||
parser.add_argument('--title', default=None,
|
||||
help='matplotlib figure title (default: None)')
|
||||
parser.add_argument('--xkey', default='env_step',
|
||||
help='x-axis key in csv file (default: env_step)')
|
||||
parser.add_argument('--ykey', default='rew',
|
||||
help='y-axis key in csv file (default: rew)')
|
||||
parser.add_argument('--smooth', type=int, default=0,
|
||||
help='smooth radius of y axis (default: 0)')
|
||||
parser.add_argument('--xlabel', default='Timesteps',
|
||||
help='matplotlib figure xlabel')
|
||||
parser.add_argument('--ylabel', default='Episode Reward',
|
||||
help='matplotlib figure ylabel')
|
||||
parser.add_argument(
|
||||
'--shaded-std', action='store_true',
|
||||
help='shaded region corresponding to standard deviation of the group')
|
||||
parser.add_argument('--sharex', action='store_true',
|
||||
help='whether to share x axis within multiple sub-figures')
|
||||
parser.add_argument('--sharey', action='store_true',
|
||||
help='whether to share y axis within multiple sub-figures')
|
||||
parser.add_argument('--legend-outside', action='store_true',
|
||||
help='place the legend outside of the figure')
|
||||
parser.add_argument('--xlim', type=int, default=None,
|
||||
help='x-axis limitation (default: None)')
|
||||
parser.add_argument('--root-dir', default='./', help='root dir (default: ./)')
|
||||
parser.add_argument(
|
||||
'--file-pattern', type=str, default=r".*/test_rew_\d+seeds.csv$",
|
||||
help='regular expression to determine whether or not to include target csv '
|
||||
'file, default to including all test_rew_{num}seeds.csv file under rootdir')
|
||||
parser.add_argument(
|
||||
'--group-pattern', type=str, default=r"(/|^)\w*?\-v(\d|$)",
|
||||
help='regular expression to group files in sub-figure, default to grouping '
|
||||
'according to env_name dir, "" means no grouping')
|
||||
parser.add_argument(
|
||||
'--legend-pattern', type=str, default=r".*",
|
||||
help='regular expression to extract legend from csv file path, default to '
|
||||
'using file path as legend name.')
|
||||
parser.add_argument('--show', action='store_true', help='show figure')
|
||||
parser.add_argument('--output-path', type=str,
|
||||
help='figure save path', default="./figure.png")
|
||||
parser.add_argument('--dpi', type=int, default=200,
|
||||
help='figure dpi (default: 200)')
|
||||
args = parser.parse_args()
|
||||
file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
|
||||
file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
|
||||
if args.style:
|
||||
plt.style.use(args.style)
|
||||
os.chdir(args.root_dir)
|
||||
plot_figure(
|
||||
file_lists,
|
||||
group_pattern=args.group_pattern,
|
||||
legend_pattern=args.legend_pattern,
|
||||
fig_length=args.fig_length,
|
||||
fig_width=args.fig_width,
|
||||
title=args.title,
|
||||
xlabel=args.xlabel,
|
||||
ylabel=args.ylabel,
|
||||
xkey=args.xkey,
|
||||
ykey=args.ykey,
|
||||
xlim=args.xlim,
|
||||
sharex=args.sharex,
|
||||
sharey=args.sharey,
|
||||
smooth_radius=args.smooth,
|
||||
shaded_std=args.shaded_std,
|
||||
legend_outside=args.legend_outside)
|
||||
if args.output_path:
|
||||
plt.savefig(args.output_path,
|
||||
dpi=args.dpi, bbox_inches='tight')
|
||||
if args.show:
|
||||
plt.show()
|
@ -2,10 +2,11 @@
|
||||
|
||||
LOGDIR="results"
|
||||
TASK=$1
|
||||
ALGO=$2
|
||||
|
||||
echo "Experiments started."
|
||||
for seed in $(seq 0 9)
|
||||
do
|
||||
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
|
||||
python mujoco_${ALGO}.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
|
||||
done
|
||||
echo "Experiments ended."
|
||||
|
100
examples/mujoco/tools.py
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import re
|
||||
import csv
|
||||
import tqdm
|
||||
import argparse
|
||||
import numpy as np
|
||||
from typing import Dict, List, Union
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
|
||||
|
||||
def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
|
||||
"""Find all files under root_dir according to relative pattern."""
|
||||
file_list = []
|
||||
for dirname, _, files in os.walk(root_dir):
|
||||
for f in files:
|
||||
absolute_path = os.path.join(dirname, f)
|
||||
if re.match(pattern, absolute_path):
|
||||
file_list.append(absolute_path)
|
||||
return file_list
|
||||
|
||||
|
||||
def convert_tfevents_to_csv(
|
||||
root_dir: str, refresh: bool = False
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""Recursively convert test/rew from all tfevent file under root_dir to csv.
|
||||
|
||||
This function assumes that there is at most one tfevents file in each directory
|
||||
and will add suffix to that directory.
|
||||
|
||||
:param bool refresh: re-create csv file under any condition.
|
||||
"""
|
||||
tfevent_files = find_all_files(root_dir, re.compile(r"^.*tfevents.*$"))
|
||||
print(f"Converting {len(tfevent_files)} tfevents files under {root_dir} ...")
|
||||
result = {}
|
||||
with tqdm.tqdm(tfevent_files) as t:
|
||||
for tfevent_file in t:
|
||||
t.set_postfix(file=tfevent_file)
|
||||
output_file = os.path.join(os.path.split(tfevent_file)[0], "test_rew.csv")
|
||||
if os.path.exists(output_file) and not refresh:
|
||||
content = list(csv.reader(open(output_file, "r")))
|
||||
if content[0] == ["env_step", "rew", "time"]:
|
||||
for i in range(1, len(content)):
|
||||
content[i] = list(map(eval, content[i]))
|
||||
result[output_file] = content
|
||||
continue
|
||||
ea = event_accumulator.EventAccumulator(tfevent_file)
|
||||
ea.Reload()
|
||||
initial_time = ea._first_event_timestamp
|
||||
content = [["env_step", "rew", "time"]]
|
||||
for test_rew in ea.scalars.Items("test/rew"):
|
||||
content.append([
|
||||
round(test_rew.step, 4),
|
||||
round(test_rew.value, 4),
|
||||
round(test_rew.wall_time - initial_time, 4),
|
||||
])
|
||||
csv.writer(open(output_file, 'w')).writerows(content)
|
||||
result[output_file] = content
|
||||
return result
|
||||
|
||||
|
||||
def merge_csv(
|
||||
csv_files: List[List[Union[str, int, float]]],
|
||||
root_dir: str,
|
||||
remove_zero: bool = False,
|
||||
) -> None:
|
||||
"""Merge result in csv_files into a single csv file."""
|
||||
assert len(csv_files) > 0
|
||||
if remove_zero:
|
||||
for k, v in csv_files.items():
|
||||
if v[1][0] == 0:
|
||||
v.pop(1)
|
||||
sorted_keys = sorted(csv_files.keys())
|
||||
sorted_values = [csv_files[k][1:] for k in sorted_keys]
|
||||
content = [["env_step", "rew", "rew:shaded"] + list(map(
|
||||
lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))]
|
||||
for rows in zip(*sorted_values):
|
||||
array = np.array(rows)
|
||||
assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0])
|
||||
line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)]
|
||||
line += array[:, 1].tolist()
|
||||
content.append(line)
|
||||
output_path = os.path.join(root_dir, f"test_rew_{len(csv_files)}seeds.csv")
|
||||
print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.")
|
||||
csv.writer(open(output_path, "w")).writerows(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root-dir', type=str)
|
||||
parser.add_argument(
|
||||
'--refresh', action="store_true",
|
||||
help="Re-generate all csv files instead of using existing one.")
|
||||
parser.add_argument(
|
||||
'--remove-zero', action="store_true",
|
||||
help="Remove the data point of env_step == 0.")
|
||||
args = parser.parse_args()
|
||||
csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh)
|
||||
merge_csv(csv_files, args.root_dir, args.remove_zero)
|