Add numerical analysis tool and interactive plot (#341)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
parent
844d7703c3
commit
bbc3c3e32d
1
docs/_static/js/atari
vendored
Symbolic link
1
docs/_static/js/atari
vendored
Symbolic link
@ -0,0 +1 @@
|
||||
../../../examples/atari
|
67
docs/_static/js/benchmark.js
vendored
Normal file
67
docs/_static/js/benchmark.js
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
var envs = [
|
||||
"Ant-v3",
|
||||
"HalfCheetah-v3",
|
||||
"Hopper-v3",
|
||||
"Humanoid-v3",
|
||||
"InvertedDoublePendulum-v2",
|
||||
"InvertedPendulum-v2",
|
||||
"Reacher-v2",
|
||||
"Swimmer-v3",
|
||||
"Walker2d-v3",
|
||||
];
|
||||
function showEnv(elem) {
|
||||
var selectEnv = elem.value || envs[0];
|
||||
var dataSource = {
|
||||
$schema: "https://vega.github.io/schema/vega-lite/v5.json",
|
||||
data: {
|
||||
url: "/_static/js/mujoco/benchmark/" + selectEnv + "/result.json"
|
||||
},
|
||||
mark: "line",
|
||||
height: 400,
|
||||
width: 800,
|
||||
params: [{name: "Range", value: 1000000, bind: {input: "range", min: 10000, max: 10000000}}],
|
||||
transform: [
|
||||
{calculate: "datum.rew - datum.rew_std", as: "rew_std0"},
|
||||
{calculate: "datum.rew + datum.rew_std", as: "rew_std1"},
|
||||
{calculate: "datum.rew + ' ± ' + datum.rew_std", as: "tooltip_str"},
|
||||
{filter: "datum.env_step <= Range"},
|
||||
],
|
||||
encoding: {
|
||||
color: {"field": "Agent", "type": "nominal"},
|
||||
x: {field: "env_step", type: "quantitative", title: "Env step"},
|
||||
},
|
||||
layer: [{
|
||||
"encoding": {
|
||||
"opacity": {"value": 0.3},
|
||||
"y": {
|
||||
"title": "Return",
|
||||
"field": "rew_std0",
|
||||
"type": "quantitative",
|
||||
},
|
||||
"y2": {"field": "rew_std1"},
|
||||
tooltip: [
|
||||
{field: "env_step", type: "quantitative", title: "Env step"},
|
||||
{field: "Agent", type: "nominal"},
|
||||
{field: "tooltip_str", type: "nominal", title: "Return"},
|
||||
]
|
||||
},
|
||||
"mark": "area"
|
||||
}, {
|
||||
"encoding": {
|
||||
"y": {
|
||||
"field": "rew",
|
||||
"type": "quantitative"
|
||||
}
|
||||
},
|
||||
"mark": "line"
|
||||
}]
|
||||
};
|
||||
vegaEmbed("#vis-mujoco", dataSource);
|
||||
}
|
||||
$(document).ready(function() {
|
||||
var envSelect = $("#env-mujoco");
|
||||
if (envSelect.length) {
|
||||
$.each(envs, function(idx, env) {envSelect.append($("<option></option>").val(env).html(env));})
|
||||
showEnv(envSelect);
|
||||
}
|
||||
});
|
1
docs/_static/js/mujoco
vendored
Symbolic link
1
docs/_static/js/mujoco
vendored
Symbolic link
@ -0,0 +1 @@
|
||||
../../../examples/mujoco
|
@ -90,7 +90,14 @@ html_logo = "_static/images/tianshou-logo.png"
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_js_file(
|
||||
"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.0/jquery.min.js")
|
||||
app.add_js_file("https://cdn.jsdelivr.net/npm/vega@5.20.2")
|
||||
app.add_js_file("https://cdn.jsdelivr.net/npm/vega-lite@5.1.0")
|
||||
app.add_js_file("https://cdn.jsdelivr.net/npm/vega-embed@6.17.0")
|
||||
|
||||
app.add_js_file("js/copybutton.js")
|
||||
app.add_js_file("js/benchmark.js")
|
||||
app.add_css_file("css/style.css")
|
||||
|
||||
|
||||
|
@ -8,6 +8,16 @@ Tianshou's Mujoco benchmark contains state-of-the-art results (even better than
|
||||
|
||||
Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/mujoco
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<center>
|
||||
<select id="env-mujoco" onchange="showEnv(this)"></select>
|
||||
<br>
|
||||
<div id="vis-mujoco"></div>
|
||||
<br>
|
||||
</center>
|
||||
|
||||
|
||||
Atari Benchmark
|
||||
---------------
|
||||
|
||||
|
@ -16,7 +16,7 @@ Supported algorithms are listed below:
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
|
||||
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
|
||||
- [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/1dcf65fe21dc7636966796b6099ede1f4bd775e1)
|
||||
- [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/844d7703c313009c4c364edb4018c91de93439ca)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
|
||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
|
||||
@ -46,8 +46,12 @@ This will start 10 experiments with different seeds.
|
||||
Now that all the experiments are finished, we can convert all tfevent files into csv files and then try plotting the results.
|
||||
|
||||
```bash
|
||||
# geenrate csv
|
||||
$ ./tools.py --root-dir ./results/Ant-v3/sac
|
||||
# generate figures
|
||||
$ ./plotter.py --root-dir ./results/Ant-v3 --shaded-std --legend-pattern "\\w+"
|
||||
# generate numerical result (support multiple groups: `--root-dir ./` instead of single dir)
|
||||
$ ./analysis.py --root-dir ./results --norm
|
||||
```
|
||||
|
||||
#### Example benchmark
|
||||
|
89
examples/mujoco/analysis.py
Executable file
89
examples/mujoco/analysis.py
Executable file
@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
from collections import defaultdict
|
||||
from tools import find_all_files, group_files, csv2numpy
|
||||
|
||||
|
||||
def numerical_anysis(root_dir, xlim, norm=False):
|
||||
file_pattern = re.compile(r".*/test_rew_\d+seeds.csv$")
|
||||
norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)")
|
||||
output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)")
|
||||
csv_files = find_all_files(root_dir, file_pattern)
|
||||
norm_group = group_files(csv_files, norm_group_pattern)
|
||||
output_group = group_files(csv_files, output_group_pattern)
|
||||
# calculate numerical outcome for each csv_file (y/std integration max_y, final_y)
|
||||
results = defaultdict(list)
|
||||
for f in csv_files:
|
||||
result = csv2numpy(f)
|
||||
if norm:
|
||||
result = np.stack([
|
||||
result['env_step'],
|
||||
result['rew'] - result['rew'][0],
|
||||
result['rew:shaded']])
|
||||
else:
|
||||
result = np.stack([
|
||||
result['env_step'], result['rew'], result['rew:shaded']])
|
||||
|
||||
if result[0, -1] < xlim:
|
||||
continue
|
||||
|
||||
final_rew = np.interp(xlim, result[0], result[1])
|
||||
final_rew_std = np.interp(xlim, result[0], result[2])
|
||||
result = result[:, result[0] <= xlim]
|
||||
|
||||
if len(result) == 0:
|
||||
continue
|
||||
|
||||
if result[0, -1] < xlim:
|
||||
last_line = np.array([xlim, final_rew, final_rew_std]).reshape(3, 1)
|
||||
result = np.concatenate([result, last_line], axis=-1)
|
||||
|
||||
max_id = np.argmax(result[1])
|
||||
results['name'].append(f)
|
||||
results['final_reward'].append(result[1, -1])
|
||||
results['final_reward_std'].append(result[2, -1])
|
||||
results['max_reward'].append(result[1, max_id])
|
||||
results['max_std'].append(result[2, max_id])
|
||||
results['reward_integration'].append(np.trapz(result[1], x=result[0]))
|
||||
results['reward_std_integration'].append(np.trapz(result[2], x=result[0]))
|
||||
|
||||
results = {k: np.array(v) for k, v in results.items()}
|
||||
print(tabulate(results, headers="keys"))
|
||||
|
||||
if norm:
|
||||
# calculate normalized numerical outcome for each csv_file group
|
||||
for _, fs in norm_group.items():
|
||||
mask = np.isin(results['name'], fs)
|
||||
for k, v in results.items():
|
||||
if k == 'name':
|
||||
continue
|
||||
v[mask] = v[mask] / max(v[mask])
|
||||
# Add all numerical results for each outcome group
|
||||
group_results = defaultdict(list)
|
||||
for g, fs in output_group.items():
|
||||
group_results['name'].append(g)
|
||||
mask = np.isin(results['name'], fs)
|
||||
group_results['num'].append(sum(mask))
|
||||
for k in results.keys():
|
||||
if k == 'name':
|
||||
continue
|
||||
group_results[k + ":norm"].append(results[k][mask].mean())
|
||||
# print all outputs for each csv_file and each outcome group
|
||||
print()
|
||||
print(tabulate(group_results, headers="keys"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--xlim', type=int, default=1000000,
|
||||
help='x-axis limitation (default: 1000000)')
|
||||
parser.add_argument('--root-dir', type=str)
|
||||
parser.add_argument(
|
||||
'--norm', action="store_true",
|
||||
help="Normalize all results according to environment.")
|
||||
args = parser.parse_args()
|
||||
numerical_anysis(args.root_dir, args.xlim, norm=args.norm)
|
1
examples/mujoco/benchmark/Ant-v3/result.json
Normal file
1
examples/mujoco/benchmark/Ant-v3/result.json
Normal file
File diff suppressed because one or more lines are too long
1
examples/mujoco/benchmark/HalfCheetah-v3/result.json
Normal file
1
examples/mujoco/benchmark/HalfCheetah-v3/result.json
Normal file
File diff suppressed because one or more lines are too long
1
examples/mujoco/benchmark/Hopper-v3/result.json
Normal file
1
examples/mujoco/benchmark/Hopper-v3/result.json
Normal file
File diff suppressed because one or more lines are too long
1
examples/mujoco/benchmark/Humanoid-v3/result.json
Normal file
1
examples/mujoco/benchmark/Humanoid-v3/result.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
examples/mujoco/benchmark/Reacher-v2/result.json
Normal file
1
examples/mujoco/benchmark/Reacher-v2/result.json
Normal file
File diff suppressed because one or more lines are too long
1
examples/mujoco/benchmark/Swimmer-v3/result.json
Normal file
1
examples/mujoco/benchmark/Swimmer-v3/result.json
Normal file
File diff suppressed because one or more lines are too long
1
examples/mujoco/benchmark/Walker2d-v3/result.json
Normal file
1
examples/mujoco/benchmark/Walker2d-v3/result.json
Normal file
File diff suppressed because one or more lines are too long
32
examples/mujoco/gen_json.py
Executable file
32
examples/mujoco/gen_json.py
Executable file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import csv
|
||||
import sys
|
||||
import json
|
||||
|
||||
|
||||
def merge(rootdir):
|
||||
"""format: $rootdir/$algo/*.csv"""
|
||||
result = []
|
||||
for path, dirnames, filenames in os.walk(rootdir):
|
||||
filenames = [f for f in filenames if f.endswith('.csv')]
|
||||
if len(filenames) == 0:
|
||||
continue
|
||||
elif len(filenames) != 1:
|
||||
print(f'More than 1 csv found in {path}!')
|
||||
continue
|
||||
algo = os.path.relpath(path, rootdir).upper()
|
||||
reader = csv.DictReader(open(os.path.join(path, filenames[0])))
|
||||
for row in reader:
|
||||
result.append({
|
||||
'env_step': int(row['env_step']),
|
||||
'rew': float(row['rew']),
|
||||
'rew_std': float(row['rew:shaded']),
|
||||
'Agent': algo,
|
||||
})
|
||||
open(os.path.join(rootdir, 'result.json'), 'w').write(json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
merge(sys.argv[-1])
|
0
examples/mujoco/mujoco_npg.py
Normal file → Executable file
0
examples/mujoco/mujoco_npg.py
Normal file → Executable file
0
examples/mujoco/mujoco_trpo.py
Normal file → Executable file
0
examples/mujoco/mujoco_trpo.py
Normal file → Executable file
@ -2,14 +2,12 @@
|
||||
|
||||
import re
|
||||
import os
|
||||
import csv
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mticker
|
||||
from collections import defaultdict
|
||||
|
||||
from tools import find_all_files
|
||||
from tools import find_all_files, group_files, csv2numpy
|
||||
|
||||
|
||||
def smooth(y, radius, mode='two_sided', valid_only=False):
|
||||
@ -64,24 +62,6 @@ COLORS = ([
|
||||
])
|
||||
|
||||
|
||||
def csv2numpy(csv_file):
|
||||
csv_dict = defaultdict(list)
|
||||
reader = csv.DictReader(open(csv_file))
|
||||
for row in reader:
|
||||
for k, v in row.items():
|
||||
csv_dict[k].append(eval(v))
|
||||
return {k: np.array(v) for k, v in csv_dict.items()}
|
||||
|
||||
|
||||
def group_files(file_list, pattern):
|
||||
res = defaultdict(list)
|
||||
for f in file_list:
|
||||
match = re.search(pattern, f)
|
||||
key = match.group() if match else ''
|
||||
res[key].append(f)
|
||||
return res
|
||||
|
||||
|
||||
def plot_ax(
|
||||
ax,
|
||||
file_lists,
|
||||
|
@ -6,11 +6,11 @@ import csv
|
||||
import tqdm
|
||||
import argparse
|
||||
import numpy as np
|
||||
from typing import Dict, List, Union
|
||||
from collections import defaultdict
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
|
||||
|
||||
def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
|
||||
def find_all_files(root_dir, pattern):
|
||||
"""Find all files under root_dir according to relative pattern."""
|
||||
file_list = []
|
||||
for dirname, _, files in os.walk(root_dir):
|
||||
@ -21,9 +21,25 @@ def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
|
||||
return file_list
|
||||
|
||||
|
||||
def convert_tfevents_to_csv(
|
||||
root_dir: str, refresh: bool = False
|
||||
) -> Dict[str, np.ndarray]:
|
||||
def group_files(file_list, pattern):
|
||||
res = defaultdict(list)
|
||||
for f in file_list:
|
||||
match = re.search(pattern, f)
|
||||
key = match.group() if match else ''
|
||||
res[key].append(f)
|
||||
return res
|
||||
|
||||
|
||||
def csv2numpy(csv_file):
|
||||
csv_dict = defaultdict(list)
|
||||
reader = csv.DictReader(open(csv_file))
|
||||
for row in reader:
|
||||
for k, v in row.items():
|
||||
csv_dict[k].append(eval(v))
|
||||
return {k: np.array(v) for k, v in csv_dict.items()}
|
||||
|
||||
|
||||
def convert_tfevents_to_csv(root_dir, refresh=False):
|
||||
"""Recursively convert test/rew from all tfevent file under root_dir to csv.
|
||||
|
||||
This function assumes that there is at most one tfevents file in each directory
|
||||
@ -60,11 +76,7 @@ def convert_tfevents_to_csv(
|
||||
return result
|
||||
|
||||
|
||||
def merge_csv(
|
||||
csv_files: List[List[Union[str, int, float]]],
|
||||
root_dir: str,
|
||||
remove_zero: bool = False,
|
||||
) -> None:
|
||||
def merge_csv(csv_files, root_dir, remove_zero=False):
|
||||
"""Merge result in csv_files into a single csv file."""
|
||||
assert len(csv_files) > 0
|
||||
if remove_zero:
|
||||
@ -88,13 +100,14 @@ def merge_csv(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--root-dir', type=str)
|
||||
parser.add_argument(
|
||||
'--refresh', action="store_true",
|
||||
help="Re-generate all csv files instead of using existing one.")
|
||||
parser.add_argument(
|
||||
'--remove-zero', action="store_true",
|
||||
help="Remove the data point of env_step == 0.")
|
||||
parser.add_argument('--root-dir', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh)
|
||||
merge_csv(csv_files, args.root_dir, args.remove_zero)
|
||||
|
Loading…
x
Reference in New Issue
Block a user