Fix tuple support. (#88)
This commit is contained in:
		
							parent
							
								
									ec270759ab
								
							
						
					
					
						commit
						d7dd3105bc
					
				@ -54,6 +54,9 @@ def test_batch_over_batch():
 | 
				
			|||||||
    assert batch3.c == [6, 7, 8, 6, 7, 8]
 | 
					    assert batch3.c == [6, 7, 8, 6, 7, 8]
 | 
				
			||||||
    assert batch3.b.a == [3, 4, 5, 3, 4, 5]
 | 
					    assert batch3.b.a == [3, 4, 5, 3, 4, 5]
 | 
				
			||||||
    assert batch3.b.b == [4, 5, 6, 4, 5, 6]
 | 
					    assert batch3.b.b == [4, 5, 6, 4, 5, 6]
 | 
				
			||||||
 | 
					    batch4 = Batch(({'a': {'b': np.array([1.0])}},))
 | 
				
			||||||
 | 
					    assert batch4.a.b.ndim == 2
 | 
				
			||||||
 | 
					    assert batch4.a.b[0, 0] == 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_batch_cat_and_stack():
 | 
					def test_batch_cat_and_stack():
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@ import copy
 | 
				
			|||||||
import pprint
 | 
					import pprint
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from typing import Any, List, Union, Iterator, Optional
 | 
					from typing import Any, List, Tuple, Union, Iterator, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Disable pickle warning related to torch, since it has been removed
 | 
					# Disable pickle warning related to torch, since it has been removed
 | 
				
			||||||
# on torch master branch. See Pull Request #39003 for details:
 | 
					# on torch master branch. See Pull Request #39003 for details:
 | 
				
			||||||
@ -76,29 +76,28 @@ class Batch:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 batch_dict: Optional[
 | 
					                 batch_dict: Optional[
 | 
				
			||||||
                     Union[dict, List[dict], np.ndarray]] = None,
 | 
					                     Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
 | 
				
			||||||
                 **kwargs) -> None:
 | 
					                 **kwargs) -> None:
 | 
				
			||||||
        if isinstance(batch_dict, (list, np.ndarray)) \
 | 
					        if isinstance(batch_dict, (list, tuple, np.ndarray)) \
 | 
				
			||||||
                and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
 | 
					                and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
 | 
				
			||||||
            for k, v in zip(batch_dict[0].keys(),
 | 
					            for k, v in zip(batch_dict[0].keys(),
 | 
				
			||||||
                            zip(*[e.values() for e in batch_dict])):
 | 
					                            zip(*[e.values() for e in batch_dict])):
 | 
				
			||||||
                if isinstance(v, (list, np.ndarray)) \
 | 
					                if isinstance(v[0], dict) \
 | 
				
			||||||
                        and len(v) > 0 and isinstance(v[0], dict):
 | 
					                        or (isinstance(v, (list, tuple, np.ndarray))
 | 
				
			||||||
                    self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
 | 
					                            and len(v) > 0 and isinstance(v[0], dict)):
 | 
				
			||||||
 | 
					                    self.__dict__[k] = Batch(v)
 | 
				
			||||||
                elif isinstance(v[0], np.ndarray):
 | 
					                elif isinstance(v[0], np.ndarray):
 | 
				
			||||||
                    self.__dict__[k] = np.stack(v, axis=0)
 | 
					                    self.__dict__[k] = np.stack(v, axis=0)
 | 
				
			||||||
                elif isinstance(v[0], torch.Tensor):
 | 
					                elif isinstance(v[0], torch.Tensor):
 | 
				
			||||||
                    self.__dict__[k] = torch.stack(v, dim=0)
 | 
					                    self.__dict__[k] = torch.stack(v, dim=0)
 | 
				
			||||||
                elif isinstance(v[0], Batch):
 | 
					                elif isinstance(v[0], Batch):
 | 
				
			||||||
                    self.__dict__[k] = Batch.stack(v)
 | 
					                    self.__dict__[k] = Batch.stack(v)
 | 
				
			||||||
                elif isinstance(v[0], dict):
 | 
					 | 
				
			||||||
                    self.__dict__[k] = Batch(v)
 | 
					 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    self.__dict__[k] = list(v)
 | 
					                    self.__dict__[k] = list(v)
 | 
				
			||||||
        elif isinstance(batch_dict, dict):
 | 
					        elif isinstance(batch_dict, dict):
 | 
				
			||||||
            for k, v in batch_dict.items():
 | 
					            for k, v in batch_dict.items():
 | 
				
			||||||
                if isinstance(v, dict) \
 | 
					                if isinstance(v, dict) \
 | 
				
			||||||
                        or (isinstance(v, (list, np.ndarray))
 | 
					                        or (isinstance(v, (list, tuple, np.ndarray))
 | 
				
			||||||
                            and len(v) > 0 and isinstance(v[0], dict)):
 | 
					                            and len(v) > 0 and isinstance(v[0], dict)):
 | 
				
			||||||
                    self.__dict__[k] = Batch(v)
 | 
					                    self.__dict__[k] = Batch(v)
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user