Add numerical analysis tool and interactive plot (#341)

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-04-22 12:49:54 +08:00 committed by GitHub
parent 844d7703c3
commit bbc3c3e32d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 246 additions and 33 deletions

1
docs/_static/js/atari vendored Symbolic link
View File

@ -0,0 +1 @@
../../../examples/atari

67
docs/_static/js/benchmark.js vendored Normal file
View File

@ -0,0 +1,67 @@
var envs = [
"Ant-v3",
"HalfCheetah-v3",
"Hopper-v3",
"Humanoid-v3",
"InvertedDoublePendulum-v2",
"InvertedPendulum-v2",
"Reacher-v2",
"Swimmer-v3",
"Walker2d-v3",
];
function showEnv(elem) {
var selectEnv = elem.value || envs[0];
var dataSource = {
$schema: "https://vega.github.io/schema/vega-lite/v5.json",
data: {
url: "/_static/js/mujoco/benchmark/" + selectEnv + "/result.json"
},
mark: "line",
height: 400,
width: 800,
params: [{name: "Range", value: 1000000, bind: {input: "range", min: 10000, max: 10000000}}],
transform: [
{calculate: "datum.rew - datum.rew_std", as: "rew_std0"},
{calculate: "datum.rew + datum.rew_std", as: "rew_std1"},
{calculate: "datum.rew + ' ± ' + datum.rew_std", as: "tooltip_str"},
{filter: "datum.env_step <= Range"},
],
encoding: {
color: {"field": "Agent", "type": "nominal"},
x: {field: "env_step", type: "quantitative", title: "Env step"},
},
layer: [{
"encoding": {
"opacity": {"value": 0.3},
"y": {
"title": "Return",
"field": "rew_std0",
"type": "quantitative",
},
"y2": {"field": "rew_std1"},
tooltip: [
{field: "env_step", type: "quantitative", title: "Env step"},
{field: "Agent", type: "nominal"},
{field: "tooltip_str", type: "nominal", title: "Return"},
]
},
"mark": "area"
}, {
"encoding": {
"y": {
"field": "rew",
"type": "quantitative"
}
},
"mark": "line"
}]
};
vegaEmbed("#vis-mujoco", dataSource);
}
$(document).ready(function() {
var envSelect = $("#env-mujoco");
if (envSelect.length) {
$.each(envs, function(idx, env) {envSelect.append($("<option></option>").val(env).html(env));})
showEnv(envSelect);
}
});

1
docs/_static/js/mujoco vendored Symbolic link
View File

@ -0,0 +1 @@
../../../examples/mujoco

View File

@ -90,7 +90,14 @@ html_logo = "_static/images/tianshou-logo.png"
def setup(app):
app.add_js_file(
"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.0/jquery.min.js")
app.add_js_file("https://cdn.jsdelivr.net/npm/vega@5.20.2")
app.add_js_file("https://cdn.jsdelivr.net/npm/vega-lite@5.1.0")
app.add_js_file("https://cdn.jsdelivr.net/npm/vega-embed@6.17.0")
app.add_js_file("js/copybutton.js")
app.add_js_file("js/benchmark.js")
app.add_css_file("css/style.css")

View File

@ -8,6 +8,16 @@ Tianshou's Mujoco benchmark contains state-of-the-art results (even better than
Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/mujoco
.. raw:: html
<center>
<select id="env-mujoco" onchange="showEnv(this)"></select>
<br>
<div id="vis-mujoco"></div>
<br>
</center>
Atari Benchmark
---------------

View File

@ -16,7 +16,7 @@ Supported algorithms are listed below:
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
- [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/1dcf65fe21dc7636966796b6099ede1f4bd775e1)
- [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/844d7703c313009c4c364edb4018c91de93439ca)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
@ -46,8 +46,12 @@ This will start 10 experiments with different seeds.
Now that all the experiments are finished, we can convert all tfevent files into csv files and then try plotting the results.
```bash
# geenrate csv
$ ./tools.py --root-dir ./results/Ant-v3/sac
# generate figures
$ ./plotter.py --root-dir ./results/Ant-v3 --shaded-std --legend-pattern "\\w+"
# generate numerical result (support multiple groups: `--root-dir ./` instead of single dir)
$ ./analysis.py --root-dir ./results --norm
```
#### Example benchmark

89
examples/mujoco/analysis.py Executable file
View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
import re
import argparse
import numpy as np
from tabulate import tabulate
from collections import defaultdict
from tools import find_all_files, group_files, csv2numpy
def numerical_anysis(root_dir, xlim, norm=False):
file_pattern = re.compile(r".*/test_rew_\d+seeds.csv$")
norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)")
output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)")
csv_files = find_all_files(root_dir, file_pattern)
norm_group = group_files(csv_files, norm_group_pattern)
output_group = group_files(csv_files, output_group_pattern)
# calculate numerical outcome for each csv_file (y/std integration max_y, final_y)
results = defaultdict(list)
for f in csv_files:
result = csv2numpy(f)
if norm:
result = np.stack([
result['env_step'],
result['rew'] - result['rew'][0],
result['rew:shaded']])
else:
result = np.stack([
result['env_step'], result['rew'], result['rew:shaded']])
if result[0, -1] < xlim:
continue
final_rew = np.interp(xlim, result[0], result[1])
final_rew_std = np.interp(xlim, result[0], result[2])
result = result[:, result[0] <= xlim]
if len(result) == 0:
continue
if result[0, -1] < xlim:
last_line = np.array([xlim, final_rew, final_rew_std]).reshape(3, 1)
result = np.concatenate([result, last_line], axis=-1)
max_id = np.argmax(result[1])
results['name'].append(f)
results['final_reward'].append(result[1, -1])
results['final_reward_std'].append(result[2, -1])
results['max_reward'].append(result[1, max_id])
results['max_std'].append(result[2, max_id])
results['reward_integration'].append(np.trapz(result[1], x=result[0]))
results['reward_std_integration'].append(np.trapz(result[2], x=result[0]))
results = {k: np.array(v) for k, v in results.items()}
print(tabulate(results, headers="keys"))
if norm:
# calculate normalized numerical outcome for each csv_file group
for _, fs in norm_group.items():
mask = np.isin(results['name'], fs)
for k, v in results.items():
if k == 'name':
continue
v[mask] = v[mask] / max(v[mask])
# Add all numerical results for each outcome group
group_results = defaultdict(list)
for g, fs in output_group.items():
group_results['name'].append(g)
mask = np.isin(results['name'], fs)
group_results['num'].append(sum(mask))
for k in results.keys():
if k == 'name':
continue
group_results[k + ":norm"].append(results[k][mask].mean())
# print all outputs for each csv_file and each outcome group
print()
print(tabulate(group_results, headers="keys"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--xlim', type=int, default=1000000,
help='x-axis limitation (default: 1000000)')
parser.add_argument('--root-dir', type=str)
parser.add_argument(
'--norm', action="store_true",
help="Normalize all results according to environment.")
args = parser.parse_args()
numerical_anysis(args.root_dir, args.xlim, norm=args.norm)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

32
examples/mujoco/gen_json.py Executable file
View File

@ -0,0 +1,32 @@
#!/usr/bin/env python3
import os
import csv
import sys
import json
def merge(rootdir):
"""format: $rootdir/$algo/*.csv"""
result = []
for path, dirnames, filenames in os.walk(rootdir):
filenames = [f for f in filenames if f.endswith('.csv')]
if len(filenames) == 0:
continue
elif len(filenames) != 1:
print(f'More than 1 csv found in {path}!')
continue
algo = os.path.relpath(path, rootdir).upper()
reader = csv.DictReader(open(os.path.join(path, filenames[0])))
for row in reader:
result.append({
'env_step': int(row['env_step']),
'rew': float(row['rew']),
'rew_std': float(row['rew:shaded']),
'Agent': algo,
})
open(os.path.join(rootdir, 'result.json'), 'w').write(json.dumps(result))
if __name__ == "__main__":
merge(sys.argv[-1])

0
examples/mujoco/mujoco_npg.py Normal file → Executable file
View File

0
examples/mujoco/mujoco_trpo.py Normal file → Executable file
View File

View File

@ -2,14 +2,12 @@
import re
import os
import csv
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from collections import defaultdict
from tools import find_all_files
from tools import find_all_files, group_files, csv2numpy
def smooth(y, radius, mode='two_sided', valid_only=False):
@ -64,24 +62,6 @@ COLORS = ([
])
def csv2numpy(csv_file):
csv_dict = defaultdict(list)
reader = csv.DictReader(open(csv_file))
for row in reader:
for k, v in row.items():
csv_dict[k].append(eval(v))
return {k: np.array(v) for k, v in csv_dict.items()}
def group_files(file_list, pattern):
res = defaultdict(list)
for f in file_list:
match = re.search(pattern, f)
key = match.group() if match else ''
res[key].append(f)
return res
def plot_ax(
ax,
file_lists,

View File

@ -6,11 +6,11 @@ import csv
import tqdm
import argparse
import numpy as np
from typing import Dict, List, Union
from collections import defaultdict
from tensorboard.backend.event_processing import event_accumulator
def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
def find_all_files(root_dir, pattern):
"""Find all files under root_dir according to relative pattern."""
file_list = []
for dirname, _, files in os.walk(root_dir):
@ -21,9 +21,25 @@ def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
return file_list
def convert_tfevents_to_csv(
root_dir: str, refresh: bool = False
) -> Dict[str, np.ndarray]:
def group_files(file_list, pattern):
res = defaultdict(list)
for f in file_list:
match = re.search(pattern, f)
key = match.group() if match else ''
res[key].append(f)
return res
def csv2numpy(csv_file):
csv_dict = defaultdict(list)
reader = csv.DictReader(open(csv_file))
for row in reader:
for k, v in row.items():
csv_dict[k].append(eval(v))
return {k: np.array(v) for k, v in csv_dict.items()}
def convert_tfevents_to_csv(root_dir, refresh=False):
"""Recursively convert test/rew from all tfevent file under root_dir to csv.
This function assumes that there is at most one tfevents file in each directory
@ -60,11 +76,7 @@ def convert_tfevents_to_csv(
return result
def merge_csv(
csv_files: List[List[Union[str, int, float]]],
root_dir: str,
remove_zero: bool = False,
) -> None:
def merge_csv(csv_files, root_dir, remove_zero=False):
"""Merge result in csv_files into a single csv file."""
assert len(csv_files) > 0
if remove_zero:
@ -88,13 +100,14 @@ def merge_csv(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root-dir', type=str)
parser.add_argument(
'--refresh', action="store_true",
help="Re-generate all csv files instead of using existing one.")
parser.add_argument(
'--remove-zero', action="store_true",
help="Remove the data point of env_step == 0.")
parser.add_argument('--root-dir', type=str)
args = parser.parse_args()
csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh)
merge_csv(csv_files, args.root_dir, args.remove_zero)