Skip to content

Commit

Permalink
allow single traits
Browse files Browse the repository at this point in the history
  • Loading branch information
cjGO committed Jun 12, 2024
1 parent 6e9b2b8 commit c1efc58
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 52 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ correlation_matrix = [
]
correlation_matrix = torch.tensor(correlation_matrix)

# multi_traits
target_means = torch.tensor([0, 5, 20])
target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait
correlation_matrix = [
[1.0, 0.2, 0.58],
[0.2, 1.0, -0.37],
[0.58, -0.37, 1.0],
]
correlation_matrix = torch.tensor(correlation_matrix)

ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100)
ta(population.get_dosages()).shape
```
Expand Down
41 changes: 19 additions & 22 deletions chewc/trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,34 @@ def select_qtl_loci(num_qtl_per_chromosome: int, genome: Genome) -> torch.Tensor

return torch.stack(qtl_indices).to(genome.device)



class TraitModule(nn.Module):
"""
Module for managing and simulating multiple correlated additive traits.
"""
def __init__(self, genome: Genome,founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor,
correlation_matrix: torch.Tensor, n_qtl_per_chromosome: int):
def __init__(self, genome: Genome, founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor,
correlation_matrix: Optional[torch.Tensor], n_qtl_per_chromosome: int):
"""
Initializes the TraitModule.
Args:
genome (Genome): The genome object.
target_means (torch.Tensor): Target means for each trait (n_traits).
target_vars (torch.Tensor): Target variances for each trait (n_traits).
correlation_matrix (torch.Tensor): Correlation matrix between traits (n_traits, n_traits).
target_means (torch.Tensor): Target means for each trait (n_traits or 1 for single trait).
target_vars (torch.Tensor): Target variances for each trait (n_traits or 1 for single trait).
correlation_matrix (Optional[torch.Tensor]): Correlation matrix between traits (n_traits, n_traits) or None for single trait.
n_qtl_per_chromosome (int): Number of QTLs per chromosome for each trait.
"""
super().__init__()
self.genome = genome
self.founder_pop = founder_pop
self.n_traits = len(target_means)
self.target_means = target_means.to(genome.device)
self.target_vars = target_vars.to(genome.device)
self.correlation_matrix = correlation_matrix.to(genome.device)
self.n_traits = 1 if target_means.dim() == 0 else len(target_means)
self.target_means = target_means.to(genome.device).view(-1)
self.target_vars = target_vars.to(genome.device).view(-1)
self.correlation_matrix = correlation_matrix.to(genome.device) if correlation_matrix is not None else None
self.n_qtl_per_chromosome = n_qtl_per_chromosome

self.qtl_loci = select_qtl_loci(n_qtl_per_chromosome, genome)
self.effects = self._initialize_correlated_effects()
self.intercepts = self._calculate_intercepts()


def _initialize_correlated_effects(self) -> torch.Tensor:
"""
Expand All @@ -80,13 +77,14 @@ def _initialize_correlated_effects(self) -> torch.Tensor:
"""
n_chr, n_loci = self.genome.genetic_map.shape

L = torch.linalg.cholesky(self.correlation_matrix)

uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)
uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits)

correlated_effects = torch.matmul(L, uncorrelated_effects.T).T
return correlated_effects.reshape(n_chr, n_loci, self.n_traits)
if self.correlation_matrix is not None:
L = torch.linalg.cholesky(self.correlation_matrix)
uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)
uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits)
correlated_effects = torch.matmul(L, uncorrelated_effects.T).T
return correlated_effects.reshape(n_chr, n_loci, self.n_traits)
else:
return torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)

def _calculate_intercepts(self) -> torch.Tensor:
"""
Expand All @@ -99,15 +97,13 @@ def _calculate_intercepts(self) -> torch.Tensor:
Returns:
torch.Tensor: Trait intercepts (n_traits).
"""
# Example: Calculate intercepts based on a founder population
dosages = self.founder_pop.get_dosages()
unscaled_bvs = self.calculate_breeding_values(dosages, scale_effects=False)
unscaled_var = unscaled_bvs.var(dim=0, unbiased=False)
unscaled_mean = unscaled_bvs.mean(dim=0)

scaling_factors = torch.sqrt(self.target_vars / unscaled_var)
# import pdb; pdb.set_trace()
self.effects *= scaling_factors.view(1, 1, 3) # Scale the effects
self.effects *= scaling_factors.view(1, 1, self.n_traits) # Scale the effects
return self.target_means - (unscaled_mean * scaling_factors)

def calculate_breeding_values(self, dosages: torch.Tensor, scale_effects: bool = True) -> torch.Tensor:
Expand Down Expand Up @@ -156,3 +152,4 @@ def forward(self, dosages: torch.Tensor, h2: Optional[Union[float, torch.Tensor]
return breeding_values + env_noise
else:
return breeding_values # No noise added

113 changes: 90 additions & 23 deletions nbs/02_trait.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,37 +74,34 @@
" \n",
" return torch.stack(qtl_indices).to(genome.device)\n",
"\n",
"\n",
"\n",
"class TraitModule(nn.Module):\n",
" \"\"\"\n",
" Module for managing and simulating multiple correlated additive traits.\n",
" \"\"\"\n",
" def __init__(self, genome: Genome,founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor, \n",
" correlation_matrix: torch.Tensor, n_qtl_per_chromosome: int):\n",
" def __init__(self, genome: Genome, founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor, \n",
" correlation_matrix: Optional[torch.Tensor], n_qtl_per_chromosome: int):\n",
" \"\"\"\n",
" Initializes the TraitModule.\n",
"\n",
" Args:\n",
" genome (Genome): The genome object.\n",
" target_means (torch.Tensor): Target means for each trait (n_traits).\n",
" target_vars (torch.Tensor): Target variances for each trait (n_traits).\n",
" correlation_matrix (torch.Tensor): Correlation matrix between traits (n_traits, n_traits).\n",
" target_means (torch.Tensor): Target means for each trait (n_traits or 1 for single trait).\n",
" target_vars (torch.Tensor): Target variances for each trait (n_traits or 1 for single trait).\n",
" correlation_matrix (Optional[torch.Tensor]): Correlation matrix between traits (n_traits, n_traits) or None for single trait.\n",
" n_qtl_per_chromosome (int): Number of QTLs per chromosome for each trait.\n",
" \"\"\"\n",
" super().__init__()\n",
" self.genome = genome\n",
" self.founder_pop = founder_pop\n",
" self.n_traits = len(target_means)\n",
" self.target_means = target_means.to(genome.device)\n",
" self.target_vars = target_vars.to(genome.device)\n",
" self.correlation_matrix = correlation_matrix.to(genome.device)\n",
" self.n_traits = 1 if target_means.dim() == 0 else len(target_means)\n",
" self.target_means = target_means.to(genome.device).view(-1)\n",
" self.target_vars = target_vars.to(genome.device).view(-1)\n",
" self.correlation_matrix = correlation_matrix.to(genome.device) if correlation_matrix is not None else None\n",
" self.n_qtl_per_chromosome = n_qtl_per_chromosome\n",
" \n",
" self.qtl_loci = select_qtl_loci(n_qtl_per_chromosome, genome)\n",
" self.effects = self._initialize_correlated_effects()\n",
" self.intercepts = self._calculate_intercepts()\n",
" \n",
"\n",
" def _initialize_correlated_effects(self) -> torch.Tensor:\n",
" \"\"\"\n",
Expand All @@ -115,13 +112,14 @@
" \"\"\"\n",
" n_chr, n_loci = self.genome.genetic_map.shape\n",
" \n",
" L = torch.linalg.cholesky(self.correlation_matrix)\n",
"\n",
" uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)\n",
" uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits)\n",
"\n",
" correlated_effects = torch.matmul(L, uncorrelated_effects.T).T\n",
" return correlated_effects.reshape(n_chr, n_loci, self.n_traits)\n",
" if self.correlation_matrix is not None:\n",
" L = torch.linalg.cholesky(self.correlation_matrix)\n",
" uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)\n",
" uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits)\n",
" correlated_effects = torch.matmul(L, uncorrelated_effects.T).T\n",
" return correlated_effects.reshape(n_chr, n_loci, self.n_traits)\n",
" else:\n",
" return torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)\n",
"\n",
" def _calculate_intercepts(self) -> torch.Tensor:\n",
" \"\"\"\n",
Expand All @@ -134,15 +132,13 @@
" Returns:\n",
" torch.Tensor: Trait intercepts (n_traits).\n",
" \"\"\"\n",
" # Example: Calculate intercepts based on a founder population\n",
" dosages = self.founder_pop.get_dosages()\n",
" unscaled_bvs = self.calculate_breeding_values(dosages, scale_effects=False)\n",
" unscaled_var = unscaled_bvs.var(dim=0, unbiased=False)\n",
" unscaled_mean = unscaled_bvs.mean(dim=0)\n",
" \n",
" scaling_factors = torch.sqrt(self.target_vars / unscaled_var)\n",
"# import pdb; pdb.set_trace()\n",
" self.effects *= scaling_factors.view(1, 1, 3) # Scale the effects\n",
" self.effects *= scaling_factors.view(1, 1, self.n_traits) # Scale the effects\n",
" return self.target_means - (unscaled_mean * scaling_factors)\n",
"\n",
" def calculate_breeding_values(self, dosages: torch.Tensor, scale_effects: bool = True) -> torch.Tensor:\n",
Expand Down Expand Up @@ -190,7 +186,7 @@
" env_noise = torch.randn_like(breeding_values) * torch.sqrt(varE)\n",
" return breeding_values + env_noise\n",
" else:\n",
" return breeding_values # No noise added"
" return breeding_values # No noise added\n"
]
},
{
Expand Down Expand Up @@ -241,6 +237,77 @@
"ta(population.get_dosages()).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c87b215",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Created genetic map\n"
]
},
{
"data": {
"text/plain": [
"torch.Size([333, 3])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ploidy = 2\n",
"n_chr = 10\n",
"n_loci = 100\n",
"n_Ind = 333\n",
"g = Genome(ploidy, n_chr, n_loci)\n",
"population = Population()\n",
"population.create_random_founder_population(g, n_founders=n_Ind)\n",
"init_pop = population.get_dosages().float() # gets allele dosage for calculating trait values\n",
"\n",
"# multi_traits\n",
"target_means = torch.tensor([0, 5, 20])\n",
"target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait\n",
"correlation_matrix = [\n",
" [1.0, 0.2, 0.58],\n",
" [0.2, 1.0, -0.37],\n",
" [0.58, -0.37, 1.0],\n",
" ]\n",
"correlation_matrix = torch.tensor(correlation_matrix)\n",
"\n",
"ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100)\n",
"ta(population.get_dosages()).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27260d24",
"metadata": {},
"outputs": [],
"source": [
"# Define single trait parameters\n",
"target_mean = torch.tensor(0.5) # Single trait target mean\n",
"target_var = torch.tensor(0.2) # Single trait target variance\n",
"correlation_matrix = None # No correlation matrix for a single trait\n",
"n_qtl_per_chromosome = 10 \n",
"# Initialize the TraitModule for a single trait\n",
"trait_module = TraitModule(\n",
" genome=g,\n",
" founder_pop=population,\n",
" target_means=target_mean,\n",
" target_vars=target_var,\n",
" correlation_matrix=correlation_matrix,\n",
" n_qtl_per_chromosome=n_qtl_per_chromosome\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
17 changes: 10 additions & 7 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@
" ]\n",
"correlation_matrix = torch.tensor(correlation_matrix)\n",
"\n",
"# multi_traits\n",
"target_means = torch.tensor([0, 5, 20])\n",
"target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait\n",
"correlation_matrix = [\n",
" [1.0, 0.2, 0.58],\n",
" [0.2, 1.0, -0.37],\n",
" [0.58, -0.37, 1.0],\n",
" ]\n",
"correlation_matrix = torch.tensor(correlation_matrix)\n",
"\n",
"ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100)\n",
"ta(population.get_dosages()).shape\n"
]
Expand All @@ -152,13 +162,6 @@
"source": [
"random_crosses(g, population, 10, reps = 6).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit c1efc58

Please sign in to comment.