Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Expose references via ExecuteContext #14686

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class HailFeatureFlags private (
flags.update(flag, value)
}

def +(feature: (String, String)): HailFeatureFlags =
new HailFeatureFlags(flags + (feature._1 -> feature._2))

def get(flag: String): String = flags(flag)

def lookup(flag: String): Option[String] =
Expand Down
8 changes: 2 additions & 6 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def references: mutable.Map[String, ReferenceGenome]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -181,10 +179,8 @@ abstract class Backend extends Closeable {
): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
Extraction.decompose {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}(defaultJSONFormats)
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

Expand Down
9 changes: 6 additions & 3 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object ExecuteContext {
tmpdir: String,
localTmpdir: String,
backend: Backend,
references: Map[String, ReferenceGenome],
fs: FS,
timer: ExecutionTimer,
tempFileManager: TempFileManager,
Expand All @@ -79,6 +80,7 @@ object ExecuteContext {
tmpdir,
localTmpdir,
backend,
references,
fs,
region,
timer,
Expand Down Expand Up @@ -107,6 +109,7 @@ class ExecuteContext(
val tmpdir: String,
val localTmpdir: String,
val backend: Backend,
val references: Map[String, ReferenceGenome],
val fs: FS,
val r: Region,
val timer: ExecutionTimer,
Expand All @@ -128,7 +131,7 @@ class ExecuteContext(
)
}

def stateManager = HailStateManager(backend.references.toMap)
val stateManager = HailStateManager(references)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand All @@ -154,8 +157,6 @@ class ExecuteContext(

def getFlag(name: String): String = flags.get(name)

def getReference(name: String): ReferenceGenome = backend.references(name)

def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null

def shouldNotLogIR(): Boolean = flags.get("no_ir_logging") != null
Expand All @@ -174,6 +175,7 @@ class ExecuteContext(
tmpdir: String = this.tmpdir,
localTmpdir: String = this.localTmpdir,
backend: Backend = this.backend,
references: Map[String, ReferenceGenome] = this.references,
fs: FS = this.fs,
r: Region = this.r,
timer: ExecutionTimer = this.timer,
Expand All @@ -189,6 +191,7 @@ class ExecuteContext(
tmpdir,
localTmpdir,
backend,
references,
fs,
r,
timer,
Expand Down
2 changes: 2 additions & 0 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class LocalBackend(
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

Expand All @@ -102,6 +103,7 @@ class LocalBackend(
tmpdir,
tmpdir,
this,
references.toMap,
fs,
timer,
null,
Expand Down
38 changes: 25 additions & 13 deletions hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import is.hail.expr.ir.{
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx}
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.RowMatrix
import is.hail.types.physical.PStruct
import is.hail.types.virtual.{TArray, TInterval}
import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval}
import is.hail.variant.ReferenceGenome

import scala.collection.mutable
import scala.jdk.CollectionConverters.{
asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter,
}
Expand All @@ -29,7 +31,10 @@ import org.json4s.Formats
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

trait Py4JBackendExtensions { this: Backend =>
trait Py4JBackendExtensions {
def backend: Backend
def references: mutable.Map[String, ReferenceGenome]
def persistedIR: mutable.Map[Int, BaseIR]
def flags: HailFeatureFlags
def longLifeTempFileManager: TempFileManager

Expand Down Expand Up @@ -59,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend =>
persistedIR.remove(id)

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))
backend.withExecuteContext { ctx =>
references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile))
}

def pyRemoveSequence(name: String): Unit =
references(name).removeSequence()
Expand All @@ -74,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend =>
partitionSize: java.lang.Integer,
entries: String,
): Unit =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
entries match {
case "full" =>
Expand Down Expand Up @@ -112,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend =>
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
Expand All @@ -126,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend =>
}

def pyExecuteLiteral(irStr: String): Int =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
assert(ir.typ.isRealizable)
execute(ctx, ir) match {
backend.execute(ctx, ir) match {
case Left(_) => throw new HailException("Can't create literal")
case Right((pt, addr)) =>
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
Expand Down Expand Up @@ -158,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend =>
}

def pyToDF(s: String): DataFrame =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
Interpret(tir, ctx).toDF()
}

def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
log.info("pyReadMultipleMatrixTables: got query")
val kvs = JsonMethods.parse(jsonQuery) match {
case json4s.JObject(values) => values.toMap
Expand Down Expand Up @@ -193,19 +200,24 @@ trait Py4JBackendExtensions { this: Backend =>
addReference(ReferenceGenome.fromJSON(jsonConfig))

def pyRemoveReference(name: String): Unit =
references.remove(name)
removeReference(name)

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))
backend.withExecuteContext { ctx =>
references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile))
}

def pyRemoveLiftover(name: String, destRGName: String): Unit =
references(name).removeLiftover(destRGName)

private[this] def addReference(rg: ReferenceGenome): Unit =
ReferenceGenome.addFatalOnCollision(references, FastSeq(rg))

private[this] def removeReference(name: String): Unit =
references -= name

def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
Expand All @@ -231,7 +243,7 @@ trait Py4JBackendExtensions { this: Backend =>
}

def loadReferencesFromDataset(path: String): Array[Byte] =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
ReferenceGenome.addFatalOnCollision(references, rgs)

Expand All @@ -245,7 +257,7 @@ trait Py4JBackendExtensions { this: Backend =>
f: ExecuteContext => T
)(implicit E: Enclosing
): T =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val tempFileManager = longLifeTempFileManager
if (selfContainedExecution && tempFileManager != null) f(ctx)
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)
Expand Down
28 changes: 13 additions & 15 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
Expand Down Expand Up @@ -93,7 +94,16 @@ object ServiceBackend {
rpcConfig.custom_references.map(ReferenceGenome.fromJSON),
)

