| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | from abc import ABC, abstractmethod | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BasePolicy(ABC): | 
					
						
							|  |  |  |     """docstring for BasePolicy""" | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |     def __init__(self): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |         self.model = None | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |     def __call__(self, batch, hidden_state=None): | 
					
						
							| 
									
										
										
										
											2020-03-16 15:04:58 +08:00
										 |  |  |         # return Batch(act=np.array(), state=None, ...) | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def learn(self, batch): | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |     def process_fn(self, batch, buffer, indice): | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |         return batch | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |     def sync_weight(self): | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |         pass |