Skip to content

Commit

Permalink
polished core and trait
Browse files Browse the repository at this point in the history
  • Loading branch information
cjGO committed Jun 12, 2024
1 parent c0af82e commit 1be15b9
Show file tree
Hide file tree
Showing 10 changed files with 909 additions and 411 deletions.
36 changes: 15 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,35 +37,29 @@ import torch
```

``` python
g = Genome()
founder_pop = Population()
founder_pop.create_random_founder_population(g, 30)
ploidy = 2
n_chr = 10
n_loci = 100
n_Ind = 333
g = Genome(ploidy, n_chr, n_loci)
population = Population()
population.create_random_founder_population(g, n_founders=n_Ind)
init_pop = population.get_dosages().float() # gets allele dosage for calculating trait values

# multi_traits
target_means = torch.tensor([0, 5, 20])
target_vars = torch.tensor([1, 1, 0.5]) # Note: I'm assuming you want a variance of 1 for the second trait
correlation_values = [
correlation_matrix = [
[1.0, 0.2, 0.58],
[0.2, 1.0, -0.37],
[0.58, -0.37, 1.0],
]
traits = corr_traits(g, founder_pop.get_dosages().float(), target_means, target_vars, correlation_values)
```

Created genetic map

``` python
f1 = x_random(g, founder_pop.get_genotypes().float(), 50)
f1.shape
```

torch.Size([50, 2, 10, 5])
correlation_matrix = torch.tensor(correlation_matrix)

``` python
DH = x_DH(g,f1)
ta = TraitModule(g, population, target_means, target_vars, correlation_matrix,100)
ta(population.get_dosages()).shape
```

``` python
DH.shape
```
Created genetic map

torch.Size([50, 2, 10, 5])
torch.Size([333, 3])
15 changes: 9 additions & 6 deletions chewc/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
'chewc.core.Individual.__init__': ('core.html#individual.__init__', 'chewc/core.py'),
'chewc.core.Individual.create_random_individual': ( 'core.html#individual.create_random_individual',
'chewc/core.py'),
'chewc.core.Individual.to': ('core.html#individual.to', 'chewc/core.py'),
'chewc.core.Population': ('core.html#population', 'chewc/core.py'),
'chewc.core.Population.__init__': ('core.html#population.__init__', 'chewc/core.py'),
'chewc.core.Population.add_individual': ('core.html#population.add_individual', 'chewc/core.py'),
Expand All @@ -27,7 +26,6 @@
'chewc.core.Population.get_dosages': ('core.html#population.get_dosages', 'chewc/core.py'),
'chewc.core.Population.get_genotypes': ('core.html#population.get_genotypes', 'chewc/core.py'),
'chewc.core.Population.size': ('core.html#population.size', 'chewc/core.py'),
'chewc.core.Population.to': ('core.html#population.to', 'chewc/core.py'),
'chewc.core.PopulationDataset': ('core.html#populationdataset', 'chewc/core.py'),
'chewc.core.PopulationDataset.__getitem__': ('core.html#populationdataset.__getitem__', 'chewc/core.py'),
'chewc.core.PopulationDataset.__init__': ('core.html#populationdataset.__init__', 'chewc/core.py'),
Expand All @@ -39,8 +37,13 @@
'chewc.crossing.random_crosses': ('crossing.html#random_crosses', 'chewc/crossing.py')},
'chewc.meiosis': { 'chewc.meiosis.poisson_crossing_over': ('meiosis.html#poisson_crossing_over', 'chewc/meiosis.py'),
'chewc.meiosis.simulate_gametes': ('meiosis.html#simulate_gametes', 'chewc/meiosis.py')},
'chewc.trait': { 'chewc.trait.TraitA': ('trait.html#traita', 'chewc/trait.py'),
'chewc.trait.TraitA.__init__': ('trait.html#traita.__init__', 'chewc/trait.py'),
'chewc.trait.TraitA.forward': ('trait.html#traita.forward', 'chewc/trait.py'),
'chewc.trait.corr_traits': ('trait.html#corr_traits', 'chewc/trait.py'),
'chewc.trait': { 'chewc.trait.TraitModule': ('trait.html#traitmodule', 'chewc/trait.py'),
'chewc.trait.TraitModule.__init__': ('trait.html#traitmodule.__init__', 'chewc/trait.py'),
'chewc.trait.TraitModule._calculate_intercepts': ( 'trait.html#traitmodule._calculate_intercepts',
'chewc/trait.py'),
'chewc.trait.TraitModule._initialize_correlated_effects': ( 'trait.html#traitmodule._initialize_correlated_effects',
'chewc/trait.py'),
'chewc.trait.TraitModule.calculate_breeding_values': ( 'trait.html#traitmodule.calculate_breeding_values',
'chewc/trait.py'),
'chewc.trait.TraitModule.forward': ('trait.html#traitmodule.forward', 'chewc/trait.py'),
'chewc.trait.select_qtl_loci': ('trait.html#select_qtl_loci', 'chewc/trait.py')}}}
86 changes: 39 additions & 47 deletions chewc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ def __init__(self, ploidy: int = 2, n_chromosomes: int = 10, n_loci_per_chromoso
self.chromosome_length = chromosome_length

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.genetic_map = None
self.create_genetic_map()


self.create_genetic_map()

def shape(self) -> Tuple[int, int, int]:
"""Returns the shape of the genome (ploidy, chromosomes, loci)."""
return self.ploidy, self.n_chromosomes, self.n_loci_per_chromosome
Expand Down Expand Up @@ -72,34 +71,35 @@ class Individual:
Args:
genome (Genome): Reference to the shared Genome object.
haplotypes (torch.Tensor): Tensor representing the individual's haplotypes.
haplotypes (torch.Tensor): Tensor representing the individual's haplotypes.
Shape: (ploidy, n_chromosomes, n_loci_per_chromosome).
id (Optional[str]): Unique identifier. Defaults to None.
mother_id (Optional[str]): Mother's identifier. Defaults to None.
father_id (Optional[str]): Father's identifier. Defaults to None.
breeding_values (Optional[torch.Tensor]): Breeding values for traits. Shape: (n_traits,). Defaults to None.
breeding_values (Optional[torch.Tensor]): Breeding values for traits.
Shape: (n_traits,). Defaults to None.
phenotypes (Optional[torch.Tensor]): Phenotype for traits. Shape: (n_traits,). Defaults to None.
"""

def __init__(self,
genome: Genome,
genome: 'Genome',
haplotypes: torch.Tensor,
id: Optional[str] = None,
mother_id: Optional[str] = None,
father_id: Optional[str] = None,
breeding_values: Optional[torch.Tensor] = None,
phenotypes: Optional[torch.Tensor] = None):
self.genome: Genome = genome
self.haplotypes: torch.Tensor = haplotypes.to('cpu')
self.id: Optional[str] = id
self.mother_id: Optional[str] = mother_id
self.father_id: Optional[str] = father_id
self.breeding_values: Optional[torch.Tensor] = breeding_values
self.phenotypes: Optional[torch.Tensor] = phenotypes

self.genome = genome
self.haplotypes = haplotypes.to(self.genome.device)
self.id = id
self.mother_id = mother_id
self.father_id = father_id
self.breeding_values = breeding_values
self.phenotypes = phenotypes

@classmethod
def create_random_individual(cls, genome: Genome, id: Optional[str] = None):
def create_random_individual(cls, genome: 'Genome', id: Optional[str] = None) -> 'Individual':
"""
Creates a random individual with the specified genome.
Expand All @@ -110,18 +110,9 @@ def create_random_individual(cls, genome: Genome, id: Optional[str] = None):
Returns:
Individual: A new Individual object with random haplotypes.
"""
haplotypes = torch.randint(0, 2, genome.shape(), device='cpu')
haplotypes = torch.randint(0, 2, genome.shape(), device=genome.device)
return cls(genome=genome, haplotypes=haplotypes, id=id)

def to(self, device: torch.device):
"""Moves the individual's data to the specified device."""
self.haplotypes = self.haplotypes.to(device)
if self.breeding_values is not None:
self.breeding_values = self.breeding_values.to(device)
if self.phenotypes is not None:
self.phenotypes = self.phenotypes.to(device)
return self


class Population:
"""
Expand All @@ -131,20 +122,21 @@ class Population:
individuals (List[Individual], optional): List of Individual objects in the population. Defaults to None.
id (Optional[str]): Unique identifier for the population. Defaults to None.
"""

def __init__(self, individuals: Optional[List[Individual]] = None, id: Optional[str] = None):
self.individuals = individuals if individuals is not None else []
self.id = id
self.id = id

def create_random_founder_population(self, genome: Genome, n_founders: int):
def create_random_founder_population(self, genome: 'Genome', n_founders: int):
"""
Creates a founder population with random haplotypes.
Args:
genome (Genome): The genome object.
n_founders (int): The number of founder individuals to create.
"""
self.individuals = [Individual.create_random_individual(genome, id=str(i)) for i in range(n_founders)]
self.individuals = [Individual.create_random_individual(genome, id=str(i))
for i in range(n_founders)]

def size(self) -> int:
"""Returns the number of individuals in the population."""
Expand All @@ -159,7 +151,7 @@ def get_genotypes(self) -> torch.Tensor:
(population_size, ploidy, n_chromosomes, n_loci_per_chromosome).
"""
return torch.stack([individual.haplotypes for individual in self.individuals])

def get_dosages(self) -> torch.Tensor:
"""
Calculates the allele dosage for each locus in the population by summing over the ploidy.
Expand All @@ -168,32 +160,32 @@ def get_dosages(self) -> torch.Tensor:
torch.Tensor: Allele dosage tensor with shape
(population_size, n_chromosomes, n_loci_per_chromosome).
"""
genotypes = self.get_genotypes()
allele_dosage = genotypes.sum(dim=1) # Sum over the ploidy dimension
return allele_dosage
return self.get_genotypes().sum(dim=1) # Sum over the ploidy dimension

def add_individual(self, individual: Individual):
"""Adds an individual to the population."""
self.individuals.append(individual)

def to(self, device: torch.device):
"""Moves all individuals in the population to the specified device."""
for individual in self.individuals:
individual.to(device)
return self

def calculate_allele_frequencies(self) -> torch.Tensor:
"""Calculates allele frequencies for each locus in the population."""
genotypes = self.get_genotypes().float()
return genotypes.mean(dim=(0, 1)) # Average over ploidy and individuals
"""
Calculates allele frequencies for each locus in the population.
Returns:
torch.Tensor: Allele frequencies (n_chromosomes, n_loci_per_chromosome).
"""
return self.get_genotypes().float().mean(dim=(0, 1)) # Average over ploidy and individuals

def calculate_genetic_diversity(self) -> torch.Tensor:
"""Calculates a measure of genetic diversity (e.g., heterozygosity)."""
# Example implementation (you can customize this based on your needs)
"""
Calculates a measure of genetic diversity (e.g., heterozygosity).
Returns:
torch.Tensor: Genetic diversity (n_chromosomes, n_loci_per_chromosome).
"""
allele_frequencies = self.calculate_allele_frequencies()
return 1.0 - (allele_frequencies**2 + (1 - allele_frequencies)**2)
return 1.0 - (allele_frequencies**2 + (1 - allele_frequencies)**2)

# %% ../nbs/01_core.ipynb 9
# %% ../nbs/01_core.ipynb 8
from torch.utils.data import Dataset, DataLoader

class PopulationDataset(Dataset):
Expand Down
27 changes: 19 additions & 8 deletions chewc/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,33 @@ def x_random( genome: Genome, parent_haplotypes: torch.Tensor, n_crosses: int) -
return progeny_haplotypes

# %% ../nbs/04_cross.ipynb 7
def x_DH(genome: Genome, parent_haplotypes: torch.Tensor) -> torch.Tensor:
def x_DH(genome: Genome, parent_haplotypes: torch.Tensor, reps: int) -> torch.Tensor:
"""
Generate doubled haploid individuals from a set of parent haplotypes.
Generate doubled haploid individuals from a set of parent haplotypes with distinct samples for each repetition.
Args:
----
parent_haplotypes (torch.Tensor): Haplotypes of the parents.
Shape: (n_parents, ploidy, chr, loci)
genome (Genome): Genome object.
parent_haplotypes (torch.Tensor): Haplotypes of the parents.
Shape: (n_parents, ploidy, chr, loci)
reps (int): Number of times to repeat the DH process with distinct samples.
Returns:
-------
torch.Tensor: Haplotypes of the doubled haploid progeny.
Shape: (n_parents, ploidy, chr, loci)
Shape: (n_parents, reps, ploidy, chr, loci)
"""
gametes = simulate_gametes(genome, parent_haplotypes)
dh_haplotypes = gametes.repeat(1, 2, 1, 1) # Duplicate the gametes along ploidy dimension
ploidy, n_chr, n_loci = genome.shape() # Assuming genome.shape() returns (ploidy, n_chr, n_loci)
n_parents = parent_haplotypes.shape[0]
# Each parent represents a new family created by this script
all_dh_haplotypes_by_family = torch.zeros((n_parents, reps, ploidy, n_chr, n_loci), device=parent_haplotypes.device)

# Need to loop through the reps/n_parents to fill the all_dh_haplotypes_by_family
for rep in range(reps):
gametes = simulate_gametes(genome, parent_haplotypes) # Returns (n_parents, ploidy//2, n_chr, n_loci)
# Create doubled haploid by copying the gametes into the ploidy axis
doubled_haploids = torch.cat([gametes, gametes], dim=1) # Concatenate gametes to double the haploid number
all_dh_haplotypes_by_family[:, rep, :, :, :] = doubled_haploids

return all_dh_haplotypes_by_family

return dh_haplotypes
Loading

0 comments on commit 1be15b9

Please sign in to comment.