2021-04-22 12:49:54 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
import csv
|
|
|
|
import json
|
2021-09-03 05:05:04 +08:00
|
|
|
import os
|
|
|
|
import sys
|
2024-04-03 18:07:51 +02:00
|
|
|
from os import PathLike
|
2021-04-22 12:49:54 +08:00
|
|
|
|
|
|
|
|
2024-04-03 18:07:51 +02:00
|
|
|
def merge(rootdir: str | PathLike[str]) -> None:
|
2023-08-25 23:40:56 +02:00
|
|
|
"""format: $rootdir/$algo/*.csv."""
|
2021-04-22 12:49:54 +08:00
|
|
|
result = []
|
2021-09-03 05:05:04 +08:00
|
|
|
for path, _, filenames in os.walk(rootdir):
|
2023-08-25 23:40:56 +02:00
|
|
|
filtered_filenames = [f for f in filenames if f.endswith(".csv")]
|
|
|
|
if len(filtered_filenames) == 0:
|
2021-04-22 12:49:54 +08:00
|
|
|
continue
|
2023-08-25 23:40:56 +02:00
|
|
|
if len(filtered_filenames) != 1:
|
|
|
|
print(f"More than 1 csv found in {path}!")
|
2021-04-22 12:49:54 +08:00
|
|
|
continue
|
|
|
|
algo = os.path.relpath(path, rootdir).upper()
|
2023-08-25 23:40:56 +02:00
|
|
|
with open(os.path.join(path, filtered_filenames[0])) as f:
|
|
|
|
reader = csv.DictReader(f)
|
|
|
|
for row in reader:
|
|
|
|
result.append(
|
|
|
|
{
|
|
|
|
"env_step": int(row["env_step"]),
|
|
|
|
"rew": float(row["reward"]),
|
|
|
|
"rew_std": float(row["reward:shaded"]),
|
|
|
|
"Agent": algo,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
with open(os.path.join(rootdir, "result.json"), "w") as f:
|
|
|
|
f.write(json.dumps(result))
|
2021-04-22 12:49:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
merge(sys.argv[-1])
|