Skip to content

Commit

Permalink
traitA nnModule
Browse files Browse the repository at this point in the history
  • Loading branch information
cjGO committed Jun 11, 2024
1 parent be22e36 commit fdb9f12
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 186 deletions.
9 changes: 1 addition & 8 deletions chewc/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,4 @@
'chewc.core.PopulationDataset.__init__': ('core.html#populationdataset.__init__', 'chewc/core.py'),
'chewc.core.PopulationDataset.__len__': ('core.html#populationdataset.__len__', 'chewc/core.py'),
'chewc.core.create_population_dataloader': ('core.html#create_population_dataloader', 'chewc/core.py')},
'chewc.trait': { 'chewc.trait.TraitA': ('trait.html#traita', 'chewc/trait.py'),
'chewc.trait.TraitA.__init__': ('trait.html#traita.__init__', 'chewc/trait.py'),
'chewc.trait.TraitA._calculate_scaled_additive_dosages': ( 'trait.html#traita._calculate_scaled_additive_dosages',
'chewc/trait.py'),
'chewc.trait.TraitA.forward': ('trait.html#traita.forward', 'chewc/trait.py'),
'chewc.trait.TraitA.sample_initial_effects': ('trait.html#traita.sample_initial_effects', 'chewc/trait.py'),
'chewc.trait.TraitA.scale_effects': ('trait.html#traita.scale_effects', 'chewc/trait.py'),
'chewc.trait.select_qtl_loci': ('trait.html#select_qtl_loci', 'chewc/trait.py')}}}
'chewc.trait': {'chewc.trait.select_qtl_loci': ('trait.html#select_qtl_loci', 'chewc/trait.py')}}}
70 changes: 1 addition & 69 deletions chewc/trait.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_trait.ipynb.

# %% auto 0
__all__ = ['ta', 'tb', 'tc', 'ta_values', 'tb_values', 'tc_values', 'select_qtl_loci', 'TraitA']
__all__ = ['select_qtl_loci']

# %% ../nbs/02_trait.ipynb 3
from .core import *
Expand Down Expand Up @@ -45,72 +45,4 @@ def select_qtl_loci(num_qtl_per_chromosome: int, genome: Genome) -> torch.Tensor

return torch.stack(qtl_indices)

class TraitA(nn.Module):


def __init__(self, genome: Genome, founder_population: Population, target_mean: float, target_variance: float):
super().__init__()
self.genome = genome
self.target_mean = torch.tensor(target_mean, device=self.genome.device)
self.target_variance = torch.tensor(target_variance, device=self.genome.device)
self.effects = None # QTL effects (to be initialized later)
self.intercept = None
self.qtl_map = torch.randint(0, 2, (self.genome.n_chromosomes, self.genome.n_loci_per_chromosome), device=self.genome.device) # random qtl map

# Move computations to the GPU for performance
with torch.no_grad():
self.sample_initial_effects()
self.scale_effects(founder_population)

def _calculate_scaled_additive_dosages(self, genotypes: torch.Tensor) -> torch.Tensor:

return (genotypes - self.genome.ploidy / 2) * (2 / self.genome.ploidy)

def sample_initial_effects(self):

self.effects = torch.randn((self.genome.n_chromosomes, self.genome.n_loci_per_chromosome), device=self.genome.device)

def scale_effects(self, founder_pop: Population):

founder_genotypes = founder_pop.get_dosages().float().to(self.genome.device) # Move to GPU
scaled_dosages = self._calculate_scaled_additive_dosages(founder_genotypes)
# Apply QTL map
scaled_dosages = scaled_dosages * self.qtl_map[None, None, :, :]
genetic_values = torch.sum(scaled_dosages * self.effects[None, None, :, :], dim=(2, 3))
current_mean = genetic_values.mean()
self.intercept = self.target_mean - current_mean
initial_variance = torch.var(genetic_values)
scaling_constant = torch.sqrt(self.target_variance / initial_variance)
self.effects = self.effects * scaling_constant
# Intercept is now fixed to the target mean

# current_mean = self.calculate_genetic_values(self.sim_param.founder_pop).mean()
# self.intercept = self.target_mean - current_mean

def forward(self, genotypes: torch.Tensor) -> torch.Tensor:

genotypes = genotypes.to(self.genome.device) # Ensure genotypes are on the same device as effects
scaled_dosages = self._calculate_scaled_additive_dosages(genotypes)
# Apply QTL map
scaled_dosages = scaled_dosages * self.qtl_map
genetic_values = torch.sum(scaled_dosages * self.effects[None, None, :, :], dim=(2, 3)) + self.intercept
return genetic_values

# %% ../nbs/02_trait.ipynb 6
class TraitA(nn.Module):
def __init__(self, marker_fx, intercept):
super(TraitA, self).__init__()
self.marker_fx = marker_fx
self.intercept = intercept

#calculate the breeding value for the trait and population
def forward(self, dosages):
return torch.einsum('ijk,lij->lk', correlated_effects, init_pop)

ta = TraitA(scaled_bvs[0], trait_intercepts[0])
tb = TraitA(scaled_bvs[1], trait_intercepts[1])
tc = TraitA(scaled_bvs[2], trait_intercepts[2])

ta_values = ta(init_pop)
tb_values = tb(init_pop)
tc_values = tc(init_pop)
201 changes: 116 additions & 85 deletions nbs/02_trait.ipynb

Large diffs are not rendered by default.

