allow for evolutionary policy optimization
This commit is contained in:
parent
46f86cd247
commit
c0450359f3
@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# learned set of latent genes
|
# learned set of latent genes
|
||||||
|
|
||||||
self.agent_has_genes = num_latent_genes > 0
|
self.agent_has_genes = num_latent_genes > 0
|
||||||
|
self.num_latent_genes = num_latent_genes
|
||||||
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
||||||
|
|
||||||
# policy head
|
# policy head
|
||||||
@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
return align_dims_left(times, align_dims_left_to)
|
return align_dims_left(times, align_dims_left_to)
|
||||||
|
|
||||||
|
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evolve_(
|
||||||
|
self,
|
||||||
|
fitness,
|
||||||
|
select_frac = 0.5,
|
||||||
|
tournament_frac = 0.5
|
||||||
|
):
|
||||||
|
assert fitness.numel() == self.num_latent_genes
|
||||||
|
|
||||||
|
pop = self.latent_genes
|
||||||
|
|
||||||
|
pop_size = self.num_latent_genes
|
||||||
|
num_selected = ceil(pop_size * select_frac)
|
||||||
|
num_children = pop_size - num_selected
|
||||||
|
|
||||||
|
dim_gene = pop.shape[-1]
|
||||||
|
|
||||||
|
# natural selection just a sort and slice
|
||||||
|
|
||||||
|
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
||||||
|
selected = pop[selected_indices]
|
||||||
|
|
||||||
|
# use tournament - one tournament per child
|
||||||
|
|
||||||
|
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
||||||
|
|
||||||
|
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
||||||
|
|
||||||
|
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
||||||
|
|
||||||
|
parents = selected[parent_ids]
|
||||||
|
|
||||||
|
# crossover by random interpolation from parent1 to parent2
|
||||||
|
|
||||||
|
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
||||||
|
|
||||||
|
parent1, parent2 = parents.unbind(dim = 1)
|
||||||
|
children = parent1.lerp(parent2, random_uniform_mix)
|
||||||
|
|
||||||
|
# store next population
|
||||||
|
|
||||||
|
next_pop = cat((selected, children))
|
||||||
|
|
||||||
|
self.latent_genes.copy_(next_pop)
|
||||||
|
|
||||||
# interacting with env for experience
|
# interacting with env for experience
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.88"
|
version = "0.0.89"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -753,3 +753,36 @@ def test_proprioception(
|
|||||||
|
|
||||||
assert exists(generations.proprio)
|
assert exists(generations.proprio)
|
||||||
assert generations.video.shape == video_shape
|
assert generations.video.shape == video_shape
|
||||||
|
|
||||||
|
def test_epo():
|
||||||
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
512,
|
||||||
|
dim_latent = 32,
|
||||||
|
patch_size = 32,
|
||||||
|
encoder_depth = 2,
|
||||||
|
decoder_depth = 2,
|
||||||
|
time_block_every = 2,
|
||||||
|
attn_heads = 8,
|
||||||
|
image_height = 256,
|
||||||
|
image_width = 256,
|
||||||
|
attn_kwargs = dict(
|
||||||
|
query_heads = 16
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dynamics = DynamicsWorldModel(
|
||||||
|
512,
|
||||||
|
num_agents = 1,
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
|
dim_latent = 32,
|
||||||
|
dim_proprio = 21,
|
||||||
|
num_tasks = 4,
|
||||||
|
num_latent_genes = 16,
|
||||||
|
num_discrete_actions = 4,
|
||||||
|
num_residual_streams = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
fitness = torch.randn(16,)
|
||||||
|
dynamics.evolve_(fitness)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user