Skip to content

Commit

Permalink
correct trait a marker fx
Browse files Browse the repository at this point in the history
  • Loading branch information
cjGO committed Jun 7, 2024
1 parent 07a9f84 commit 8d5a572
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 192 deletions.
37 changes: 15 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ n_founders = 50
genetic_map = create_random_genetic_map(number_chromosomes,loci_per_chromosome)
crop_genome = Genome(ploidy, number_chromosomes, loci_per_chromosome, genetic_map)
founder_pop = create_random_founder_pop(crop_genome , n_founders)
```

qtl_map = select_qtl_loci(5,crop_genome)
add_fx = generate_marker_effects(qtl_map,0,1)
``` python
qtl_map = select_qtl_loci(100,crop_genome)
add_fx = generate_marker_effects(qtl_map,founder_pop,0,1)
# Add a trait
trait_A = TraitA(
qtl_map,
add_fx,
genome=crop_genome,
founder_pop=founder_pop,
target_variance=1.0, # Example: Target genetic variance of 1.0
target_mean=10.0) # Example: Target mean genetic value of 10.0

genome=crop_genome,
founder_pop=founder_pop,
)
# Now you can use trait_A to calculate genetic values, simulate phenotypes, etc.
example_genotypes = create_random_founder_pop(crop_genome , 10) # Example genotypes
genetic_values = trait_A.calculate_genetic_value(example_genotypes)
Expand All @@ -65,27 +65,20 @@ print(genetic_values)
print(phenotypes)
```

tensor([11.7284, 11.6001, 9.7972, 8.4023, 9.8967, 10.0112, 7.7677, 10.5128,
10.1639, 11.9934])
tensor([10.5428, 11.0402, 9.3920, 7.7007, 9.8136, 9.1105, 7.1924, 9.3808,
9.9087, 12.1152])

``` python
qtl_map = select_qtl_loci(20, crop_genome)
marker_fx = generate_marker_effects(qtl_map)

founder_genetic_variance = calculate_genetic_variance(founder_pop,marker_fx,crop_genome)


traita = TraitA(qtl_map, marker_fx,crop_genome, founder_pop,1.0,0.0)
```
tensor([4.9217, 3.0676, 3.8802, 3.4133, 3.8190, 2.0305, 1.8424, 3.8215, 2.8152,
3.3127])
tensor([5.3106, 2.9171, 3.4742, 3.4601, 3.0819, 3.1966, 2.1117, 3.8294, 2.2255,
2.3835])

``` python
# recurrent truncation selection

means = []
variances = []
traita = TraitA(qtl_map, marker_fx,crop_genome, founder_pop,1.0,0.0)

marker_fx = generate_marker_effects(qtl_map,founder_pop,0,1)

traita = TraitA(qtl_map, marker_fx,crop_genome, founder_pop)


tgv = traita.calculate_genetic_value(founder_pop)
Expand Down
1 change: 0 additions & 1 deletion chewc/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
'chewc.trait.TraitA._calculate_intercept': ('trait.html#traita._calculate_intercept', 'chewc/trait.py'),
'chewc.trait.TraitA._calculate_scaled_additive_dosages': ( 'trait.html#traita._calculate_scaled_additive_dosages',
'chewc/trait.py'),
'chewc.trait.TraitA._scale_effects': ('trait.html#traita._scale_effects', 'chewc/trait.py'),
'chewc.trait.TraitA.calculate_genetic_value': ('trait.html#traita.calculate_genetic_value', 'chewc/trait.py'),
'chewc.trait.TraitA.setPheno': ('trait.html#traita.setpheno', 'chewc/trait.py'),
'chewc.trait.calculate_genetic_variance': ('trait.html#calculate_genetic_variance', 'chewc/trait.py'),
Expand Down
102 changes: 48 additions & 54 deletions chewc/trait.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_trait.ipynb.

# %% auto 0
__all__ = ['select_qtl_loci', 'generate_marker_effects', 'calculate_genetic_variance', 'TraitA']
__all__ = ['select_qtl_loci', 'calculate_genetic_variance', 'TraitA', 'generate_marker_effects']

# %% ../nbs/02_trait.ipynb 4
from .core import *
Expand Down Expand Up @@ -44,36 +44,6 @@ def select_qtl_loci(num_qtl_per_chromosome: int, genome:Genome) -> torch.Tensor:
return torch.stack(qtl_indices)

# %% ../nbs/02_trait.ipynb 5
def generate_marker_effects(qtl_map: torch.Tensor, mean: float = 0.0, variance: float = 1.0) -> torch.Tensor:
"""
Generates random marker effects for QTLs, drawn from a normal distribution.
Args:
----
qtl_map (torch.Tensor): A boolean tensor indicating which loci are QTLs.
Shape: (number_chromosomes, loci_per_chromosome)
mean (float): The mean of the normal distribution from which to draw effects. Defaults to 0.0.
variance (float): The variance of the normal distribution from which to draw effects. Defaults to 1.0.
Returns:
-------
torch.Tensor: A tensor of marker effects. Shape: (number_chromosomes, loci_per_chromosome).
Non-QTL loci will have an effect of 0.
"""
# Create a tensor of zeros with the same shape as the qtl_map
effects = torch.zeros_like(qtl_map, dtype=torch.float)

# Determine the number of QTLs
num_qtl = qtl_map.sum().item()

