| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | from torch import nn | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  | from abc import ABC, abstractmethod | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  | class BasePolicy(ABC, nn.Module): | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |     """docstring for BasePolicy""" | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |     def __init__(self): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |     def process_fn(self, batch, buffer, indice): | 
					
						
							|  |  |  |         return batch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |     def __call__(self, batch, state=None): | 
					
						
							| 
									
										
										
										
											2020-03-18 21:45:41 +08:00
										 |  |  |         # return Batch(logits=..., act=..., state=None, ...) | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |     def learn(self, batch, batch_size=None): | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         # return a dict which includes loss and its name | 
					
						
							| 
									
										
										
										
											2020-03-12 22:20:33 +08:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-15 17:41:00 +08:00
										 |  |  |     def sync_weight(self): | 
					
						
							| 
									
										
										
										
											2020-03-14 21:48:31 +08:00
										 |  |  |         pass |