Skip to content

Commit

Permalink
packaged multi_trait generator into function
Browse files Browse the repository at this point in the history
  • Loading branch information
cjGO committed Jun 11, 2024
1 parent 239adb2 commit 1d01bf5
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 207 deletions.
73 changes: 20 additions & 53 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,64 +33,31 @@ pip install chewc
First, define the genome of your crop

``` python
# import random
import torch

# ploidy = 2
# number_chromosomes = 10
# loci_per_chromosome = 100
# genetic_map = create_random_genetic_map(number_chromosomes,loci_per_chromosome)
# crop_genome = Genome(ploidy, number_chromosomes, loci_per_chromosome, genetic_map)
ploidy = 2
n_chr = 10
n_loci = 100
n_Ind = 333
g = Genome(ploidy, n_chr, n_loci)
population = Population()
population.create_random_founder_population(g, n_founders=n_Ind)
init_pop = population.get_dosages().float() # gets allele dosage for calculating trait values

# n_founders = 500
# founder_pop = create_random_founder_pop(crop_genome , n_founders)
# sim_param = SimParam
# sim_param.founder_pop = founder_pop
# sim_param.genome = crop_genome
# multi_traits


# #add a single additive trait
# qtl_loci = 20
# qtl_map = select_qtl_loci(qtl_loci,sim_param.genome)
target_means = torch.tensor([0, 5])
target_vars = torch.tensor([1, 1]) # Note: I'm assuming you want a variance of 1 for the second trait
correlation_values = [
[1.0, 0.2],
[0.2, 1.0],
]

# ta = TraitA(qtl_map,sim_param,0, 1)
# ta.sample_initial_effects()
# ta.scale_genetic_effects()
# ta.calculate_intercept()

correlated_traits = corr_traits(g, init_pop, target_means, target_vars, correlation_values)
```

Created genetic map




# # Ensure sim_param.device is defined and correct
# device = sim_param.device

# years = 20
# current_pop = founder_pop.to(device)
# pmean = []
# pvar = []

# for _ in range(years):
# # phenotype current pop
# TOPK = 10
# new_pop = []
# pheno = ta.phenotype(current_pop, h2=0.14).to(device)
# topk = torch.topk(pheno, TOPK).indices.to(device)

# for _ in range(200):
# sampled_indices = torch.multinomial(torch.ones(topk.size(0), device=device), 2, replacement=False)
# sampled_parents = topk[sampled_indices]
# m, f = current_pop[sampled_parents[0]], current_pop[sampled_parents[1]]
# new_pop.append(make_cross(sim_param, m, f).to(device))

# current_pop = torch.stack(new_pop).to(device)
# pmean.append(ta.calculate_genetic_values(current_pop).mean().item())
# pvar.append(ta.calculate_genetic_values(current_pop).var().item())

# pmean_normalized = torch.tensor(pmean, device=device) / max(pmean)
# pvar_normalized = torch.tensor(pvar, device=device) / max(pvar)

