diff --git a/hail/Makefile b/hail/Makefile index ab4f1c745c3..8b9488b2447 100644 --- a/hail/Makefile +++ b/hail/Makefile @@ -182,7 +182,7 @@ $(FAST_PYTHON_JAR_EXTRA_CLASSPATH): $(EXTRA_CLASSPATH) cp $(EXTRA_CLASSPATH) $@ .PHONY: pytest -pytest: install-editable +pytest: install cd python && \ $(HAIL_PYTHON3) -m pytest \ -Werror:::hail -Werror:::hailtop -Werror::ResourceWarning \ diff --git a/hail/hail/package.mill b/hail/hail/package.mill index 2a60c3d4696..3f54a9916c8 100644 --- a/hail/hail/package.mill +++ b/hail/hail/package.mill @@ -191,7 +191,7 @@ trait RootHailModule extends CrossScalaModule, HailScalaModule: ) override def runMvnDeps: T[Seq[Dep]] = - outer.runMvnDeps() + outer.runMvnDeps() ++ outer.compileMvnDeps() override def assemblyRules: Seq[Rule] = outer.assemblyRules diff --git a/hail/hail/src/is/hail/asm4s/AsmFunction.scala b/hail/hail/src/is/hail/asm4s/AsmFunction.scala index b08cde595c4..48ed2d52d19 100644 --- a/hail/hail/src/is/hail/asm4s/AsmFunction.scala +++ b/hail/hail/src/is/hail/asm4s/AsmFunction.scala @@ -1,6 +1,7 @@ package is.hail.asm4s import is.hail.annotations.Region +import is.hail.types.physical.stypes.interfaces.NoBoxLongIterator trait AsmFunction0[R] { def apply(): R } trait AsmFunction1[A, R] { def apply(a: A): R } @@ -112,3 +113,11 @@ trait AsmFunction3RegionIteratorJLongBooleanLong { trait AsmFunction3RegionLongIteratorJLongBoolean { def apply(r: Region, a: Long, b: Iterator[java.lang.Long]): Boolean } + +trait AsmFunction3RegionLongLongIteratorJLong { + def apply(r: Region, a: Long, b: Long): Iterator[java.lang.Long] +} + +trait AsmFunction3RegionLongNoBoxLongIteratorIteratorJLong { + def apply(r: Region, a: Long, b: NoBoxLongIterator): Iterator[java.lang.Long] +} diff --git a/hail/hail/src/is/hail/expr/ir/Compile.scala b/hail/hail/src/is/hail/expr/ir/Compile.scala index 5f14d92c55c..02b1153e44a 100644 --- a/hail/hail/src/is/hail/expr/ir/Compile.scala +++ b/hail/hail/src/is/hail/expr/ir/Compile.scala @@ -3,19 +3,15 @@ package is.hail.expr.ir import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailTaskContext} -import is.hail.collection.FastSeq -import is.hail.collection.implicits.toRichIterable import is.hail.expr.ir.agg.AggStateSig import is.hail.expr.ir.defs.In import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.ir.streams.EmitStream import is.hail.io.fs.FS -import is.hail.rvd.RVDContext import is.hail.types.physical.{PStruct, PType} -import is.hail.types.physical.stypes.{ - PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType, -} -import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream} +import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType} +import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStreamConcrete} +import is.hail.types.virtual.TStream import scala.collection.mutable @@ -29,6 +25,28 @@ case class CompileCacheKey( body: IR, ) +private[ir] class NoBoxLongIteratorAdapter(it: NoBoxLongIterator) extends Iterator[java.lang.Long] { + private var _stepped = false + private var _hasNext = false + private var _value: Long = 0L + + override def hasNext: Boolean = { + if (!_stepped) { + _value = it.next() + _hasNext = !it.eos + _stepped = true + if (!_hasNext) it.close() + } + _hasNext + } + + override def next(): java.lang.Long = { + if (!hasNext) Iterator.empty.next(): Unit + _stepped = false + _value + } +} + private[ir] trait CompileOps { type Compiled[A] = (HailClassLoader, FS, HailTaskContext, Region) => A @@ -106,10 +124,12 @@ private[ir] trait CompileOps { ctx.CompileCache.getOrElseUpdate( key, { - val lowered = + val lowered = ForwardLets( + ctx, LoweringPipeline.compileLowerer(ctx, ir) .asInstanceOf[IR] - .noSharing(ctx) + .noSharing(ctx), + ) val fb = EmitFunctionBuilder[F]( @@ -130,231 +150,37 @@ private[ir] trait CompileOps { ) val emitContext = EmitContext.analyze(ctx, lowered) - val rt = Emit(emitContext, lowered, fb, expectedCodeReturnType, params.length, aggSigs) - (rt, fb.resultWithIndex(print)) + lowered.typ match { + case _: TStream => + var eltPType: PType = null + + fb.emitWithBuilder[Iterator[_]] { cb => + val mb = fb.apply_method + val env = EmitEnv( + Env.empty, + (0 until params.length).map(i => mb.storeEmitParamAsField(cb, i + 2)), + ) + val (ept, iterEmitCode) = EmitStream.produceIterator(emitContext, lowered, cb, env) + eltPType = ept + val noBoxIter = iterEmitCode.getOrAssert(cb).asInstanceOf[SStreamConcrete].it + cb += noBoxIter.invoke[Region, Region, Unit]( + "init", + fb.partitionRegion, + mb.getCodeParam[Region](1), + ) + Code.newInstance[NoBoxLongIteratorAdapter, NoBoxLongIterator](noBoxIter) + } + + ( + Some(PTypeReferenceSingleCodeType(eltPType.asInstanceOf[PStruct])), + fb.resultWithIndex(print).asInstanceOf[Compiled[F with Mixin]], + ) + case _ => + val rt = + Emit(emitContext, lowered, fb, expectedCodeReturnType, params.length, aggSigs) + (rt, fb.resultWithIndex(print)) + } }, ).asInstanceOf[CompiledFunction[F with Mixin]] } } - -object CompileIterator { - - private trait StepFunctionBase { - def loadAddress(): Long - } - - private trait TableStageToRVDStepFunction extends StepFunctionBase { - def apply(o: Object, a: Long, b: Long): Boolean - - def setRegions(outerRegion: Region, eltRegion: Region): Unit - } - - private trait TMPStepFunction extends StepFunctionBase { - def apply(o: Object, a: Long, b: NoBoxLongIterator): Boolean - - def setRegions(outerRegion: Region, eltRegion: Region): Unit - } - - abstract private class LongIteratorWrapper extends Iterator[java.lang.Long] { - def step(): Boolean - - protected val stepFunction: StepFunctionBase - private var _stepped = false - private var _hasNext = false - - override def hasNext: Boolean = { - if (!_stepped) { - _hasNext = step() - _stepped = true - } - _hasNext - } - - override def next(): java.lang.Long = { - if (!hasNext) Iterator.empty.next(): Unit // throw - _stepped = false - stepFunction.loadAddress() - } - } - - private def compileStepper[F >: Null <: StepFunctionBase: TypeInfo]( - ctx: ExecuteContext, - body: IR, - argTypeInfo: Array[ParamType], - printWriter: Option[PrintWriter], - ): (PType, Compiled[F]) = { - - val fb = EmitFunctionBuilder.apply[F]( - ctx, - s"stream_${body.getClass.getSimpleName}", - argTypeInfo.toFastSeq, - CodeParamType(BooleanInfo), - Some("Emit.scala"), - ) - val outerRegionField = fb.genFieldThisRef[Region]("outerRegion") - val eltRegionField = fb.genFieldThisRef[Region]("eltRegion") - val setF = fb.newEmitMethod( - "setRegions", - FastSeq(CodeParamType(typeInfo[Region]), CodeParamType(typeInfo[Region])), - CodeParamType(typeInfo[Unit]), - ) - setF.emit(Code( - outerRegionField := setF.getCodeParam[Region](1), - eltRegionField := setF.getCodeParam[Region](2), - )) - - val stepF = fb.apply_method - val stepFECB = stepF.ecb - - val outerRegion = outerRegionField - - val ir = LoweringPipeline.compileLowerer(ctx, body).asInstanceOf[IR].noSharing(ctx) - TypeCheck(ctx, ir) - - var elementAddress: Settable[Long] = null - var returnType: PType = null - - stepF.emitWithBuilder[Boolean] { cb => - val emitContext = EmitContext.analyze(ctx, ir) - val emitter = new Emit(emitContext, stepFECB) - - val env = EmitEnv( - Env.empty, - argTypeInfo.indices.filter(i => argTypeInfo(i).isInstanceOf[EmitParamType]).map(i => - stepF.getEmitParam(cb, i + 1) - ), - ) - val optStream = EmitCode.fromI(stepF)(cb => - EmitStream.produce(emitter, ir, cb, cb.emb, outerRegion, env, None) - ) - returnType = optStream.st.asInstanceOf[SStream].elementEmitType.storageType.setRequired(true) - - elementAddress = stepF.genFieldThisRef[Long]("elementAddr") - - val didSetup = stepF.genFieldThisRef[Boolean]("didSetup") - stepF.cb.emitInit(didSetup := false) - - val eosField = stepF.genFieldThisRef[Boolean]("eos") - - val producer = optStream.pv.asStream.getProducer(cb.emb) - - val ret = cb.newLocal[Boolean]("stepf_ret") - val Lreturn = CodeLabel() - - cb.if_( - !didSetup, { - optStream.toI(cb).getOrAssert(cb): Unit // handle missing, but bound stream producer above - - cb.assign(producer.elementRegion, eltRegionField) - producer.initialize(cb, outerRegion) - cb.assign(didSetup, true) - cb.assign(eosField, false) - }, - ) - - cb.if_( - eosField, { - cb.assign(ret, false) - cb.goto(Lreturn) - }, - ) - - cb.goto(producer.LproduceElement) - - stepF.implementLabel(producer.LendOfStream) { cb => - producer.close(cb) - cb.assign(eosField, true) - cb.assign(ret, false) - cb.goto(Lreturn) - } - - stepF.implementLabel(producer.LproduceElementDone) { cb => - val pc = producer.element.toI(cb).getOrAssert(cb) - cb.assign(elementAddress, returnType.store(cb, producer.elementRegion, pc, false)) - cb.assign(ret, true) - cb.goto(Lreturn) - } - - cb.define(Lreturn) - ret - } - - val getMB = fb.newEmitMethod("loadAddress", FastSeq(), LongInfo) - getMB.emit(elementAddress.load()) - - (returnType, fb.resultWithIndex(printWriter)) - } - - def forTableMapPartitions( - ctx: ExecuteContext, - typ0: PStruct, - streamElementType: PType, - ir: IR, - ): ( - PType, - (HailClassLoader, FS, HailTaskContext, RVDContext, Long, NoBoxLongIterator) => Iterator[java.lang.Long], - ) = { - assert(typ0.required) - assert(streamElementType.required) - val (eltPType, makeStepper) = compileStepper[TMPStepFunction]( - ctx, - ir, - Array[ParamType]( - CodeParamType(typeInfo[Object]), - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(typ0)), - SingleCodeEmitParamType(true, StreamSingleCodeType(true, streamElementType, true)), - ), - None, - ) - ( - eltPType, - (theHailClassLoader, fs, htc, consumerCtx, v0, part) => { - val outerStepFunction = - makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) - outerStepFunction.setRegions(consumerCtx.partitionRegion, consumerCtx.region) - new LongIteratorWrapper { - val stepFunction: TMPStepFunction = outerStepFunction - - override def step(): Boolean = stepFunction.apply(null, v0, part) - } - }, - ) - } - - def forTableStageToRVD( - ctx: ExecuteContext, - ctxType: PStruct, - bcValsType: PType, - ir: IR, - ): ( - PType, - (HailClassLoader, FS, HailTaskContext, RVDContext, Long, Long) => Iterator[java.lang.Long], - ) = { - assert(ctxType.required) - assert(bcValsType.required) - val (eltPType, makeStepper) = compileStepper[TableStageToRVDStepFunction]( - ctx, - ir, - Array[ParamType]( - CodeParamType(typeInfo[Object]), - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(ctxType)), - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(bcValsType)), - ), - None, - ) - ( - eltPType, - (theHailClassLoader, fs, htc, consumerCtx, v0, v1) => { - val outerStepFunction = - makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) - outerStepFunction.setRegions(consumerCtx.partitionRegion, consumerCtx.region) - new LongIteratorWrapper { - val stepFunction: TableStageToRVDStepFunction = outerStepFunction - - override def step(): Boolean = stepFunction.apply(null, v0, v1) - } - }, - ) - } - -} diff --git a/hail/hail/src/is/hail/expr/ir/Emit.scala b/hail/hail/src/is/hail/expr/ir/Emit.scala index a3928983a61..2a4ed63823b 100644 --- a/hail/hail/src/is/hail/expr/ir/Emit.scala +++ b/hail/hail/src/is/hail/expr/ir/Emit.scala @@ -3574,34 +3574,25 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { container: Option[AggContainer], loopEnv: Option[Env[LoopRef]], ): EmitEnv = { + assert(let.bindings.forall(!_.value.typ.isInstanceOf[TStream])) + def emitI(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, r: Value[Region]): IEmitCode = - if (ir.typ.isInstanceOf[TStream]) - EmitStream.produce(this, ir, cb, cb.emb, r, env, container) - else this.emitI(ir, cb, r, env, container, loopEnv) + this.emitI(ir, cb, r, env, container, loopEnv) def emitVoid(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, r: Value[Region]): Unit = this.emitVoid(cb, ir, r, env, container, loopEnv) - val uses: mutable.Set[Name] = - ctx.usesAndDefs.uses.get(let) match { - case Some(refs) => refs.map(_.t.name) - case None => mutable.Set.empty - } - /* Emit a sequence of bindings into a code builder. Each is added to the environment of all - * following bindings. Any bindings which is unused and has no side effects is skipped (this is - * mostly an optimization, but it is important not to emit unused streams). */ + * following bindings. */ def emitChunk(cb: EmitCodeBuilder, bindings: Seq[Binding], env: EmitEnv, r: Value[Region]) : EmitEnv = bindings.foldLeft(env) { case (newEnv, Binding(name, ir, Scope.EVAL)) => if (ir.typ == TVoid) { emitVoid(ir, cb, newEnv, r) newEnv - } else if (IsPure(ir) && !uses.contains(name)) { - newEnv } else { val value = emitI(ir, cb, newEnv, r) - val memo = cb.memoizeMaybeStreamValue(value, s"let_$name") + val memo = cb.memoizeField(value, s"let_$name") newEnv.bind(name, memo) } } @@ -3633,7 +3624,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { } def cantEmitInSeparateMethod(ir: IR): Boolean = - ir.typ.isInstanceOf[TStream] || ctx.inLoopCriticalPath.contains(ir) + ctx.inLoopCriticalPath.contains(ir) // end of bindings, emit any pending chunk and return the final environment if (pos == let.bindings.length) { @@ -3643,23 +3634,12 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { return env } - val Binding(curName, curIR, Scope.EVAL) = let.bindings(pos) + val Binding(_, curIR, Scope.EVAL) = let.bindings(pos) - // skip over unused streams - if (curIR.typ.isInstanceOf[TStream] && !uses.contains(curName)) { - go(env, chunkStart, pos + 1, chunkSize, groupIdx) - } else if (chunkSize == 16 || (chunkSize > 0 && cantEmitInSeparateMethod(curIR))) { - /* emit the current chunk if it's either max size, or broken by a stream or other control - * flow */ + if (chunkSize == 16 || (chunkSize > 0 && cantEmitInSeparateMethod(curIR))) { + // emit the current chunk if it's either max size, or broken by control flow val newEnv = emitChunkInSeparateMethod() go(newEnv, pos, pos, 0, groupIdx + 1) - } else if (curIR.typ.isInstanceOf[TStream]) { - // emit a stream, assuming we've already emitted any prior chunk - assert(chunkSize == 0) // no pending bindings - val value = emitI(curIR, cb, env, r) - val memo = cb.memoizeMaybeStreamValue(value, s"let_$curName") - val newEnv = env.bind(curName, memo) - go(newEnv, pos + 1, pos + 1, 0, groupIdx) } else { // add cur binding to pending chunk go(env, chunkStart, pos + 1, chunkSize + 1, groupIdx) diff --git a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala index 46543921c12..3d3f6fe7757 100644 --- a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala +++ b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir.defs.{BaseRef, Binding, Block, In, Ref, Str} -import is.hail.types.virtual.TVoid +import is.hail.types.virtual.{TStream, TVoid} import scala.collection.Set @@ -19,6 +19,7 @@ object ForwardLets { IsPure(value) && ( value.isInstanceOf[Ref] || value.isInstanceOf[In] || + value.typ.isInstanceOf[TStream] || (IsConstant(value) && !value.isInstanceOf[Str]) || refs.isEmpty || (refs.size == 1 && diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index 8e2a362d6d8..715819be288 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -986,27 +986,22 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val rowPType = rvd.rowPType val globalPType = globals.t - val (newRowPType: PStruct, makeIterator) = CompileIterator.forTableMapPartitions( - ctx, - globalPType, - rowPType, - Subst( - body, - BindingEnv(Env( - globalName -> In( - 0, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globalPType)), - ), - partitionStreamName -> In( - 1, - SingleCodeEmitParamType( - true, - StreamSingleCodeType(requiresMemoryManagementPerElement = true, rowPType, true), - ), - ), - )), - ), + val globalEmitType = SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globalPType)) + val streamEmitType = SingleCodeEmitParamType( + true, + StreamSingleCodeType(requiresMemoryManagementPerElement = true, rowPType, true), ) + val (Some(PTypeReferenceSingleCodeType(newRowPType: PStruct)), makeStreamFn) = + Compile[AsmFunction3RegionLongNoBoxLongIteratorIteratorJLong]( + ctx, + FastSeq( + (globalName, globalEmitType), + (partitionStreamName, streamEmitType), + ), + FastSeq(classInfo[Region], LongInfo, classInfo[NoBoxLongIterator]), + classInfo[Iterator[_]], + body, + ) val globalsBc = globals.broadcast(ctx.theHailClassLoader) @@ -1028,14 +1023,18 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow override def close(): Unit = () } - makeIterator( + val f = makeStreamFn( hcl, fsBc.value, SparkTaskContext.get(), - consumerCtx, + consumerCtx.partitionRegion, + ) + f( + consumerCtx.region, globalsBc.value.readRegionValue(consumerCtx.partitionRegion, hcl), boxedPartition, - ).map(l => l.longValue()) + ) + .map(l => l.longValue()) } val newRVD = rvd.repartition(ctx, rvd.partitioner.strictify(allowedOverlap)) diff --git a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala index 8de56a85f15..76f46e94e2f 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -177,23 +177,25 @@ object TableStageToRVD { baos.toByteArray } - val (newRowPType: PStruct, makeIterator) = CompileIterator.forTableStageToRVD( - ctx, - decodedContextPType, - decodedBcValsPType, - Let( - ts.broadcastVals.map(_._1).map(bcVal => - bcVal -> GetField( - In(1, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedBcValsPType))), - bcVal.str, - ) - ), - ts.partition(In( - 0, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedContextPType)), - )), + val ctxParam = Ref(freshName(), decodedContextPType.virtualType) + val bcParam = Ref(freshName(), decodedBcValsPType.virtualType) + val ctxEmitType = + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedContextPType)) + val bcEmitType = SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedBcValsPType)) + val bodyIR = Let( + ts.broadcastVals.map(_._1).map(bcVal => + bcVal -> GetField(bcParam, bcVal.str) ), + ts.partition(ctxParam), ) + val (Some(PTypeReferenceSingleCodeType(newRowPType: PStruct)), makeStreamFn) = + Compile[AsmFunction3RegionLongLongIteratorJLong]( + ctx, + FastSeq((ctxParam.name, ctxEmitType), (bcParam.name, bcEmitType)), + FastSeq(classInfo[Region], LongInfo, LongInfo), + classInfo[Iterator[_]], + bodyIR, + ) val fsBc = ctx.fsBc @@ -215,14 +217,13 @@ object TableStageToRVD { hcl, ) .readRegionValue(rvdContext.partitionRegion) - makeIterator( + val f = makeStreamFn( hcl, fsBc.value, SparkTaskContext.get(), - rvdContext, - decodedContext, - decodedBroadcastVals, + rvdContext.partitionRegion, ) + f(rvdContext.region, decodedContext, decodedBroadcastVals) .map(_.longValue()) } diff --git a/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala b/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala index d066e5ce1ef..085521f146a 100644 --- a/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala +++ b/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala @@ -142,6 +142,174 @@ abstract class StreamProducer { } object EmitStream { + + private[ir] def produceIterator( + emitCtx: EmitContext, + streamIR: IR, + cb: EmitCodeBuilder, + env: EmitEnv, + ): (PType, IEmitCode) = { + val ecb = cb.emb.genEmitClass[NoBoxLongIterator]("stream_to_iter") + ecb.cb.addInterface(typeInfo[MissingnessAsMethod].iname) + + val fv = FreeVariables(streamIR, false, false).eval + val (envParamTypes, envParams, restoreEnv) = env.asParams(fv) + + val isMissing = ecb.genFieldThisRef[Boolean]("isMissing") + val eosField = ecb.genFieldThisRef[Boolean]("eos") + val outerRegionField = ecb.genFieldThisRef[Region]("outer") + val eltRegionField = ecb.genFieldThisRef[Region]("eltRegion") + + ecb.makeAddPartitionRegion() + + var producer: StreamProducer = null + var producerRequired: Boolean = false + var elementPType: PType = null + + val next = ecb.newEmitMethod("next", FastSeq[ParamType](), LongInfo) + val ctor = ecb.newEmitMethod( + "", + FastSeq[ParamType](typeInfo[Region], arrayInfo[Long]) ++ envParamTypes, + UnitInfo, + ) + + // Constructor: stores env params as fields, sets up partition region and literals. + // Does NOT produce the stream — that happens in next() to keep all stream code in one method. + var savedEnv: EmitEnv = null + ctor.voidWithBuilder { cb => + val L = new lir.Block() + L.append( + lir.methodStmt( + INVOKESPECIAL, + "java/lang/Object", + "", + "()V", + false, + UnitInfo, + FastSeq(lir.load(ctor.mb.this_.asInstanceOf[LocalRef[_]].l)), + ) + ) + cb += new VCode(L, L, null) + + savedEnv = restoreEnv(cb, 3) + + val self = + cb.memoize( + Code.checkcast[FunctionWithPartitionRegion]((ctor.getCodeParam(0)(ecb.cb.ti))) + ) + + ecb.setLiteralsArray(cb, ctor.getCodeParam[Array[Long]](2)) + val partitionRegion = cb.memoize(ctor.getCodeParam[Region](1)) + cb += self.invoke[Region, Unit]("addPartitionRegion", partitionRegion) + cb += self.invoke[RegionPool, Unit]("setPool", partitionRegion.getPool()) + } + + // All stream production and iteration code lives in next(), avoiding method-boundary issues + // with SStreamControlFlow producers (which assert producer.method == consuming method). + next.emitWithBuilder { cb => + val optStream = EmitCode.fromI(next)(cb => + EmitStream.produce( + new Emit(emitCtx, ecb), streamIR, cb, cb.emb, outerRegionField, savedEnv, None, + ) + ) + producerRequired = optStream.required + producer = optStream.pv.asStream.getProducer(cb.emb) + elementPType = producer.element.emitType.storageType.setRequired(true) + + val didSetup = next.genFieldThisRef[Boolean]("didSetup") + val ret = cb.newLocal[Long]("ret") + val Lret = CodeLabel() + + cb.if_( + !didSetup, { + optStream.toI(cb).consume( + cb, + if (!producerRequired) cb.assign(isMissing, true), + { stream => + if (!producerRequired) cb.assign(isMissing, false) + }, + ) + cb.assign(producer.elementRegion, eltRegionField) + producer.initialize(cb, outerRegionField) + cb.assign(didSetup, true) + cb.assign(eosField, false) + }, + ) + + cb.if_( + eosField, { + cb.goto(Lret) + }, + ) + + cb.goto(producer.LproduceElement) + + next.implementLabel(producer.LendOfStream) { cb => + cb.assign(eosField, true) + cb.goto(Lret) + } + + next.implementLabel(producer.LproduceElementDone) { cb => + producer.element.toI(cb) + .consume( + cb, + cb.assign(ret, 0L), + value => + cb.assign(ret, elementPType.store(cb, producer.elementRegion, value, false)), + ) + cb.goto(Lret) + } + + cb.define(Lret) + ret + } + + val init = + ecb.newEmitMethod( + "init", + FastSeq[ParamType](typeInfo[Region], typeInfo[Region]), + UnitInfo, + ) + init.voidWithBuilder { cb => + cb.assign(outerRegionField, init.getCodeParam[Region](1)) + cb.assign(eltRegionField, init.getCodeParam[Region](2)) + } + + val isEOS = ecb.newEmitMethod("eos", FastSeq[ParamType](), BooleanInfo) + isEOS.emitWithBuilder[Boolean](cb => eosField) + + val isMissingMethod = ecb.newEmitMethod("isMissing", FastSeq[ParamType](), BooleanInfo) + isMissingMethod.emitWithBuilder[Boolean](cb => isMissing) + + val close = ecb.newEmitMethod("close", FastSeq[ParamType](), UnitInfo) + close.voidWithBuilder(cb => producer.close(cb)) + + val obj = cb.memoize(Code.newInstance( + ecb.cb, + ctor.mb, + FastSeq(cb.emb.partitionRegion.get, cb.emb.ecb.literalsArray().get) ++ envParams.map( + _.get + ), + )) + + val iter = cb.emb.genFieldThisRef[NoBoxLongIterator]("iter") + cb.assign(iter, Code.checkcast[NoBoxLongIterator](obj)) + val iEmitCode = IEmitCode( + cb, + if (producerRequired) false + else Code.checkcast[MissingnessAsMethod](obj).invoke[Boolean]("isMissing"), + new SStreamConcrete( + SStreamIteratorLong( + producer.element.required, + elementPType, + producer.requiresMemoryManagementPerElement, + ), + iter, + ), + ) + (elementPType, iEmitCode) + } + private[ir] def produce( emitter: Emit[_], streamIR: IR, @@ -151,7 +319,6 @@ object EmitStream { env: EmitEnv, container: Option[AggContainer], ): IEmitCode = { - @nowarn("cat=unused-locals&msg=local default argument") def emitVoid( ir: IR, @@ -170,153 +337,10 @@ object EmitStream { container: Option[AggContainer] = container, ): IEmitCode = ir.typ match { - case _: TStream => produce(ir, cb, cb.emb, region, env, container) + case _: TStream => produce(ir, cb, region = region, env = env, container = container) case _ => emitter.emitI(ir, cb, region, env, container, None) } - // returns IEmitCode of SStreamConcrete - def produceIterator( - streamIR: IR, - elementPType: PType, - cb: EmitCodeBuilder, - env: EmitEnv, - ): IEmitCode = { - val ecb = cb.emb.genEmitClass[NoBoxLongIterator]("stream_to_iter") - ecb.cb.addInterface(typeInfo[MissingnessAsMethod].iname) - - val fv = FreeVariables(streamIR, false, false).eval - val (envParamTypes, envParams, restoreEnv) = env.asParams(fv) - - val isMissing = ecb.genFieldThisRef[Boolean]("isMissing") - val eosField = ecb.genFieldThisRef[Boolean]("eos") - val outerRegionField = ecb.genFieldThisRef[Region]("outer") - - ecb.makeAddPartitionRegion() - - var producer: StreamProducer = null - var producerRequired: Boolean = false - - val next = ecb.newEmitMethod("next", FastSeq[ParamType](), LongInfo) - val ctor = ecb.newEmitMethod( - "", - FastSeq[ParamType](typeInfo[Region], arrayInfo[Long]) ++ envParamTypes, - UnitInfo, - ) - ctor.voidWithBuilder { cb => - val L = new lir.Block() - L.append( - lir.methodStmt( - INVOKESPECIAL, - "java/lang/Object", - "", - "()V", - false, - UnitInfo, - FastSeq(lir.load(ctor.mb.this_.asInstanceOf[LocalRef[_]].l)), - ) - ) - cb += new VCode(L, L, null) - - val newEnv = restoreEnv(cb, 3) - val s = EmitStream.produce( - new Emit(emitter.ctx, ecb), - streamIR, - cb, - next, - outerRegionField, - newEnv, - None, - ) - producerRequired = s.required - s.consume( - cb, - if (!producerRequired) cb.assign(isMissing, true), - { stream => - if (!producerRequired) cb.assign(isMissing, false) - producer = stream.asStream.getProducer(next) - }, - ) - - val self = - cb.memoize( - Code.checkcast[FunctionWithPartitionRegion]((ctor.getCodeParam(0)(ecb.cb.ti))) - ) - - ecb.setLiteralsArray(cb, ctor.getCodeParam[Array[Long]](2)) - val partitionRegion = cb.memoize(ctor.getCodeParam[Region](1)) - cb += self.invoke[Region, Unit]("addPartitionRegion", partitionRegion) - cb += self.invoke[RegionPool, Unit]("setPool", partitionRegion.getPool()) - } - - next.emitWithBuilder { cb => - val ret = cb.newLocal[Long]("ret") - val Lret = CodeLabel() - cb.goto(producer.LproduceElement) - cb.define(producer.LproduceElementDone) - producer.element.toI(cb) - .consume( - cb, - cb.assign(ret, 0L), - value => cb.assign(ret, elementPType.store(cb, producer.elementRegion, value, false)), - ) - cb.goto(Lret) - cb.define(producer.LendOfStream) - cb.assign(eosField, true) - - cb.define(Lret) - ret - } - - val init = - ecb.newEmitMethod( - "init", - FastSeq[ParamType](typeInfo[Region], typeInfo[Region]), - UnitInfo, - ) - init.voidWithBuilder { cb => - val outerRegion = init.getCodeParam[Region](1) - val eltRegion = init.getCodeParam[Region](2) - - cb.assign(producer.elementRegion, eltRegion) - cb.assign(outerRegionField, outerRegion) - producer.initialize(cb, outerRegion) - cb.assign(eosField, false) - } - - val isEOS = ecb.newEmitMethod("eos", FastSeq[ParamType](), BooleanInfo) - isEOS.emitWithBuilder[Boolean](cb => eosField) - - val isMissingMethod = ecb.newEmitMethod("isMissing", FastSeq[ParamType](), BooleanInfo) - isMissingMethod.emitWithBuilder[Boolean](cb => isMissing) - - val close = ecb.newEmitMethod("close", FastSeq[ParamType](), UnitInfo) - close.voidWithBuilder(cb => producer.close(cb)) - - val obj = cb.memoize(Code.newInstance( - ecb.cb, - ctor.mb, - FastSeq(cb.emb.partitionRegion.get, cb.emb.ecb.literalsArray().get) ++ envParams.map( - _.get - ), - )) - - val iter = cb.emb.genFieldThisRef[NoBoxLongIterator]("iter") - cb.assign(iter, Code.checkcast[NoBoxLongIterator](obj)) - IEmitCode( - cb, - if (producerRequired) false - else Code.checkcast[MissingnessAsMethod](obj).invoke[Boolean]("isMissing"), - new SStreamConcrete( - SStreamIteratorLong( - producer.element.required, - elementPType, - producer.requiresMemoryManagementPerElement, - ), - iter, - ), - ) - } - def produce( streamIR: IR, cb: EmitCodeBuilder, @@ -3407,12 +3431,13 @@ object EmitStream { var streamRequiresMemoryManagement = false cb.while_( idx < nStreams, { - val iter = produceIterator( + val (_, iterCode) = EmitStream.produceIterator( + emitter.ctx, makeProducer, - eltType, cb, env.bind(ctxName, cb.memoize(contextsArray.loadElement(cb, idx))), ) + val iter = iterCode .getOrFatal(cb, "streams in zipJoinProducers cannot be missing") .asInstanceOf[SStreamConcrete] streamRequiresMemoryManagement = iter.st.requiresMemoryManagement diff --git a/hail/hail/src/is/hail/types/physical/stypes/SingleCodeSCode.scala b/hail/hail/src/is/hail/types/physical/stypes/SingleCodeSCode.scala index 6035c35fae4..694d5bfe72a 100644 --- a/hail/hail/src/is/hail/types/physical/stypes/SingleCodeSCode.scala +++ b/hail/hail/src/is/hail/types/physical/stypes/SingleCodeSCode.scala @@ -18,6 +18,7 @@ object SingleCodeType { case TFloat64 => DoubleInfo case TBoolean => BooleanInfo case TVoid => UnitInfo + case _: TStream => classInfo[Iterator[_]] case _ => LongInfo // all others passed as ptype references } diff --git a/hail/hail/test/resources/log4j2.properties b/hail/hail/test/resources/log4j2-test.properties similarity index 100% rename from hail/hail/test/resources/log4j2.properties rename to hail/hail/test/resources/log4j2-test.properties