# Sample random effects from a normal distribution
qtl_effects = torch.randn(num_qtl) * (variance ** 0.5) + mean

# Assign the sampled effects to the QTL positions in the effects tensor
effects[qtl_map] = qtl_effects

return effects

def calculate_genetic_variance(founder_pop: torch.Tensor, marker_effects: torch.Tensor, genome: Genome) -> float:
"""
Calculates the additive genetic variance in the founder population.
Expand Down Expand Up @@ -135,16 +105,13 @@ class TraitA:
additive_effects: torch.Tensor
genome: Genome
founder_pop: torch.Tensor
target_variance: float
target_mean: float
intercept: float = attr.ib(init=False)

def __attrs_post_init__(self):
"""
Calculate the intercept and scale the effects after initialization.
"""
self.intercept = self._calculate_intercept()
self._scale_effects()

def _calculate_intercept(self) -> float:
"""
Expand All @@ -156,7 +123,7 @@ def _calculate_intercept(self) -> float:
# Calculate the mean genetic value of the founder population (without scaling)
founder_genetic_values = (self.founder_pop.float() * self.additive_effects).sum(dim=(1, 2, 3))
mean_founder_gv = founder_genetic_values.mean().item()
return self.target_mean - mean_founder_gv
return - mean_founder_gv

def _calculate_scaled_additive_dosages(self, genotypes: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -171,24 +138,7 @@ def _calculate_scaled_additive_dosages(self, genotypes: torch.Tensor) -> torch.T
Shape: (n_individuals, ploidy, number_chromosomes, loci_per_chromosome).
"""
return (genotypes - self.genome.ploidy / 2) * (2 / self.genome.ploidy)

def _scale_effects(self) -> None:
"""
Scales the additive effects to achieve the target genetic variance and
calculates the intercept to achieve the target mean.
"""
# Calculate the initial genetic variance in the founder population
founder_gvs = self.calculate_genetic_value(self.founder_pop)
initial_variance = founder_gvs.var().item()

# Calculate the scaling factor
scaling_factor = (self.target_variance / initial_variance) ** 0.5

# Scale the additive effects
self.additive_effects = self.additive_effects * scaling_factor

# Recalculate the intercept after scaling
# self.intercept = self._calculate_intercept()


def calculate_genetic_value(self, genotypes: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -208,6 +158,7 @@ def calculate_genetic_value(self, genotypes: torch.Tensor) -> torch.Tensor:
scaled_dosages = self._calculate_scaled_additive_dosages(genotypes)

# Apply the additive effects to the scaled dosages, only at QTL positions
# dim(1,2,3) grabs the total matrix for each individual ignoring the 0 axes which is individual index
additive_genetic_values = (scaled_dosages * self.additive_effects).sum(dim=(1, 2, 3))

# Add the intercept to adjust the mean genetic value
Expand Down Expand Up @@ -239,11 +190,54 @@ def setPheno(self, genotypes: torch.Tensor,

if h2 is not None:
# Calculate environmental variance based on heritability
varE = self.target_variance * (1 - h2) / h2
varE = (1 - h2) / h2

# Simulate environmental effects
environmental_effects = torch.randn(genotypes.shape[0]) * torch.sqrt(varE)

# Calculate phenotypes
phenotypes = genetic_values + environmental_effects
return phenotypes

# %% ../nbs/02_trait.ipynb 7
def generate_marker_effects(qtl_map: torch.Tensor, founder_pop: torch.Tensor, mean: float = 0.0, variance: float = 1.0,) -> torch.Tensor:
"""
Generates random marker effects for QTLs, drawn from a normal distribution and scaled to match the desired genetic variance.
Args:
----
qtl_map (torch.Tensor): A boolean tensor indicating which loci are QTLs.
Shape: (number_chromosomes, loci_per_chromosome)
mean (float): The mean of the normal distribution from which to draw effects. Defaults to 0.0.
variance (float): The desired genetic variance of the trait in the founder population. Defaults to 1.0.
founder_pop (torch.Tensor): A tensor containing the genotypes of the founder population.
Shape: (number_founders, polyploid, number_chromosomes, loci_per_chromosome)
Returns:
-------
torch.Tensor: A tensor of marker effects, scaled to achieve the desired genetic variance.
Shape: (number_chromosomes, loci_per_chromosome).
Non-QTL loci will have an effect of 0.
"""
# Create empty vector to store marker effects
effects = torch.zeros_like(qtl_map, dtype=torch.float)
# Get total number of QTLs with non-zero marker effects
num_qtl = qtl_map.sum().item()
# Sample from a normal distribution and scale by variance and add mean
qtl_effects = torch.randn(num_qtl) * (variance ** 0.5) + mean
# Store these effects in the vector
effects[qtl_map] = qtl_effects

# Sum over the ploidy (dim=1) to get genotypes encoded as 0/1/2
founder_qtl_genotypes = founder_pop.sum(dim=1) # Shape: [500, 10, 100]
# Use torch.einsum to multiply and sum over the appropriate dimensions
summed_result = torch.einsum('bij,ij->b', founder_qtl_genotypes.float(), effects.float())

# Calculate the initial genetic variance in the founder population
initial_variance = torch.var(summed_result)
# Calculate scaling factor
scaling_factor = (variance / initial_variance) ** 0.5
# Scale the effects
effects[qtl_map] *= scaling_factor

return effects
Binary file modified index_files/figure-commonmark/cell-7-output-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 8d5a572

Please sign in to comment.