2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#!/usr/bin/env python3
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import argparse
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								import re
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from collections import defaultdict
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import numpy as np
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from tabulate import tabulate
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from tools import csv2numpy, find_all_files, group_files
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								def numerical_analysis(root_dir, xlim, norm=False):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    file_pattern = re.compile(r".*/test_reward_\d+seeds.csv$")
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    csv_files = find_all_files(root_dir, file_pattern)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    norm_group = group_files(csv_files, norm_group_pattern)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    output_group = group_files(csv_files, output_group_pattern)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # calculate numerical outcome for each csv_file (y/std integration max_y, final_y)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    results = defaultdict(list)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    for f in csv_files:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        result = csv2numpy(f)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if norm:
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            result = np.stack(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                [
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                    result["env_step"],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    result["reward"] - result["reward"][0],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    result["reward:shaded"],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                ],
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            )
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        else:
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            result = np.stack([result["env_step"], result["reward"], result["reward:shaded"]])
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if result[0, -1] < xlim:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            continue
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        final_rew = np.interp(xlim, result[0], result[1])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        final_rew_std = np.interp(xlim, result[0], result[2])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        result = result[:, result[0] <= xlim]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if len(result) == 0:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            continue
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if result[0, -1] < xlim:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            last_line = np.array([xlim, final_rew, final_rew_std]).reshape(3, 1)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            result = np.concatenate([result, last_line], axis=-1)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        max_id = np.argmax(result[1])
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        results["name"].append(f)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        results["final_reward"].append(result[1, -1])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        results["final_reward_std"].append(result[2, -1])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        results["max_reward"].append(result[1, max_id])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        results["max_std"].append(result[2, max_id])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        results["reward_integration"].append(np.trapz(result[1], x=result[0]))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        results["reward_std_integration"].append(np.trapz(result[2], x=result[0]))
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    results = {k: np.array(v) for k, v in results.items()}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    print(tabulate(results, headers="keys"))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if norm:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # calculate normalized numerical outcome for each csv_file group
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for _, fs in norm_group.items():
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            mask = np.isin(results["name"], fs)
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            for k, v in results.items():
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                if k == "name":
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    continue
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                v[mask] = v[mask] / max(v[mask])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Add all numerical results for each outcome group
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        group_results = defaultdict(list)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for g, fs in output_group.items():
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            group_results["name"].append(g)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            mask = np.isin(results["name"], fs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            group_results["num"].append(sum(mask))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            for k in results:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                if k == "name":
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    continue
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                group_results[k + ":norm"].append(results[k][mask].mean())
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # print all outputs for each csv_file and each outcome group
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        print()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        print(tabulate(group_results, headers="keys"))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								if __name__ == "__main__":
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    parser = argparse.ArgumentParser()
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    parser.add_argument(
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        "--xlim",
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        type=int,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        default=1000000,
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        help="x-axis limitation (default: 1000000)",
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    parser.add_argument("--root-dir", type=str)
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    parser.add_argument(
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        "--norm",
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        action="store_true",
							 | 
						
					
						
							
								
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        help="Normalize all results according to environment.",
							 | 
						
					
						
							
								
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    args = parser.parse_args()
							 | 
						
					
						
							
								
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    numerical_analysis(args.root_dir, args.xlim, norm=args.norm)
							 |