# plt.scatter(range(len(pmean_normalized)), pmean_normalized.cpu())
# plt.scatter(range(len(pvar_normalized)), pvar_normalized.cpu())
# plt.show()
```
NameError: name 'init_pop' is not defined
5 changes: 3 additions & 2 deletions chewc/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
'chewc.core.PopulationDataset.__init__': ('core.html#populationdataset.__init__', 'chewc/core.py'),
'chewc.core.PopulationDataset.__len__': ('core.html#populationdataset.__len__', 'chewc/core.py'),
'chewc.core.create_population_dataloader': ('core.html#create_population_dataloader', 'chewc/core.py')},
'chewc.cross': { 'chewc.cross.double_haploid': ('crossing.html#double_haploid', 'chewc/cross.py'),
'chewc.cross.random_crosses': ('crossing.html#random_crosses', 'chewc/cross.py')},
'chewc.cross': { 'chewc.cross.double_haploid': ('cross.html#double_haploid', 'chewc/cross.py'),
'chewc.cross.random_crosses': ('cross.html#random_crosses', 'chewc/cross.py')},
'chewc.crossing': { 'chewc.crossing.double_haploid': ('crossing.html#double_haploid', 'chewc/crossing.py'),
'chewc.crossing.random_crosses': ('crossing.html#random_crosses', 'chewc/crossing.py')},
'chewc.meiosis': { 'chewc.meiosis.poisson_crossing_over': ('meiosis.html#poisson_crossing_over', 'chewc/meiosis.py'),
'chewc.meiosis.simulate_gametes': ('meiosis.html#simulate_gametes', 'chewc/meiosis.py')},
'chewc.trait': { 'chewc.trait.TraitA': ('trait.html#traita', 'chewc/trait.py'),
'chewc.trait.TraitA.__init__': ('trait.html#traita.__init__', 'chewc/trait.py'),
'chewc.trait.TraitA.forward': ('trait.html#traita.forward', 'chewc/trait.py'),
'chewc.trait.corr_traits': ('trait.html#corr_traits', 'chewc/trait.py'),
'chewc.trait.select_qtl_loci': ('trait.html#select_qtl_loci', 'chewc/trait.py')}}}
10 changes: 7 additions & 3 deletions chewc/cross.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_crossing.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_cross.ipynb.

# %% auto 0
__all__ = ['random_crosses', 'double_haploid']

# %% ../nbs/04_crossing.ipynb 4
# %% ../nbs/04_cross.ipynb 3
from .core import *
from .trait import *
from .meiosis import *
import torch
def random_crosses( genome: Genome, parent_haplotypes: torch.Tensor, n_crosses: int) -> torch.Tensor:
"""
Generate random crosses from a set of parent haplotypes.
Expand Down Expand Up @@ -41,7 +45,7 @@ def random_crosses( genome: Genome, parent_haplotypes: torch.Tensor, n_crosses:
return progeny_haplotypes


# %% ../nbs/04_crossing.ipynb 7
# %% ../nbs/04_cross.ipynb 6
def double_haploid(genome: Genome, parent_haplotypes: torch.Tensor) -> torch.Tensor:
"""
Generate doubled haploid individuals from a set of parent haplotypes.
Expand Down
42 changes: 41 additions & 1 deletion chewc/trait.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_trait.ipynb.

# %% auto 0
__all__ = ['select_qtl_loci', 'TraitA']
__all__ = ['select_qtl_loci', 'TraitA', 'corr_traits']

# %% ../nbs/02_trait.ipynb 3
from .core import *
Expand Down Expand Up @@ -88,3 +88,43 @@ def forward(self, dosages: torch.Tensor, h2: Optional[float] = None, varE: Optio
breeding_values += env_noise

return breeding_values

# %% ../nbs/02_trait.ipynb 4
def corr_traits(genome, init_pop, target_means, target_vars, correlation_matrix):

n_chr, n_loci = genome.genetic_map.shape
n_traits = target_means.shape[0]
corA = torch.tensor(correlation_matrix)
L = torch.linalg.cholesky(corA)

# Sample initial additive effects from a standard normal distribution
uncorrelated_effects = torch.randn(n_chr, n_loci, n_traits)

# Reshape for proper multiplication with Cholesky factor
uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, n_traits)

# Introduce correlation FIRST
correlated_effects = torch.matmul(L, uncorrelated_effects.T).T
correlated_effects = correlated_effects.reshape(n_chr, n_loci, n_traits)

# Calculate unscaled breeding values using CORRELATED effects
unscaled_bvs = torch.einsum('ijk,lij->lk', correlated_effects, init_pop)
unscaled_var = unscaled_bvs.var(dim=0)
unscaled_mean = unscaled_bvs.mean(dim=0)
trait_intercepts = []

# Scale correlated effects and calculate intercepts
for i in range(n_traits):
scaling_factor = torch.sqrt(target_vars[i] / unscaled_var[i])
correlated_effects[:, :, i] *= scaling_factor # Scale the CORRELATED effects
trait_intercepts.append(target_means[i] - (unscaled_mean[i] * scaling_factor))

# Now we have the additive marker effects and intercepts to calculate breeding values
trait_intercepts = torch.tensor(trait_intercepts)
# scaled_bvs = torch.einsum('ijk,lij->lk', correlated_effects, init_pop)


Traits = []
for i in range(len(target_means)):
Traits.append(TraitA(genome, correlated_effects[:,:,i], trait_intercepts[i]))
return Traits
Loading

0 comments on commit 1d01bf5

Please sign in to comment.