101 lines
3.9 KiB
Python
101 lines
3.9 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
import os
|
||
|
import re
|
||
|
import csv
|
||
|
import tqdm
|
||
|
import argparse
|
||
|
import numpy as np
|
||
|
from typing import Dict, List, Union
|
||
|
from tensorboard.backend.event_processing import event_accumulator
|
||
|
|
||
|
|
||
|
def find_all_files(root_dir: str, pattern: re.Pattern) -> List[str]:
|
||
|
"""Find all files under root_dir according to relative pattern."""
|
||
|
file_list = []
|
||
|
for dirname, _, files in os.walk(root_dir):
|
||
|
for f in files:
|
||
|
absolute_path = os.path.join(dirname, f)
|
||
|
if re.match(pattern, absolute_path):
|
||
|
file_list.append(absolute_path)
|
||
|
return file_list
|
||
|
|
||
|
|
||
|
def convert_tfevents_to_csv(
|
||
|
root_dir: str, refresh: bool = False
|
||
|
) -> Dict[str, np.ndarray]:
|
||
|
"""Recursively convert test/rew from all tfevent file under root_dir to csv.
|
||
|
|
||
|
This function assumes that there is at most one tfevents file in each directory
|
||
|
and will add suffix to that directory.
|
||
|
|
||
|
:param bool refresh: re-create csv file under any condition.
|
||
|
"""
|
||
|
tfevent_files = find_all_files(root_dir, re.compile(r"^.*tfevents.*$"))
|
||
|
print(f"Converting {len(tfevent_files)} tfevents files under {root_dir} ...")
|
||
|
result = {}
|
||
|
with tqdm.tqdm(tfevent_files) as t:
|
||
|
for tfevent_file in t:
|
||
|
t.set_postfix(file=tfevent_file)
|
||
|
output_file = os.path.join(os.path.split(tfevent_file)[0], "test_rew.csv")
|
||
|
if os.path.exists(output_file) and not refresh:
|
||
|
content = list(csv.reader(open(output_file, "r")))
|
||
|
if content[0] == ["env_step", "rew", "time"]:
|
||
|
for i in range(1, len(content)):
|
||
|
content[i] = list(map(eval, content[i]))
|
||
|
result[output_file] = content
|
||
|
continue
|
||
|
ea = event_accumulator.EventAccumulator(tfevent_file)
|
||
|
ea.Reload()
|
||
|
initial_time = ea._first_event_timestamp
|
||
|
content = [["env_step", "rew", "time"]]
|
||
|
for test_rew in ea.scalars.Items("test/rew"):
|
||
|
content.append([
|
||
|
round(test_rew.step, 4),
|
||
|
round(test_rew.value, 4),
|
||
|
round(test_rew.wall_time - initial_time, 4),
|
||
|
])
|
||
|
csv.writer(open(output_file, 'w')).writerows(content)
|
||
|
result[output_file] = content
|
||
|
return result
|
||
|
|
||
|
|
||
|
def merge_csv(
|
||
|
csv_files: List[List[Union[str, int, float]]],
|
||
|
root_dir: str,
|
||
|
remove_zero: bool = False,
|
||
|
) -> None:
|
||
|
"""Merge result in csv_files into a single csv file."""
|
||
|
assert len(csv_files) > 0
|
||
|
if remove_zero:
|
||
|
for k, v in csv_files.items():
|
||
|
if v[1][0] == 0:
|
||
|
v.pop(1)
|
||
|
sorted_keys = sorted(csv_files.keys())
|
||
|
sorted_values = [csv_files[k][1:] for k in sorted_keys]
|
||
|
content = [["env_step", "rew", "rew:shaded"] + list(map(
|
||
|
lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))]
|
||
|
for rows in zip(*sorted_values):
|
||
|
array = np.array(rows)
|
||
|
assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0])
|
||
|
line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)]
|
||
|
line += array[:, 1].tolist()
|
||
|
content.append(line)
|
||
|
output_path = os.path.join(root_dir, f"test_rew_{len(csv_files)}seeds.csv")
|
||
|
print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.")
|
||
|
csv.writer(open(output_path, "w")).writerows(content)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--root-dir', type=str)
|
||
|
parser.add_argument(
|
||
|
'--refresh', action="store_true",
|
||
|
help="Re-generate all csv files instead of using existing one.")
|
||
|
parser.add_argument(
|
||
|
'--remove-zero', action="store_true",
|
||
|
help="Remove the data point of env_step == 0.")
|
||
|
args = parser.parse_args()
|
||
|
csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh)
|
||
|
merge_csv(csv_files, args.root_dir, args.remove_zero)
|