Skip to content

Commit

Permalink
feat(starrynight): integrate starrynightmodules and massive refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
leoank committed Jan 7, 2025
1 parent c00dc24 commit 359f072
Show file tree
Hide file tree
Showing 27 changed files with 628 additions and 335 deletions.
4 changes: 3 additions & 1 deletion conductor/src/conductor/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import JSON, Enum, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql.sqltypes import String
from starrynight.modules.schema import Container

from conductor.constants import JobType
from conductor.models.base import BaseSQLModel
Expand All @@ -14,9 +15,10 @@ class Job(BaseSQLModel):

__tablename__ = "job"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
module_id: Mapped[str] = mapped_column(nullable=False, index=True)
name: Mapped[str] = mapped_column(String(100))
description: Mapped[str] = mapped_column()
type = mapped_column(Enum(JobType, create_constraint=True), nullable=False)
spec: Mapped[dict] = mapped_column(Container, nullable=False)
outputs: Mapped[dict] = mapped_column(JSON)
inputs: Mapped[dict] = mapped_column(JSON)
step_id: Mapped[int] = mapped_column(ForeignKey("step.id"), nullable=False)
Expand Down
8 changes: 3 additions & 5 deletions conductor/src/conductor/validators/job.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
"""Job domain related validators."""

from pydantic import BaseModel

from conductor.constants import JobInputSchema, JobOutputSchema, JobType
from starrynight.modules.schema import Container


class Job(BaseModel):
"""Job create schema."""

id: int | None = None
module_id: str
step_id: int
name: str
description: str
type: JobType
outputs: dict[str, JobOutputSchema]
inputs: dict[str, JobInputSchema]
spec: Container

model_config: dict = {"from_attributes": True}
6 changes: 4 additions & 2 deletions pipecraft/src/pipecraft/backend/snakemake.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
text=Path(__file__).parent.joinpath("templates/snakemake.mako").read_text(),
output_encoding="utf-8",
)
self.snakefile = None

def compile(self) -> None:
"""Compile SnakeMake pipeline.
Expand Down Expand Up @@ -107,8 +108,8 @@ def compile(self) -> None:
invoke_shells=invoke_shells,
)
assert type(snakefile) is bytes
snakefile = snakefile.decode("utf-8")
f.writelines(snakefile)
self.snakefile = snakefile.decode("utf-8")
f.writelines(self.snakefile)

def run(self) -> Path | CloudPath:
"""Run SankeMake.
Expand All @@ -119,6 +120,7 @@ def run(self) -> Path | CloudPath:
Path to run log file.
"""
self.compile()
cwd = self.output_dir
if isinstance(self.output_dir, CloudPath):
bucket = self.output_dir.drive
Expand Down
7 changes: 7 additions & 0 deletions pipecraft/src/pipecraft/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, node_list: "list[Pipeline | Node]") -> None:
"""
self.pipeline = nx.DiGraph()
self.node_list = node_list
self.is_compiled = False
self.resolve()