val backend = new ServiceBackend(
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
liftoversForSource.foreach { case (destGenome, chainFile) =>
references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile))
}
}
rpcConfig.sequences.foreach { case (rg, seq) =>
references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index))
}

new ServiceBackend(
JarUrl(jarLocation),
name,
theHailClassLoader,
Expand All @@ -106,27 +116,14 @@ object ServiceBackend {
backendContext,
scratchDir,
)

backend.withExecuteContext { ctx =>
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
liftoversForSource.foreach { case (destGenome, chainFile) =>
references(sourceGenome).addLiftover(ctx, chainFile, destGenome)
}
}
rpcConfig.sequences.foreach { case (rg, seq) =>
references(rg).addSequence(ctx, seq.fasta, seq.index)
}
}

backend
}
}

class ServiceBackend(
val jarSpec: JarSpec,
var name: String,
val theHailClassLoader: HailClassLoader,
override val references: mutable.Map[String, ReferenceGenome],
val references: mutable.Map[String, ReferenceGenome],
val batchClient: BatchClient,
val batchConfig: BatchConfig,
val flags: HailFeatureFlags,
Expand Down Expand Up @@ -397,6 +394,7 @@ class ServiceBackend(
tmpdir,
"file:///tmp",
this,
references.toMap,
fs,
timer,
null,
Expand Down
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class SparkBackend(
new HadoopFS(new SerializableHadoopConfiguration(conf))
}

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()

override val longLifeTempFileManager: TempFileManager =
Expand Down Expand Up @@ -375,6 +376,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references.toMap,
fs,
region,
timer,
Expand All @@ -394,6 +396,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references.toMap,
fs,
timer,
null,
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) {
}

def referenceGenomes(): IndexedSeq[ReferenceGenome] =
rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name)
rgContainers.keys.map(ctx.references(_)).toIndexedSeq.sortBy(_.name)

def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] =
rgContainers.toFastSeq.sortBy(_._1).map(_._2)
Expand Down
4 changes: 2 additions & 2 deletions hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
BoolValue.fromComparison(l, op).restrict(keySet)
case Contig(rgStr) =>
// locus contig equality comparison
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match {
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.references(rgStr)) match {
case Some(i) =>
val b = BoolValue(
KeySet(i),
Expand All @@ -671,7 +671,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
case Position(rgStr) =>
// locus position comparison
val posBoolValue = BoolValue.fromComparison(l, op)
val rg = ctx.getReference(rgStr)
val rg = ctx.references(rgStr)
val b = BoolValue(
KeySet(liftPosIntervalsToLocus(posBoolValue.trueBound, rg, ctx)),
KeySet(liftPosIntervalsToLocus(posBoolValue.falseBound, rg, ctx)),
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/MatrixValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ case class MatrixValue(
ReferenceGenome.exportReferences(
fs,
refPath,
ReferenceGenome.getReferences(t).map(ctx.getReference(_)),
ReferenceGenome.getReferences(t).map(ctx.references(_)),
)
}

Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ object MatrixPLINKReader {
implicit val formats: Formats = DefaultFormats
val params = jv.extract[MatrixPLINKReaderParameters]

val referenceGenome = params.rg.map(ctx.getReference)
val referenceGenome = params.rg.map(ctx.references)
referenceGenome.foreach(_.validateContigRemap(params.contigRecoding))

val locusType = TLocus.schemaFromRG(params.rg)
Expand Down
Loading