| 
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 |  |  | #!/usr/bin/env python3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import csv | 
					
						
							|  |  |  | import json | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							| 
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def merge(rootdir): | 
					
						
							|  |  |  |     """format: $rootdir/$algo/*.csv""" | 
					
						
							|  |  |  |     result = [] | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     for path, _, filenames in os.walk(rootdir): | 
					
						
							| 
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 |  |  |         filenames = [f for f in filenames if f.endswith('.csv')] | 
					
						
							|  |  |  |         if len(filenames) == 0: | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         elif len(filenames) != 1: | 
					
						
							|  |  |  |             print(f'More than 1 csv found in {path}!') | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         algo = os.path.relpath(path, rootdir).upper() | 
					
						
							|  |  |  |         reader = csv.DictReader(open(os.path.join(path, filenames[0]))) | 
					
						
							|  |  |  |         for row in reader: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             result.append( | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     'env_step': int(row['env_step']), | 
					
						
							| 
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 |  |  |                     'rew': float(row['reward']), | 
					
						
							|  |  |  |                     'rew_std': float(row['reward:shaded']), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                     'Agent': algo, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-04-22 12:49:54 +08:00
										 |  |  |     open(os.path.join(rootdir, 'result.json'), 'w').write(json.dumps(result)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     merge(sys.argv[-1]) |