allow for evolutionary policy optimization
This commit is contained in:
parent
46f86cd247
commit
c0450359f3
@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
|
||||
# learned set of latent genes
|
||||
|
||||
self.agent_has_genes = num_latent_genes > 0
|
||||
self.num_latent_genes = num_latent_genes
|
||||
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
||||
|
||||
# policy head
|
||||
@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return align_dims_left(times, align_dims_left_to)
|
||||
|
||||
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
||||
|
||||
@torch.no_grad()
|
||||
def evolve_(
|
||||
self,
|
||||
fitness,
|
||||
select_frac = 0.5,
|
||||
tournament_frac = 0.5
|
||||
):
|
||||
assert fitness.numel() == self.num_latent_genes
|
||||
|
||||
pop = self.latent_genes
|
||||
|
||||
pop_size = self.num_latent_genes
|
||||
num_selected = ceil(pop_size * select_frac)
|
||||
num_children = pop_size - num_selected
|
||||
|
||||
dim_gene = pop.shape[-1]
|
||||
|
||||
# natural selection just a sort and slice
|
||||
|
||||
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
||||
selected = pop[selected_indices]
|
||||
|
||||
# use tournament - one tournament per child
|
||||
|
||||
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
||||
|
||||
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
||||
|
||||
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
||||
|
||||
parents = selected[parent_ids]
|
||||
|
||||
# crossover by random interpolation from parent1 to parent2
|
||||
|
||||
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
||||
|
||||
parent1, parent2 = parents.unbind(dim = 1)
|
||||
children = parent1.lerp(parent2, random_uniform_mix)
|
||||
|
||||
# store next population
|
||||
|
||||
next_pop = cat((selected, children))
|
||||
|
||||
self.latent_genes.copy_(next_pop)
|
||||
|
||||
# interacting with env for experience
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.88"
|
||||
version = "0.0.89"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -753,3 +753,36 @@ def test_proprioception(
|
||||
|
||||
assert exists(generations.proprio)
|
||||
assert generations.video.shape == video_shape
|
||||
|
||||
def test_epo():
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
encoder_depth = 2,
|
||||
decoder_depth = 2,
|
||||
time_block_every = 2,
|
||||
attn_heads = 8,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
attn_kwargs = dict(
|
||||
query_heads = 16
|
||||
)
|
||||
)
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
512,
|
||||
num_agents = 1,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
dim_proprio = 21,
|
||||
num_tasks = 4,
|
||||
num_latent_genes = 16,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
fitness = torch.randn(16,)
|
||||
dynamics.evolve_(fitness)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user