allow for evolutionary policy optimization

This commit is contained in:
lucidrains 2025-10-28 10:11:13 -07:00
parent 46f86cd247
commit c0450359f3
3 changed files with 82 additions and 1 deletions

View File

@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
# learned set of latent genes
self.agent_has_genes = num_latent_genes > 0
self.num_latent_genes = num_latent_genes
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
# policy head
@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
return align_dims_left(times, align_dims_left_to)
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
@torch.no_grad()
def evolve_(
self,
fitness,
select_frac = 0.5,
tournament_frac = 0.5
):
assert fitness.numel() == self.num_latent_genes
pop = self.latent_genes
pop_size = self.num_latent_genes
num_selected = ceil(pop_size * select_frac)
num_children = pop_size - num_selected
dim_gene = pop.shape[-1]
# natural selection just a sort and slice
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
selected = pop[selected_indices]
# use tournament - one tournament per child
tournament_size = max(2, ceil(num_selected * tournament_frac))
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
parents = selected[parent_ids]
# crossover by random interpolation from parent1 to parent2
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
parent1, parent2 = parents.unbind(dim = 1)
children = parent1.lerp(parent2, random_uniform_mix)
# store next population
next_pop = cat((selected, children))
self.latent_genes.copy_(next_pop)
# interacting with env for experience
@torch.no_grad()

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.88"
version = "0.0.89"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -753,3 +753,36 @@ def test_proprioception(
assert exists(generations.proprio)
assert generations.video.shape == video_shape
def test_epo():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_latent_genes = 16,
num_discrete_actions = 4,
num_residual_streams = 1
)
fitness = torch.randn(16,)
dynamics.evolve_(fitness)