fix append batch over batch
This commit is contained in:
		
							parent
							
								
									268f9d0533
								
							
						
					
					
						commit
						aff0f9aee0
					
				
							
								
								
									
										
											BIN
										
									
								
								docs/_static/images/concepts_arch2.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/_static/images/concepts_arch2.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 30 KiB  | 
@ -8,6 +8,12 @@ Tianshou splits a Reinforcement Learning agent training procedure into these par
 | 
			
		||||
    :height: 300
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network:
 | 
			
		||||
 | 
			
		||||
.. image:: /_static/images/concepts_arch2.png
 | 
			
		||||
    :align: center
 | 
			
		||||
    :height: 300
 | 
			
		||||
 | 
			
		||||
Data Batch
 | 
			
		||||
----------
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,16 @@ def test_batch_over_batch():
 | 
			
		||||
    print(batch2)
 | 
			
		||||
    assert batch2.values()[-1] == batch2.c
 | 
			
		||||
    assert batch2[-1].b.b == 0
 | 
			
		||||
    batch2.append(Batch(c=[6, 7, 8], b=batch))
 | 
			
		||||
    assert batch2.c == [6, 7, 8, 6, 7, 8]
 | 
			
		||||
    assert batch2.b.a == [3, 4, 5, 3, 4, 5]
 | 
			
		||||
    assert batch2.b.b == [4, 5, 0, 4, 5, 0]
 | 
			
		||||
    d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
 | 
			
		||||
    batch3 = Batch(c=[6, 7, 8], b=d)
 | 
			
		||||
    batch3.append(Batch(c=[6, 7, 8], b=d))
 | 
			
		||||
    assert batch3.c == [6, 7, 8, 6, 7, 8]
 | 
			
		||||
    assert batch3.b.a == [3, 4, 5, 3, 4, 5]
 | 
			
		||||
    assert batch3.b.b == [4, 5, 6, 4, 5, 6]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_batch_over_batch_to_torch():
 | 
			
		||||
 | 
			
		||||
@ -234,6 +234,8 @@ class Batch:
 | 
			
		||||
                self.__dict__[k] = torch.cat([self.__dict__[k], v])
 | 
			
		||||
            elif isinstance(v, list):
 | 
			
		||||
                self.__dict__[k] += v
 | 
			
		||||
            elif isinstance(v, Batch):
 | 
			
		||||
                self.__dict__[k].append(v)
 | 
			
		||||
            else:
 | 
			
		||||
                s = f'No support for append with type \
 | 
			
		||||
                      {type(v)} in class Batch.'
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user