From c0450359f350b627bc7b6b281c01a6d00d15dcb4 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 28 Oct 2025 10:11:13 -0700 Subject: [PATCH] allow for evolutionary policy optimization --- dreamer4/dreamer4.py | 48 +++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/test_dreamer.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 091e4d9..bf1dc59 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module): # learned set of latent genes self.agent_has_genes = num_latent_genes > 0 + self.num_latent_genes = num_latent_genes self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2) # policy head @@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module): return align_dims_left(times, align_dims_left_to) + # evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037 + + @torch.no_grad() + def evolve_( + self, + fitness, + select_frac = 0.5, + tournament_frac = 0.5 + ): + assert fitness.numel() == self.num_latent_genes + + pop = self.latent_genes + + pop_size = self.num_latent_genes + num_selected = ceil(pop_size * select_frac) + num_children = pop_size - num_selected + + dim_gene = pop.shape[-1] + + # natural selection just a sort and slice + + selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1) + selected = pop[selected_indices] + + # use tournament - one tournament per child + + tournament_size = max(2, ceil(num_selected * tournament_frac)) + + tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size] + + parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents + + parents = selected[parent_ids] + + # crossover by random interpolation from parent1 to parent2 + + random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid() + + parent1, parent2 = parents.unbind(dim = 1) + children = parent1.lerp(parent2, random_uniform_mix) + + # store next population + + next_pop = cat((selected, children)) + + self.latent_genes.copy_(next_pop) + # interacting with env for experience @torch.no_grad() diff --git a/pyproject.toml b/pyproject.toml index c4328f8..5034a63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.88" +version = "0.0.89" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 518e8e3..9fb9d12 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -753,3 +753,36 @@ def test_proprioception( assert exists(generations.proprio) assert generations.video.shape == video_shape + +def test_epo(): + from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel + + tokenizer = VideoTokenizer( + 512, + dim_latent = 32, + patch_size = 32, + encoder_depth = 2, + decoder_depth = 2, + time_block_every = 2, + attn_heads = 8, + image_height = 256, + image_width = 256, + attn_kwargs = dict( + query_heads = 16 + ) + ) + + dynamics = DynamicsWorldModel( + 512, + num_agents = 1, + video_tokenizer = tokenizer, + dim_latent = 32, + dim_proprio = 21, + num_tasks = 4, + num_latent_genes = 16, + num_discrete_actions = 4, + num_residual_streams = 1 + ) + + fitness = torch.randn(16,) + dynamics.evolve_(fitness)