def resolve(self) -> None:
Expand Down Expand Up @@ -102,6 +103,8 @@ def compile(self: "Seq") -> Pipeline:
Compiled pipeline.
"""
if self.is_compiled:
return self
# for i in range(max(0, len(self.node_list) - 1)):
prev_root = None
for i, item in enumerate(self.node_list):
Expand All @@ -123,6 +126,7 @@ def compile(self: "Seq") -> Pipeline:
elif isinstance(prev_root, list):
self.pipeline.add_edge(prev_root[-1], subgraph_flattened[0])
prev_root = subgraph_flattened[-1]
self.is_compiled = True
return self


Expand Down Expand Up @@ -157,6 +161,8 @@ def compile(self: "Parallel") -> Pipeline:
Compiled pipeline.
"""
if self.is_compiled:
return self
flattened_list: list[Node] = list(flatten([self.resolved_list]))
scatter_node = flattened_list[0]
gather_node = flattened_list[-1]
Expand All @@ -171,4 +177,5 @@ def compile(self: "Parallel") -> Pipeline:
subgraph_flattened = list(flatten([self.resolved_list[i]]))
self.pipeline.add_edge(scatter_node, subgraph_flattened[0])
self.pipeline.add_edge(subgraph_flattened[-1], gather_node)
self.is_compiled = True
return self
51 changes: 51 additions & 0 deletions starrynight/notebooks/create_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Create index for a project."""

from pathlib import Path

from pipecraft.backend.snakemake import SnakeMakeBackend, SnakeMakeConfig

from starrynight.experiments.common import DummyExperiment
from starrynight.modules.gen_inv import GenInvModule
from starrynight.pipelines.index import create_index_pipeline
from starrynight.schema import DataConfig

# Setup experiment
data = DataConfig(
dataset_path=Path("/datastore/cpg0999-merck-asma"),
storage_path=Path("./run001/workspace"),
workspace_path=Path("./run001/workspace"),
)
experiment = DummyExperiment(dataset_id="cpg0999-merck-asma")

# Create the pipeline
pipeline = create_index_pipeline(experiment, data)


# Configure execution backend
config = SnakeMakeConfig()
exec_backend = SnakeMakeBackend(
pipeline, config, data.storage_path, data.workspace_path
)

# Compile to get the generated snakemake file if
# manual execution or inspection is required.
# exec_backend.compile()

# or execute directly
# exec_backend.run()

# -------------------------------------------------------------------
# Changing default module specs
# -------------------------------------------------------------------

# Iniialize module with experiment and data
gen_inv_module = GenInvModule.from_config(experiment, data)

# Inspect current specification and make changes
print(gen_inv_module.spec)
gen_inv_module.spec.inputs[0].path = Path("path/to/my/parser.lark").resolve().__str__()

# Add the configured module back to pipeline
pipeline = create_index_pipeline(
experiment, data, {GenInvModule.uid(): gen_inv_module.spec}
)
Empty file.
72 changes: 0 additions & 72 deletions starrynight/src/starrynight/experiment.py

This file was deleted.

1 change: 1 addition & 0 deletions starrynight/src/starrynight/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Starrynight experiments."""
62 changes: 62 additions & 0 deletions starrynight/src/starrynight/experiments/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Common experiment schemas."""

from abc import ABC, abstractstaticmethod
from enum import Enum
from pathlib import Path
from typing import Unpack

# Use a try block for backwards compatibility
try:
from typing import Self
except ImportError:
from typing_extensions import Self

from pydantic import BaseModel


class Experiment(BaseModel, ABC):
"""Experiment configuration."""

dataset_id: str
data_production_contact: str | None = None
data_processing_contact: str | None = None

@abstractstaticmethod
def from_index(index_path: Path, **kwargs: Unpack) -> Self:
"""Create experiment schema from index."""
pass


class DummyExperiment(Experiment):
"""DummyExperiment to bootstrap pipeline configuration."""

@staticmethod
def from_index(index_path: Path) -> Self:
"""Configure experiment with index."""
raise NotImplementedError


class ImageMetadataGeneric(BaseModel):
"""Generic metadata for an image."""

name: str
batch_id: str
plate_id: str
site_id: str
img_format: str
assay_type: str
channel_dict: dict


class ImageFrameType(Enum):
"""Frame type for image."""

ROUND = "round"
SQUARE = "square"


class AcquisitionOrderType(Enum):
"""Acquisition order for image."""

SNAKE = "snake"
ROWS = "rows"
51 changes: 51 additions & 0 deletions starrynight/src/starrynight/experiments/pcp_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""PCP Experiment."""

from collections.abc import Callable
from pathlib import Path
from typing import Self

import polars as pl
from pydantic import BaseModel, Field

from starrynight.experiments.common import (
AcquisitionOrderType,
Experiment,
ImageFrameType,
)
from starrynight.schema import MeasuredInventory


class SBSConfig(BaseModel):
"""SBS experiment configuration."""

im_per_well: int = Field(320)
n_cycles: int = Field(12)
img_overlap_pct: int = Field(10)
img_frame_type: ImageFrameType = Field(ImageFrameType.ROUND)
channel_dict: dict
acquisition_order: AcquisitionOrderType = Field(AcquisitionOrderType.SNAKE)


class CPConfig(BaseModel):
"""CP Experiment configuration."""

im_per_well: int = Field(1364)
img_overlap_pct: int = Field(10)
img_frame_type: ImageFrameType = Field(ImageFrameType.ROUND)
channel_dict: dict
acquisition_order: AcquisitionOrderType = Field(AcquisitionOrderType.SNAKE)


class PCPGeneric(Experiment):
"""PCP experiment configuration."""

path_parser: Callable[[str], MeasuredInventory]
sbs_config: SBSConfig
cp_config: CPConfig

@staticmethod
def from_index(index_path: Path) -> Self:
if index_path.name.endswith(".csv"):
index_df = pl.scan_csv(index_path)
else:
index_df = pl.scan_parquet(index_path)
5 changes: 0 additions & 5 deletions starrynight/src/starrynight/metadata.py

This file was deleted.

2 changes: 1 addition & 1 deletion starrynight/src/starrynight/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from starrynight.modules.registry import MODULE_REGISTRY

MODULE_REGISTRY["gen_index"] = {}
__all__ = [MODULE_REGISTRY]
Loading

0 comments on commit 359f072

Please sign in to comment.