Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hail/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ $(FAST_PYTHON_JAR_EXTRA_CLASSPATH): $(EXTRA_CLASSPATH)
cp $(EXTRA_CLASSPATH) $@

.PHONY: pytest
pytest: install-editable
pytest: install
cd python && \
$(HAIL_PYTHON3) -m pytest \
-Werror:::hail -Werror:::hailtop -Werror::ResourceWarning \
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ trait RootHailModule extends CrossScalaModule, HailScalaModule:
)

override def runMvnDeps: T[Seq[Dep]] =
outer.runMvnDeps()
outer.runMvnDeps() ++ outer.compileMvnDeps()

override def assemblyRules: Seq[Rule] =
outer.assemblyRules
9 changes: 9 additions & 0 deletions hail/hail/src/is/hail/asm4s/AsmFunction.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package is.hail.asm4s

import is.hail.annotations.Region
import is.hail.types.physical.stypes.interfaces.NoBoxLongIterator

trait AsmFunction0[R] { def apply(): R }
trait AsmFunction1[A, R] { def apply(a: A): R }
Expand Down Expand Up @@ -112,3 +113,11 @@ trait AsmFunction3RegionIteratorJLongBooleanLong {
trait AsmFunction3RegionLongIteratorJLongBoolean {
def apply(r: Region, a: Long, b: Iterator[java.lang.Long]): Boolean
}

trait AsmFunction3RegionLongLongIteratorJLong {
def apply(r: Region, a: Long, b: Long): Iterator[java.lang.Long]
}

trait AsmFunction3RegionLongNoBoxLongIteratorIteratorJLong {
def apply(r: Region, a: Long, b: NoBoxLongIterator): Iterator[java.lang.Long]
}
292 changes: 59 additions & 233 deletions hail/hail/src/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ package is.hail.expr.ir
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.{ExecuteContext, HailTaskContext}
import is.hail.collection.FastSeq
import is.hail.collection.implicits.toRichIterable
import is.hail.expr.ir.agg.AggStateSig
import is.hail.expr.ir.defs.In
import is.hail.expr.ir.lowering.LoweringPipeline
import is.hail.expr.ir.streams.EmitStream
import is.hail.io.fs.FS
import is.hail.rvd.RVDContext
import is.hail.types.physical.{PStruct, PType}
import is.hail.types.physical.stypes.{
PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType,
}
import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream}
import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType}
import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStreamConcrete}
import is.hail.types.virtual.TStream

import scala.collection.mutable

Expand All @@ -29,6 +25,28 @@ case class CompileCacheKey(
body: IR,
)

private[ir] class NoBoxLongIteratorAdapter(it: NoBoxLongIterator) extends Iterator[java.lang.Long] {
private var _stepped = false
private var _hasNext = false
private var _value: Long = 0L

override def hasNext: Boolean = {
if (!_stepped) {
_value = it.next()
_hasNext = !it.eos
_stepped = true
if (!_hasNext) it.close()
}
_hasNext
}

override def next(): java.lang.Long = {
if (!hasNext) Iterator.empty.next(): Unit
_stepped = false
_value
}
}

private[ir] trait CompileOps {

type Compiled[A] = (HailClassLoader, FS, HailTaskContext, Region) => A
Expand Down Expand Up @@ -106,10 +124,12 @@ private[ir] trait CompileOps {

ctx.CompileCache.getOrElseUpdate(
key, {
val lowered =
val lowered = ForwardLets(
ctx,
LoweringPipeline.compileLowerer(ctx, ir)
.asInstanceOf[IR]
.noSharing(ctx)
.noSharing(ctx),
)

val fb =
EmitFunctionBuilder[F](
Expand All @@ -130,231 +150,37 @@ private[ir] trait CompileOps {
)

val emitContext = EmitContext.analyze(ctx, lowered)
val rt = Emit(emitContext, lowered, fb, expectedCodeReturnType, params.length, aggSigs)
(rt, fb.resultWithIndex(print))
lowered.typ match {
case _: TStream =>
var eltPType: PType = null

fb.emitWithBuilder[Iterator[_]] { cb =>
val mb = fb.apply_method
val env = EmitEnv(
Env.empty,
(0 until params.length).map(i => mb.storeEmitParamAsField(cb, i + 2)),
)
val (ept, iterEmitCode) = EmitStream.produceIterator(emitContext, lowered, cb, env)
eltPType = ept
val noBoxIter = iterEmitCode.getOrAssert(cb).asInstanceOf[SStreamConcrete].it
cb += noBoxIter.invoke[Region, Region, Unit](
"init",
fb.partitionRegion,
mb.getCodeParam[Region](1),
)
Code.newInstance[NoBoxLongIteratorAdapter, NoBoxLongIterator](noBoxIter)
}

(
Some(PTypeReferenceSingleCodeType(eltPType.asInstanceOf[PStruct])),
fb.resultWithIndex(print).asInstanceOf[Compiled[F with Mixin]],
)
case _ =>
val rt =
Emit(emitContext, lowered, fb, expectedCodeReturnType, params.length, aggSigs)
(rt, fb.resultWithIndex(print))
}
},
).asInstanceOf[CompiledFunction[F with Mixin]]
}
}

object CompileIterator {

private trait StepFunctionBase {
def loadAddress(): Long
}

private trait TableStageToRVDStepFunction extends StepFunctionBase {
def apply(o: Object, a: Long, b: Long): Boolean

def setRegions(outerRegion: Region, eltRegion: Region): Unit
}

private trait TMPStepFunction extends StepFunctionBase {
def apply(o: Object, a: Long, b: NoBoxLongIterator): Boolean

def setRegions(outerRegion: Region, eltRegion: Region): Unit
}

abstract private class LongIteratorWrapper extends Iterator[java.lang.Long] {
def step(): Boolean

protected val stepFunction: StepFunctionBase
private var _stepped = false
private var _hasNext = false

override def hasNext: Boolean = {
if (!_stepped) {
_hasNext = step()
_stepped = true
}
_hasNext
}

override def next(): java.lang.Long = {
if (!hasNext) Iterator.empty.next(): Unit // throw
_stepped = false
stepFunction.loadAddress()
}
}

private def compileStepper[F >: Null <: StepFunctionBase: TypeInfo](
ctx: ExecuteContext,
body: IR,
argTypeInfo: Array[ParamType],
printWriter: Option[PrintWriter],
): (PType, Compiled[F]) = {

val fb = EmitFunctionBuilder.apply[F](
ctx,
s"stream_${body.getClass.getSimpleName}",
argTypeInfo.toFastSeq,
CodeParamType(BooleanInfo),
Some("Emit.scala"),
)
val outerRegionField = fb.genFieldThisRef[Region]("outerRegion")
val eltRegionField = fb.genFieldThisRef[Region]("eltRegion")
val setF = fb.newEmitMethod(
"setRegions",
FastSeq(CodeParamType(typeInfo[Region]), CodeParamType(typeInfo[Region])),
CodeParamType(typeInfo[Unit]),
)
setF.emit(Code(
outerRegionField := setF.getCodeParam[Region](1),
eltRegionField := setF.getCodeParam[Region](2),
))

val stepF = fb.apply_method
val stepFECB = stepF.ecb

val outerRegion = outerRegionField

val ir = LoweringPipeline.compileLowerer(ctx, body).asInstanceOf[IR].noSharing(ctx)
TypeCheck(ctx, ir)

var elementAddress: Settable[Long] = null
var returnType: PType = null

stepF.emitWithBuilder[Boolean] { cb =>
val emitContext = EmitContext.analyze(ctx, ir)
val emitter = new Emit(emitContext, stepFECB)

val env = EmitEnv(
Env.empty,
argTypeInfo.indices.filter(i => argTypeInfo(i).isInstanceOf[EmitParamType]).map(i =>
stepF.getEmitParam(cb, i + 1)
),
)
val optStream = EmitCode.fromI(stepF)(cb =>
EmitStream.produce(emitter, ir, cb, cb.emb, outerRegion, env, None)
)
returnType = optStream.st.asInstanceOf[SStream].elementEmitType.storageType.setRequired(true)

elementAddress = stepF.genFieldThisRef[Long]("elementAddr")

val didSetup = stepF.genFieldThisRef[Boolean]("didSetup")
stepF.cb.emitInit(didSetup := false)

val eosField = stepF.genFieldThisRef[Boolean]("eos")

val producer = optStream.pv.asStream.getProducer(cb.emb)

val ret = cb.newLocal[Boolean]("stepf_ret")
val Lreturn = CodeLabel()

cb.if_(
!didSetup, {
optStream.toI(cb).getOrAssert(cb): Unit // handle missing, but bound stream producer above

cb.assign(producer.elementRegion, eltRegionField)
producer.initialize(cb, outerRegion)
cb.assign(didSetup, true)
cb.assign(eosField, false)
},
)

cb.if_(
eosField, {
cb.assign(ret, false)
cb.goto(Lreturn)
},
)

cb.goto(producer.LproduceElement)

stepF.implementLabel(producer.LendOfStream) { cb =>
producer.close(cb)
cb.assign(eosField, true)
cb.assign(ret, false)
cb.goto(Lreturn)
}

stepF.implementLabel(producer.LproduceElementDone) { cb =>
val pc = producer.element.toI(cb).getOrAssert(cb)
cb.assign(elementAddress, returnType.store(cb, producer.elementRegion, pc, false))
cb.assign(ret, true)
cb.goto(Lreturn)
}

cb.define(Lreturn)
ret
}

val getMB = fb.newEmitMethod("loadAddress", FastSeq(), LongInfo)
getMB.emit(elementAddress.load())

(returnType, fb.resultWithIndex(printWriter))
}

def forTableMapPartitions(
ctx: ExecuteContext,
typ0: PStruct,
streamElementType: PType,
ir: IR,
): (
PType,
(HailClassLoader, FS, HailTaskContext, RVDContext, Long, NoBoxLongIterator) => Iterator[java.lang.Long],
) = {
assert(typ0.required)
assert(streamElementType.required)
val (eltPType, makeStepper) = compileStepper[TMPStepFunction](
ctx,
ir,
Array[ParamType](
CodeParamType(typeInfo[Object]),
SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(typ0)),
SingleCodeEmitParamType(true, StreamSingleCodeType(true, streamElementType, true)),
),
None,
)
(
eltPType,
(theHailClassLoader, fs, htc, consumerCtx, v0, part) => {
val outerStepFunction =
makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion)
outerStepFunction.setRegions(consumerCtx.partitionRegion, consumerCtx.region)
new LongIteratorWrapper {
val stepFunction: TMPStepFunction = outerStepFunction

override def step(): Boolean = stepFunction.apply(null, v0, part)
}
},
)
}

def forTableStageToRVD(
ctx: ExecuteContext,
ctxType: PStruct,
bcValsType: PType,
ir: IR,
): (
PType,
(HailClassLoader, FS, HailTaskContext, RVDContext, Long, Long) => Iterator[java.lang.Long],
) = {
assert(ctxType.required)
assert(bcValsType.required)
val (eltPType, makeStepper) = compileStepper[TableStageToRVDStepFunction](
ctx,
ir,
Array[ParamType](
CodeParamType(typeInfo[Object]),
SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(ctxType)),
SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(bcValsType)),
),
None,
)
(
eltPType,
(theHailClassLoader, fs, htc, consumerCtx, v0, v1) => {
val outerStepFunction =
makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion)
outerStepFunction.setRegions(consumerCtx.partitionRegion, consumerCtx.region)
new LongIteratorWrapper {
val stepFunction: TableStageToRVDStepFunction = outerStepFunction

override def step(): Boolean = stepFunction.apply(null, v0, v1)
}
},
)
}

}
Loading