58 changes: 34 additions & 24 deletions nbs/03_meiosis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
"metadata": {},
"outputs": [],
"source": [
"def gamma_interference_model(length, rate, shape, sim_param):\n",
"\n",
"\n",
"def gamma_interference_model(length, rate, shape, device):\n",
" \"\"\"\n",
" Simulate crossover events using a gamma interference model.\n",
" \n",
Expand All @@ -61,14 +63,14 @@
" Returns:\n",
" torch.Tensor: Positions of crossover events.\n",
" \"\"\"\n",
" num_crossovers = torch.poisson(torch.tensor([rate * length], device=sim_param.device))\n",
" num_crossovers = torch.poisson(torch.tensor([rate * length], device=device))\n",
" intervals = torch.distributions.gamma.Gamma(shape, rate).sample((int(num_crossovers.item()),)).to(device)\n",
" crossover_positions = torch.cumsum(intervals, dim=0)\n",
" crossover_positions = crossover_positions[crossover_positions < length]\n",
" return crossover_positions\n",
"\n",
"\n",
"def simulate_meiosis(num_chromosomes, map_length, num_individuals, num_crossovers, sim_param):\n",
"def simulate_meiosis(num_chromosomes, map_length, num_individuals, num_crossovers, device):\n",
" \"\"\"\n",
" This function simulates random crossover events across chromosomes.\n",
" \n",
Expand All @@ -80,40 +82,48 @@
" device (torch.device): Device to perform computations on.\n",
" \n",
" Returns:\n",
" List[torch.Tensor]: List of crossover positions for each chromosome.\n",
" torch.Tensor: Tensor of crossover positions for each individual and chromosome. \n",
" Shape: (num_individuals, num_chromosomes, num_crossovers)\n",
" \"\"\"\n",
" return [torch.sort(torch.rand((num_individuals, num_crossovers), device=sim_param.device) * map_length, dim=-1)[0] for _ in range(num_chromosomes)]\n",
" return torch.sort(torch.rand((num_individuals, num_chromosomes, num_crossovers), device=device) * map_length, dim=-1)[0]\n",
"\n",
"\n",
"def simulate_gametes(sim_param, parent_genome):\n",
"def simulate_gametes(genetic_map, parent_genomes, device):\n",
" \"\"\"\n",
" Simulate the formation of a single gamete given crossover positions, genetic map, and parent genomes.\n",
" Simulate the formation of gametes for multiple parents given crossover positions, genetic map, and parent genomes.\n",
"\n",
" Parameters:\n",
" genetic_map (list of torch.Tensor): List of positions of genetic markers on the chromosomes.\n",
" parent_genome (torch.Tensor): Genomes of the parents. SHAPE: (ploidy, num_chromosomes, num_loci)\n",
" genetic_map (torch.Tensor): Positions of genetic markers on the chromosomes. \n",
" Shape: (num_chromosomes, num_loci)\n",
" parent_genomes (torch.Tensor): Genomes of the parents. \n",
" Shape: (num_individuals, ploidy, num_chromosomes, num_loci)\n",
" device (torch.device): Device to perform computations on.\n",
"\n",
" Returns:\n",
" torch.Tensor: The resultant single gamete. SHAPE: (ploidy//2, num_chromosomes, num_loci)\n",
" torch.Tensor: The resultant gametes. \n",
" Shape: (num_individuals, ploidy//2, num_chromosomes, num_loci)\n",
" \"\"\"\n",
" ploidy, num_chromosomes, num_loci = parent_genome.shape\n",
" gamete_genome = torch.zeros((ploidy // 2, num_chromosomes, num_loci), dtype=parent_genome.dtype, device=device)\n",
" crossover_positions = simulate_meiosis(num_chromosomes, 100.0, 1, 1, sim_param)\n",
" num_individuals, ploidy, num_chromosomes, num_loci = parent_genomes.shape\n",
" gamete_genomes = torch.zeros((num_individuals, ploidy // 2, num_chromosomes, num_loci), \n",
" dtype=parent_genomes.dtype, device=device)\n",
" \n",
" # Simulate crossover positions for all individuals\n",
" crossover_positions = simulate_meiosis(num_chromosomes, genetic_map.max(), num_individuals, 1, device)\n",
"\n",
" for individual in range(num_individuals):\n",
" for chrom in range(num_chromosomes):\n",
" crossover_mask = torch.zeros(num_loci, dtype=torch.bool, device=device)\n",
" crossover_site = crossover_positions[individual, chrom, 0] \n",
"\n",
" for chrom in range(num_chromosomes):\n",
" # Create a crossover mask for each chromosome\n",
" crossover_mask = torch.zeros(num_loci, dtype=torch.bool, device=sim_param.device)\n",
" crossover_sites = crossover_positions[chrom].flatten()\n",
" for position in crossover_sites:\n",
" # Find the nearest marker index for each crossover position\n",
" index = torch.argmin(torch.abs(genetic_map[chrom] - position))\n",
" crossover_mask[index:] = ~crossover_mask[index:] # Flip values at and after the crossover index\n",
" # Efficiently find the crossover index using vectorized operations\n",
" index = torch.argmin(torch.abs(genetic_map[chrom] - crossover_site))\n",
" crossover_mask[index:] = ~crossover_mask[index:]\n",
"\n",
" # Use the mask to select genes from parent 1 or parent 2\n",
" gamete_genome[0, chrom] = torch.where(crossover_mask, parent_genome[1, chrom], parent_genome[0, chrom])\n",
" gamete_genomes[individual, 0, chrom] = torch.where(crossover_mask, \n",
" parent_genomes[individual, 1, chrom], \n",
" parent_genomes[individual, 0, chrom])\n",
"\n",
" return gamete_genome\n"
" return gamete_genomes"
]
},
{
Expand Down

0 comments on commit fdb9f12

Please sign in to comment.