diff --git a/README.md b/README.md index b0e76f4..e2a2b49 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,16 @@ correlation_matrix = [ ] correlation_matrix = torch.tensor(correlation_matrix) +# multi_traits +target_means = torch.tensor([0, 5, 20]) +target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait +correlation_matrix = [ + [1.0, 0.2, 0.58], + [0.2, 1.0, -0.37], + [0.58, -0.37, 1.0], + ] +correlation_matrix = torch.tensor(correlation_matrix) + ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100) ta(population.get_dosages()).shape ``` diff --git a/chewc/trait.py b/chewc/trait.py index 63f4116..1cb3a8f 100644 --- a/chewc/trait.py +++ b/chewc/trait.py @@ -39,37 +39,34 @@ def select_qtl_loci(num_qtl_per_chromosome: int, genome: Genome) -> torch.Tensor return torch.stack(qtl_indices).to(genome.device) - - class TraitModule(nn.Module): """ Module for managing and simulating multiple correlated additive traits. """ - def __init__(self, genome: Genome,founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor, - correlation_matrix: torch.Tensor, n_qtl_per_chromosome: int): + def __init__(self, genome: Genome, founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor, + correlation_matrix: Optional[torch.Tensor], n_qtl_per_chromosome: int): """ Initializes the TraitModule. Args: genome (Genome): The genome object. - target_means (torch.Tensor): Target means for each trait (n_traits). - target_vars (torch.Tensor): Target variances for each trait (n_traits). - correlation_matrix (torch.Tensor): Correlation matrix between traits (n_traits, n_traits). + target_means (torch.Tensor): Target means for each trait (n_traits or 1 for single trait). + target_vars (torch.Tensor): Target variances for each trait (n_traits or 1 for single trait). + correlation_matrix (Optional[torch.Tensor]): Correlation matrix between traits (n_traits, n_traits) or None for single trait. n_qtl_per_chromosome (int): Number of QTLs per chromosome for each trait. """ super().__init__() self.genome = genome self.founder_pop = founder_pop - self.n_traits = len(target_means) - self.target_means = target_means.to(genome.device) - self.target_vars = target_vars.to(genome.device) - self.correlation_matrix = correlation_matrix.to(genome.device) + self.n_traits = 1 if target_means.dim() == 0 else len(target_means) + self.target_means = target_means.to(genome.device).view(-1) + self.target_vars = target_vars.to(genome.device).view(-1) + self.correlation_matrix = correlation_matrix.to(genome.device) if correlation_matrix is not None else None self.n_qtl_per_chromosome = n_qtl_per_chromosome self.qtl_loci = select_qtl_loci(n_qtl_per_chromosome, genome) self.effects = self._initialize_correlated_effects() self.intercepts = self._calculate_intercepts() - def _initialize_correlated_effects(self) -> torch.Tensor: """ @@ -80,13 +77,14 @@ def _initialize_correlated_effects(self) -> torch.Tensor: """ n_chr, n_loci = self.genome.genetic_map.shape - L = torch.linalg.cholesky(self.correlation_matrix) - - uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device) - uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits) - - correlated_effects = torch.matmul(L, uncorrelated_effects.T).T - return correlated_effects.reshape(n_chr, n_loci, self.n_traits) + if self.correlation_matrix is not None: + L = torch.linalg.cholesky(self.correlation_matrix) + uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device) + uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits) + correlated_effects = torch.matmul(L, uncorrelated_effects.T).T + return correlated_effects.reshape(n_chr, n_loci, self.n_traits) + else: + return torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device) def _calculate_intercepts(self) -> torch.Tensor: """ @@ -99,15 +97,13 @@ def _calculate_intercepts(self) -> torch.Tensor: Returns: torch.Tensor: Trait intercepts (n_traits). """ - # Example: Calculate intercepts based on a founder population dosages = self.founder_pop.get_dosages() unscaled_bvs = self.calculate_breeding_values(dosages, scale_effects=False) unscaled_var = unscaled_bvs.var(dim=0, unbiased=False) unscaled_mean = unscaled_bvs.mean(dim=0) scaling_factors = torch.sqrt(self.target_vars / unscaled_var) -# import pdb; pdb.set_trace() - self.effects *= scaling_factors.view(1, 1, 3) # Scale the effects + self.effects *= scaling_factors.view(1, 1, self.n_traits) # Scale the effects return self.target_means - (unscaled_mean * scaling_factors) def calculate_breeding_values(self, dosages: torch.Tensor, scale_effects: bool = True) -> torch.Tensor: @@ -156,3 +152,4 @@ def forward(self, dosages: torch.Tensor, h2: Optional[Union[float, torch.Tensor] return breeding_values + env_noise else: return breeding_values # No noise added + diff --git a/nbs/02_trait.ipynb b/nbs/02_trait.ipynb index 6c60217..3aca9ec 100644 --- a/nbs/02_trait.ipynb +++ b/nbs/02_trait.ipynb @@ -74,37 +74,34 @@ " \n", " return torch.stack(qtl_indices).to(genome.device)\n", "\n", - "\n", - "\n", "class TraitModule(nn.Module):\n", " \"\"\"\n", " Module for managing and simulating multiple correlated additive traits.\n", " \"\"\"\n", - " def __init__(self, genome: Genome,founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor, \n", - " correlation_matrix: torch.Tensor, n_qtl_per_chromosome: int):\n", + " def __init__(self, genome: Genome, founder_pop, target_means: torch.Tensor, target_vars: torch.Tensor, \n", + " correlation_matrix: Optional[torch.Tensor], n_qtl_per_chromosome: int):\n", " \"\"\"\n", " Initializes the TraitModule.\n", "\n", " Args:\n", " genome (Genome): The genome object.\n", - " target_means (torch.Tensor): Target means for each trait (n_traits).\n", - " target_vars (torch.Tensor): Target variances for each trait (n_traits).\n", - " correlation_matrix (torch.Tensor): Correlation matrix between traits (n_traits, n_traits).\n", + " target_means (torch.Tensor): Target means for each trait (n_traits or 1 for single trait).\n", + " target_vars (torch.Tensor): Target variances for each trait (n_traits or 1 for single trait).\n", + " correlation_matrix (Optional[torch.Tensor]): Correlation matrix between traits (n_traits, n_traits) or None for single trait.\n", " n_qtl_per_chromosome (int): Number of QTLs per chromosome for each trait.\n", " \"\"\"\n", " super().__init__()\n", " self.genome = genome\n", " self.founder_pop = founder_pop\n", - " self.n_traits = len(target_means)\n", - " self.target_means = target_means.to(genome.device)\n", - " self.target_vars = target_vars.to(genome.device)\n", - " self.correlation_matrix = correlation_matrix.to(genome.device)\n", + " self.n_traits = 1 if target_means.dim() == 0 else len(target_means)\n", + " self.target_means = target_means.to(genome.device).view(-1)\n", + " self.target_vars = target_vars.to(genome.device).view(-1)\n", + " self.correlation_matrix = correlation_matrix.to(genome.device) if correlation_matrix is not None else None\n", " self.n_qtl_per_chromosome = n_qtl_per_chromosome\n", " \n", " self.qtl_loci = select_qtl_loci(n_qtl_per_chromosome, genome)\n", " self.effects = self._initialize_correlated_effects()\n", " self.intercepts = self._calculate_intercepts()\n", - " \n", "\n", " def _initialize_correlated_effects(self) -> torch.Tensor:\n", " \"\"\"\n", @@ -115,13 +112,14 @@ " \"\"\"\n", " n_chr, n_loci = self.genome.genetic_map.shape\n", " \n", - " L = torch.linalg.cholesky(self.correlation_matrix)\n", - "\n", - " uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)\n", - " uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits)\n", - "\n", - " correlated_effects = torch.matmul(L, uncorrelated_effects.T).T\n", - " return correlated_effects.reshape(n_chr, n_loci, self.n_traits)\n", + " if self.correlation_matrix is not None:\n", + " L = torch.linalg.cholesky(self.correlation_matrix)\n", + " uncorrelated_effects = torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)\n", + " uncorrelated_effects = uncorrelated_effects.reshape(n_chr * n_loci, self.n_traits)\n", + " correlated_effects = torch.matmul(L, uncorrelated_effects.T).T\n", + " return correlated_effects.reshape(n_chr, n_loci, self.n_traits)\n", + " else:\n", + " return torch.randn(n_chr, n_loci, self.n_traits, device=self.genome.device)\n", "\n", " def _calculate_intercepts(self) -> torch.Tensor:\n", " \"\"\"\n", @@ -134,15 +132,13 @@ " Returns:\n", " torch.Tensor: Trait intercepts (n_traits).\n", " \"\"\"\n", - " # Example: Calculate intercepts based on a founder population\n", " dosages = self.founder_pop.get_dosages()\n", " unscaled_bvs = self.calculate_breeding_values(dosages, scale_effects=False)\n", " unscaled_var = unscaled_bvs.var(dim=0, unbiased=False)\n", " unscaled_mean = unscaled_bvs.mean(dim=0)\n", " \n", " scaling_factors = torch.sqrt(self.target_vars / unscaled_var)\n", - "# import pdb; pdb.set_trace()\n", - " self.effects *= scaling_factors.view(1, 1, 3) # Scale the effects\n", + " self.effects *= scaling_factors.view(1, 1, self.n_traits) # Scale the effects\n", " return self.target_means - (unscaled_mean * scaling_factors)\n", "\n", " def calculate_breeding_values(self, dosages: torch.Tensor, scale_effects: bool = True) -> torch.Tensor:\n", @@ -190,7 +186,7 @@ " env_noise = torch.randn_like(breeding_values) * torch.sqrt(varE)\n", " return breeding_values + env_noise\n", " else:\n", - " return breeding_values # No noise added" + " return breeding_values # No noise added\n" ] }, { @@ -241,6 +237,77 @@ "ta(population.get_dosages()).shape" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c87b215", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created genetic map\n" + ] + }, + { + "data": { + "text/plain": [ + "torch.Size([333, 3])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ploidy = 2\n", + "n_chr = 10\n", + "n_loci = 100\n", + "n_Ind = 333\n", + "g = Genome(ploidy, n_chr, n_loci)\n", + "population = Population()\n", + "population.create_random_founder_population(g, n_founders=n_Ind)\n", + "init_pop = population.get_dosages().float() # gets allele dosage for calculating trait values\n", + "\n", + "# multi_traits\n", + "target_means = torch.tensor([0, 5, 20])\n", + "target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait\n", + "correlation_matrix = [\n", + " [1.0, 0.2, 0.58],\n", + " [0.2, 1.0, -0.37],\n", + " [0.58, -0.37, 1.0],\n", + " ]\n", + "correlation_matrix = torch.tensor(correlation_matrix)\n", + "\n", + "ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100)\n", + "ta(population.get_dosages()).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27260d24", + "metadata": {}, + "outputs": [], + "source": [ + "# Define single trait parameters\n", + "target_mean = torch.tensor(0.5) # Single trait target mean\n", + "target_var = torch.tensor(0.2) # Single trait target variance\n", + "correlation_matrix = None # No correlation matrix for a single trait\n", + "n_qtl_per_chromosome = 10 \n", + "# Initialize the TraitModule for a single trait\n", + "trait_module = TraitModule(\n", + " genome=g,\n", + " founder_pop=population,\n", + " target_means=target_mean,\n", + " target_vars=target_var,\n", + " correlation_matrix=correlation_matrix,\n", + " n_qtl_per_chromosome=n_qtl_per_chromosome\n", + ")\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/index.ipynb b/nbs/index.ipynb index 1c73fb5..9a57745 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -129,6 +129,16 @@ " ]\n", "correlation_matrix = torch.tensor(correlation_matrix)\n", "\n", + "# multi_traits\n", + "target_means = torch.tensor([0, 5, 20])\n", + "target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait\n", + "correlation_matrix = [\n", + " [1.0, 0.2, 0.58],\n", + " [0.2, 1.0, -0.37],\n", + " [0.58, -0.37, 1.0],\n", + " ]\n", + "correlation_matrix = torch.tensor(correlation_matrix)\n", + "\n", "ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100)\n", "ta(population.get_dosages()).shape\n" ] @@ -152,13 +162,6 @@ "source": [ "random_crosses(g, population, 10, reps = 6).shape" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {