diff --git a/hail/hail/src/is/hail/annotations/RegionPool.scala b/hail/hail/src/is/hail/annotations/RegionPool.scala index 83489931ad1..218f5a55e1f 100644 --- a/hail/hail/src/is/hail/annotations/RegionPool.scala +++ b/hail/hail/src/is/hail/annotations/RegionPool.scala @@ -118,9 +118,7 @@ final class RegionPool private (strictMemoryCheck: Boolean, threadName: String, } } - def getRegion(): Region = getRegion(Region.REGULAR) - - def getRegion(size: Int): Region = { + def getRegion(size: Int = Region.REGULAR): Region = { val r = new Region(size, this) r.memory = getMemory(size) r @@ -154,11 +152,29 @@ final class RegionPool private (strictMemoryCheck: Boolean, threadName: String, def report(context: String): Unit = { val inBlocks = bytesInBlocks() + val (chunksAllocated, cacheHits) = chunkCache.getUsage() logger.info( - s"RegionPool: $context: ${readableBytes(totalAllocatedBytes)} allocated (${readableBytes(inBlocks)} blocks / " + - s"${readableBytes(totalAllocatedBytes - inBlocks)} chunks), regions.size = ${regions.size}, " + - s"$numJavaObjects current java objects, thread $threadID: $threadName" + s"""RegionPool: $context + | thread: + | id: $threadID + | name: $threadName + | objects: $numJavaObjects + | allocations: + | peak: $getHighestTotalUsage + | total: ${readableBytes(totalAllocatedBytes)} + | blocks: ${readableBytes(inBlocks)} + | chunks: ${readableBytes(totalAllocatedBytes - inBlocks)} + | regions: + | total: ${regions.size} + | free: ${freeRegions.size} + | blocks: + | total: ${blocks.sum} + | free: ${freeBlocks.view.map(_.size).sum} + | chunks: + | total: $chunksAllocated + | reused: $cacheHits + | """.stripMargin ) // logger.info("-----------STACK_TRACES---------") // val stacks: String = regions.result().toIndexedSeq.flatMap(r => r.stackTrace.map((r.getTotalChunkMemory(), _))).foldLeft("")((a: String, b) => a + "\n" + b.toString()) @@ -170,8 +186,6 @@ final class RegionPool private (strictMemoryCheck: Boolean, threadName: String, def scopedSmallRegion[T](f: Region => T): T = using(Region(Region.SMALL, pool = this))(f) def scopedTinyRegion[T](f: Region => T): T = using(Region(Region.TINY, pool = this))(f) - override def finalize(): Unit = close() - private[this] var closed: Boolean = false override def close(): Unit = { diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index 1fa78c7e6d1..dda25b66d0b 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -3,7 +3,6 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader -import is.hail.backend.local.LocalTaskContext import is.hail.expr.ir.{BaseIR, CompileCache, Compiled} import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.lowering.IrMetadata @@ -72,7 +71,7 @@ object ExecuteContext { coercerCache: mutable.Map[Any, LoweredTableReaderCoercer], )( f: ExecuteContext => T - ): T = { + ): T = RegionPool.scoped { pool => pool.scopedRegion { region => using(new ExecuteContext( @@ -94,7 +93,6 @@ object ExecuteContext { ))(f(_)) } } - } def createTmpPathNoCleanup(tmpdir: String, prefix: String, extension: String = null): String = { val random = new SecureRandom() @@ -113,7 +111,7 @@ class ExecuteContext( val backend: Backend, val references: Map[String, ReferenceGenome], val fs: FS, - val r: Region, + override val r: Region, val timer: ExecutionTimer, val tempFileManager: TempFileManager, val theHailClassLoader: HailClassLoader, @@ -123,7 +121,7 @@ class ExecuteContext( val CompileCache: CompileCache, val PersistedIrCache: mutable.Map[Int, BaseIR], val PersistedCoercerCache: mutable.Map[Any, LoweredTableReaderCoercer], -) extends Closeable { +) extends HailTaskContext with Closeable { val rngNonce: Long = try @@ -142,10 +140,14 @@ class ExecuteContext( val memo: mutable.Map[Any, Any] = new mutable.HashMap[Any, Any]() - val taskContext: HailTaskContext = new LocalTaskContext(0, 0) + private[this] val onCloseTasks = mutable.ArrayBuffer.empty[() => Unit] + override def onClose(f: () => Unit): Unit = onCloseTasks += f + + def run[A](f: Compiled[A])(implicit E: Enclosing): A = + time(f(theHailClassLoader, fs, this, r)) def scopedExecution[T](f: Compiled[T])(implicit E: Enclosing): T = - using(new LocalTaskContext(0, 0))(tc => time(f(theHailClassLoader, fs, tc, r))) + r.pool.scopedRegion(r => local(r = r)(_.run(f))) def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String = tempFileManager.newTmpPath(if (local) localTmpdir else tmpdir, prefix, extension) @@ -159,8 +161,8 @@ class ExecuteContext( def shouldLogIR(): Boolean = !shouldNotLogIR() override def close(): Unit = { + onCloseTasks.foreach(_()) tempFileManager.close() - taskContext.close() } def time[A](block: => A)(implicit E: Enclosing): A = @@ -179,7 +181,7 @@ class ExecuteContext( flags: HailFeatureFlags = this.flags, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, - codeCache: CompileCache = this.CompileCache, + compileCache: CompileCache = this.CompileCache, persistedIrCache: mutable.Map[Int, BaseIR] = this.PersistedIrCache, persistedCoercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.PersistedCoercerCache, )( @@ -198,7 +200,7 @@ class ExecuteContext( flags, irMetadata, blockMatrixCache, - codeCache, + compileCache, persistedIrCache, persistedCoercerCache, ))(f) diff --git a/hail/hail/src/is/hail/backend/HailTaskContext.scala b/hail/hail/src/is/hail/backend/HailTaskContext.scala index 01d662044d1..441c29de4d5 100644 --- a/hail/hail/src/is/hail/backend/HailTaskContext.scala +++ b/hail/hail/src/is/hail/backend/HailTaskContext.scala @@ -1,50 +1,35 @@ package is.hail.backend -import is.hail.annotations.RegionPool -import is.hail.utils._ +import is.hail.annotations.{Region, RegionPool} +import is.hail.utils.using import scala.collection.mutable -import java.io.Closeable +trait HailTaskContext { -class TaskFinalizer { - val closeables = mutable.ArrayBuffer.empty[Closeable] + /** region whose lifetime is at least as long as this task */ + def r: Region - def clear(): Unit = - closeables.clear() - - def addCloseable(c: Closeable): Unit = - closeables += c - - def closeAll(): Unit = closeables.foreach(_.close()) + /** register an action that will be called when this task completes */ + def onClose(f: () => Unit): Unit } -abstract class HailTaskContext extends AutoCloseable with Logging { - def stageId(): Int - - def partitionId(): Int - - def attemptNumber(): Int - - private lazy val thePool = RegionPool() - - def getRegionPool(): RegionPool = thePool +object HailTaskContext { + def runPartition[A](partId: Int)(f: HailTaskContext => A): A = + using(new PartitionContext(partId))(f) +} - val finalizers = mutable.ArrayBuffer.empty[TaskFinalizer] +class PartitionContext(partId: Int) extends HailTaskContext with AutoCloseable { + private[this] val onCloseTasks = mutable.ArrayBuffer.empty[() => Unit] - def newFinalizer(): TaskFinalizer = { - val f = new TaskFinalizer - finalizers += f - f - } + private[this] val pool = RegionPool() + override val r: Region = Region(pool = pool) + override def onClose(f: () => Unit): Unit = onCloseTasks += f override def close(): Unit = { - logger.info( - s"TaskReport: stage=${stageId()}, partition=${partitionId()}, attempt=${attemptNumber()}, " + - s"peakBytes=${thePool.getHighestTotalUsage}, peakBytesReadable=${formatSpace(thePool.getHighestTotalUsage)}, " + - s"chunks requested=${thePool.getUsage._1}, cache hits=${thePool.getUsage._2}" - ) - finalizers.foreach(_.closeAll()) - thePool.close() + onCloseTasks.foreach(_()) + r.close() + pool.logStats(s"Partition $partId") + pool.close() } } diff --git a/hail/hail/src/is/hail/backend/local/LocalBackend.scala b/hail/hail/src/is/hail/backend/local/LocalBackend.scala index c6a47ae39dd..8ccdaeea72f 100644 --- a/hail/hail/src/is/hail/backend/local/LocalBackend.scala +++ b/hail/hail/src/is/hail/backend/local/LocalBackend.scala @@ -20,10 +20,6 @@ import com.fasterxml.jackson.core.StreamReadConstraints class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable -class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskContext { - override def attemptNumber(): Int = 0 -} - object LocalBackend extends Backend with Logging { // From https://github.com/hail-is/hail/issues/14580 : @@ -43,16 +39,7 @@ object LocalBackend extends Backend with Logging { override def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value) - private[this] var stageIdx: Int = 0 - - private[this] def nextStageId(): Int = - synchronized { - val current = stageIdx - stageIdx += 1 - current - } - - override def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext = { + override def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext = new DriverRuntimeContext { override val executionCache: ExecutionCache = @@ -77,14 +64,10 @@ object LocalBackend extends Backend with Logging { var failure: Option[Throwable] = None - val stageId = nextStageId() - try for (idx <- todo) - results += using(new LocalTaskContext(idx, stageId)) { htc => - htc.getRegionPool().scopedRegion { r => - f(ctx.theHailClassLoader, ctx.fs, htc, r)(globals, contexts(idx)) -> idx - } + results += ctx.scopedExecution { (hcl, fs, ctx, r) => + (f(hcl, fs, ctx, r)(globals, contexts(idx)), idx) } catch { case NonFatal(t) => @@ -94,12 +77,10 @@ object LocalBackend extends Backend with Logging { (failure, results.result()) } } - } override def defaultParallelism: Int = 1 - override def close(): Unit = - synchronized { stageIdx = 0 } + override def close(): Unit = {} private[this] def _jvmLowerAndExecute( ctx: ExecuteContext, diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index 150de69a006..f277d1fe31f 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -4,7 +4,6 @@ import is.hail.Revision import is.hail.backend._ import is.hail.backend.Backend.PartitionFn import is.hail.backend.ExecutionCache.Flags.UseFastRestarts -import is.hail.backend.local.LocalTaskContext import is.hail.backend.service.ServiceBackend.Flags._ import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq @@ -361,10 +360,8 @@ class ServiceBackend( partitions.getOrElse(contexts.indices) match { case Seq(k) => try - using(new LocalTaskContext(k, stageCount)) { htc => - None -> htc.getRegionPool().scopedRegion { r => - FastSeq(f(ctx.theHailClassLoader, ctx.fs, htc, r)(globals, contexts(k)) -> k) - } + ctx.scopedExecution { (hcl, fs, htc, r) => + (None, FastSeq(f(hcl, fs, htc, r)(globals, contexts(k)) -> k)) } catch { case NonFatal(t) => Some(t) -> ArraySeq.empty diff --git a/hail/hail/src/is/hail/backend/service/Worker.scala b/hail/hail/src/is/hail/backend/service/Worker.scala index 4d3cb1231d1..62568a82c5f 100644 --- a/hail/hail/src/is/hail/backend/service/Worker.scala +++ b/hail/hail/src/is/hail/backend/service/Worker.scala @@ -19,12 +19,6 @@ import java.util import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger -class ServiceTaskContext(val partitionId: Int) extends HailTaskContext { - override def stageId(): Int = 0 - - override def attemptNumber(): Int = 0 -} - class WorkerTimer extends Logging { var startTimes: mutable.Map[String, Long] = mutable.Map() @@ -250,12 +244,11 @@ object Worker extends Logging { inputs.flatMap { case (globals, context, f) => timer.enter("execute") { try - using(new ServiceTaskContext(partition)) { htc => - retryTransientErrors { - htc.getRegionPool().scopedRegion { r => - Right(f(hcl, fs, htc, r)(globals, context)) - } - } + HailTaskContext.runPartition(partition) { htc => + retryTransientErrors( + Right(f(hcl, fs, htc, htc.r)(globals, context)), + Some(() => htc.r.clear()), + ) } catch { case t: Throwable => Left(t) diff --git a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala index 58fc732448e..56080b09293 100644 --- a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala +++ b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala @@ -35,31 +35,23 @@ class SparkBroadcastValue[T](bc: Broadcast[T]) extends BroadcastValue[T] with Se } object SparkTaskContext { - def get(): SparkTaskContext = taskContext.get + def get: HailTaskContext = taskContext.get - private[this] val taskContext: ThreadLocal[SparkTaskContext] = - new ThreadLocal[SparkTaskContext]() { - override def initialValue(): SparkTaskContext = { + private[this] val taskContext: ThreadLocal[HailTaskContext] = + new ThreadLocal[HailTaskContext]() { + override def initialValue(): HailTaskContext = { val sparkTC = TaskContext.get() assert(sparkTC != null, "Spark Task Context was null, maybe this ran on the driver?") - sparkTC.addTaskCompletionListener[Unit]((_: TaskContext) => SparkTaskContext.finish()): Unit - // this must be the only place where SparkTaskContext classes are created - new SparkTaskContext(sparkTC) + val htc = new PartitionContext(sparkTC.stageId()) + sparkTC.addTaskCompletionListener[Unit] { _ => htc.close(); remove(); }: Unit + + htc } } - def finish(): Unit = { - taskContext.get().close() + def finish(): Unit = taskContext.remove() - } -} - -class SparkTaskContext private[spark] (ctx: TaskContext) extends HailTaskContext { - self => - override def stageId(): Int = ctx.stageId() - override def partitionId(): Int = ctx.partitionId() - override def attemptNumber(): Int = ctx.attemptNumber() } object SparkBackend extends Logging { @@ -267,11 +259,9 @@ class SparkBackend(val spark: SparkSession) extends Backend with Logging { override def compute(partition: Partition, context: TaskContext) : Iterator[Array[Byte]] = { - val htc = SparkTaskContext.get() - htc.getRegionPool().scopedRegion { r => - val g = f(unsafeHailClassLoaderForSparkWorkers, new HadoopFS(fsConfig), htc, r) - Iterator.single(g(globals, partition.asInstanceOf[RDDPartition].data)) - } + val ctx = SparkTaskContext.get + val g = f(unsafeHailClassLoaderForSparkWorkers, new HadoopFS(fsConfig), ctx, ctx.r) + Iterator.single(g(globals, partition.asInstanceOf[RDDPartition].data)) } } diff --git a/hail/hail/src/is/hail/expr/ir/Compile.scala b/hail/hail/src/is/hail/expr/ir/Compile.scala index 5f14d92c55c..f836b317bdf 100644 --- a/hail/hail/src/is/hail/expr/ir/Compile.scala +++ b/hail/hail/src/is/hail/expr/ir/Compile.scala @@ -292,7 +292,7 @@ object CompileIterator { ir: IR, ): ( PType, - (HailClassLoader, FS, HailTaskContext, RVDContext, Long, NoBoxLongIterator) => Iterator[java.lang.Long], + (HailClassLoader, FS, RVDContext, Long, NoBoxLongIterator) => Iterator[java.lang.Long], ) = { assert(typ0.required) assert(streamElementType.required) @@ -308,10 +308,9 @@ object CompileIterator { ) ( eltPType, - (theHailClassLoader, fs, htc, consumerCtx, v0, part) => { - val outerStepFunction = - makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) - outerStepFunction.setRegions(consumerCtx.partitionRegion, consumerCtx.region) + (theHailClassLoader, fs, ctx, v0, part) => { + val outerStepFunction = makeStepper(theHailClassLoader, fs, ctx, ctx.r) + outerStepFunction.setRegions(ctx.r, ctx.region) new LongIteratorWrapper { val stepFunction: TMPStepFunction = outerStepFunction @@ -328,7 +327,7 @@ object CompileIterator { ir: IR, ): ( PType, - (HailClassLoader, FS, HailTaskContext, RVDContext, Long, Long) => Iterator[java.lang.Long], + (HailClassLoader, FS, RVDContext, Long, Long) => Iterator[java.lang.Long], ) = { assert(ctxType.required) assert(bcValsType.required) @@ -344,10 +343,9 @@ object CompileIterator { ) ( eltPType, - (theHailClassLoader, fs, htc, consumerCtx, v0, v1) => { - val outerStepFunction = - makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) - outerStepFunction.setRegions(consumerCtx.partitionRegion, consumerCtx.region) + (theHailClassLoader, fs, ctx, v0, v1) => { + val outerStepFunction = makeStepper(theHailClassLoader, fs, ctx, ctx.r) + outerStepFunction.setRegions(ctx.r, ctx.region) new LongIteratorWrapper { val stepFunction: TableStageToRVDStepFunction = outerStepFunction diff --git a/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala b/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala index c1503c54d3c..6990d715103 100644 --- a/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala +++ b/hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala @@ -4,7 +4,7 @@ import is.hail.annotations.{Region, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq -import is.hail.expr.ir.defs.{Begin, EncodedLiteral, Literal, MakeTuple, NA} +import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -78,11 +78,8 @@ object CompileAndEvaluate { print, ) - val res = ctx.scopedExecution { (hcl, fs, htc, r) => - val execute = f(hcl, fs, htc, r) - ctx.time(execute(r)) - } - + val execute = f(ctx.theHailClassLoader, ctx.fs, ctx, ctx.r) + val res = ctx.time(execute(ctx.r)) Right((resType, res)) } } diff --git a/hail/hail/src/is/hail/expr/ir/Interpret.scala b/hail/hail/src/is/hail/expr/ir/Interpret.scala index 12a81c5abd6..544501f39e2 100644 --- a/hail/hail/src/is/hail/expr/ir/Interpret.scala +++ b/hail/hail/src/is/hail/expr/ir/Interpret.scala @@ -2,19 +2,19 @@ package is.hail.expr.ir import is.hail.annotations._ import is.hail.asm4s._ -import is.hail.backend.{ExecuteContext, HailTaskContext} -import is.hail.backend.spark.SparkTaskContext +import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq -import is.hail.collection.implicits.{toRichIterable, toRichIterator} +import is.hail.collection.implicits._ +import is.hail.expr.ir.agg.AggExecuteContextExtensions import is.hail.expr.ir.analyses.PartitionCounts import is.hail.expr.ir.defs._ +import is.hail.expr.ir.implicits._ import is.hail.expr.ir.lowering.{ExecuteRelational, LoweringPipeline} import is.hail.io.BufferSpec import is.hail.linalg.BlockMatrix -import is.hail.rvd.RVDContext import is.hail.types.physical.{PTuple, PType} -import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType} +import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.tcoerce import is.hail.types.virtual._ import is.hail.utils._ @@ -60,13 +60,13 @@ object Interpret extends Logging { }.toFastSeq val lowered = LoweringPipeline.relationalLowerer(ctx, Let(bindings, ir0)).asInstanceOf[IR] - val result = run(ctx, lowered, Env.empty[Any], args, Memo.empty).asInstanceOf[T] + val result = run(ctx, lowered, Env.empty[Any], args).asInstanceOf[T] result } def alreadyLowered(ctx: ExecuteContext, ir: IR): Any = ctx.local(flags = ctx.flags - Optimize.Flags.Optimize) { ctx => - run(ctx, ir, Env.empty, FastSeq(), Memo.empty) + run(ctx, ir, Env.empty, FastSeq()) } private def run( @@ -74,1025 +74,1019 @@ object Interpret extends Logging { ir: IR, env: Env[Any], args: IndexedSeq[(Any, Type)], - functionMemo: Memo[(SingleCodeType, AsmFunction2RegionLongLong)], ): Any = { - def interpret(ir: IR, env: Env[Any] = env, args: IndexedSeq[(Any, Type)] = args): Any = - run(ctx, ir, env, args, functionMemo) - - ir match { - case I32(x) => x - case I64(x) => x - case F32(x) => x - case F64(x) => x - case Str(x) => x - case True() => true - case False() => false - case Literal(_, value) => value - case x @ EncodedLiteral(codec, value) => - ctx.r.getPool().scopedRegion { r => - val (pt, addr) = codec.decodeArrays(ctx, x.typ, value.ba, ctx.r) - SafeRow.read(pt, addr) - } - case Void() => () - case Cast(v, t) => - val vValue = interpret(v, env, args) - if (vValue == null) - null - else - (v.typ, t) match { - case (TInt32, TInt32) => vValue - case (TInt32, TInt64) => vValue.asInstanceOf[Int].toLong - case (TInt32, TFloat32) => vValue.asInstanceOf[Int].toFloat - case (TInt32, TFloat64) => vValue.asInstanceOf[Int].toDouble - case (TInt64, TInt64) => vValue - case (TInt64, TInt32) => vValue.asInstanceOf[Long].toInt - case (TInt64, TFloat32) => vValue.asInstanceOf[Long].toFloat - case (TInt64, TFloat64) => vValue.asInstanceOf[Long].toDouble - case (TFloat32, TFloat32) => vValue - case (TFloat32, TInt32) => vValue.asInstanceOf[Float].toInt - case (TFloat32, TInt64) => vValue.asInstanceOf[Float].toLong - case (TFloat32, TFloat64) => vValue.asInstanceOf[Float].toDouble - case (TFloat64, TFloat64) => vValue - case (TFloat64, TInt32) => vValue.asInstanceOf[Double].toInt - case (TFloat64, TInt64) => vValue.asInstanceOf[Double].toLong - case (TFloat64, TFloat32) => vValue.asInstanceOf[Double].toFloat - case (TInt32, TCall) => vValue + def interpret(ir: IR, env: Env[Any], args: IndexedSeq[(Any, Type)]): Any = + ir match { + case I32(x) => x + case I64(x) => x + case F32(x) => x + case F64(x) => x + case Str(x) => x + case True() => true + case False() => false + case Literal(_, value) => value + case x @ EncodedLiteral(codec, value) => + ctx.scopedExecution { (_, _, _, r) => + val (pt, addr) = codec.decodeArrays(ctx, x.typ, value.ba, r) + SafeRow.read(pt, addr) } - case CastRename(v, _) => interpret(v) - case NA(_) => null - case IsNA(value) => interpret(value, env, args) == null - case Coalesce(values) => - values.iterator - .flatMap(x => Option(interpret(x, env, args))) - .headOption - .orNull - case If(cond, cnsq, altr) => - assert(cnsq.typ == altr.typ) - val condValue = interpret(cond, env, args) - if (condValue == null) - null - else if (condValue.asInstanceOf[Boolean]) - interpret(cnsq, env, args) - else - interpret(altr, env, args) - case Switch(x_, default, cases) => - interpret(x_, env, args) match { - case x: Int => - interpret(if (x >= 0 && x < cases.length) cases(x) else default, env, args) - case null => + case Void() => () + case Cast(v, t) => + val vValue = interpret(v, env, args) + if (vValue == null) null - } - case Block(bindings, body) => - val newEnv = bindings.foldLeft(env) { case (env, Binding(name, value, Scope.EVAL)) => - env.bind(name -> interpret(value, env, args)) - } - interpret(body, newEnv, args) - case Ref(name, _) => env.lookup(name) - case ApplyBinaryPrimOp(op, l, r) => - val lValue = interpret(l, env, args) - val rValue = interpret(r, env, args) - if (lValue == null || rValue == null) - null - else - (l.typ, r.typ) match { - case (TInt32, TInt32) => - val ll = lValue.asInstanceOf[Int] - val rr = rValue.asInstanceOf[Int] - (op: @unchecked) match { - case Add() => ll + rr - case Subtract() => ll - rr - case Multiply() => ll * rr - case FloatingPointDivide() => ll.toDouble / rr.toDouble - case RoundToNegInfDivide() => java.lang.Math.floorDiv(ll, rr) - case BitAnd() => ll & rr - case BitOr() => ll | rr - case BitXOr() => ll ^ rr - case LeftShift() => ll << rr - case RightShift() => ll >> rr - case LogicalRightShift() => ll >>> rr - } - case (TInt64, TInt32) => - val ll = lValue.asInstanceOf[Long] - val rr = rValue.asInstanceOf[Int] - (op: @unchecked) match { - case LeftShift() => ll << rr - case RightShift() => ll >> rr - case LogicalRightShift() => ll >>> rr - } - case (TInt64, TInt64) => - val ll = lValue.asInstanceOf[Long] - val rr = rValue.asInstanceOf[Long] - (op: @unchecked) match { - case Add() => ll + rr - case Subtract() => ll - rr - case Multiply() => ll * rr - case FloatingPointDivide() => ll.toDouble / rr.toDouble - case RoundToNegInfDivide() => java.lang.Math.floorDiv(ll, rr) - case BitAnd() => ll & rr - case BitOr() => ll | rr - case BitXOr() => ll ^ rr - case LeftShift() => ll << rr - case RightShift() => ll >> rr + else + (v.typ, t) match { + case (TInt32, TInt32) => vValue + case (TInt32, TInt64) => vValue.asInstanceOf[Int].toLong + case (TInt32, TFloat32) => vValue.asInstanceOf[Int].toFloat + case (TInt32, TFloat64) => vValue.asInstanceOf[Int].toDouble + case (TInt64, TInt64) => vValue + case (TInt64, TInt32) => vValue.asInstanceOf[Long].toInt + case (TInt64, TFloat32) => vValue.asInstanceOf[Long].toFloat + case (TInt64, TFloat64) => vValue.asInstanceOf[Long].toDouble + case (TFloat32, TFloat32) => vValue + case (TFloat32, TInt32) => vValue.asInstanceOf[Float].toInt + case (TFloat32, TInt64) => vValue.asInstanceOf[Float].toLong + case (TFloat32, TFloat64) => vValue.asInstanceOf[Float].toDouble + case (TFloat64, TFloat64) => vValue + case (TFloat64, TInt32) => vValue.asInstanceOf[Double].toInt + case (TFloat64, TInt64) => vValue.asInstanceOf[Double].toLong + case (TFloat64, TFloat32) => vValue.asInstanceOf[Double].toFloat + case (TInt32, TCall) => vValue + } + case CastRename(v, _) => interpret(v, env, args) + case NA(_) => null + case IsNA(value) => interpret(value, env, args) == null + case Coalesce(values) => + values.iterator + .flatMap(x => Option(interpret(x, env, args))) + .headOption + .orNull + case If(cond, cnsq, altr) => + assert(cnsq.typ == altr.typ) + val condValue = interpret(cond, env, args) + if (condValue == null) + null + else if (condValue.asInstanceOf[Boolean]) + interpret(cnsq, env, args) + else + interpret(altr, env, args) + case Switch(x_, default, cases) => + interpret(x_, env, args) match { + case x: Int => + interpret(if (x >= 0 && x < cases.length) cases(x) else default, env, args) + case null => + null + } + case Block(bindings, body) => + val newEnv = bindings.foldLeft(env) { case (env, Binding(name, value, Scope.EVAL)) => + env.bind(name -> interpret(value, env, args)) + } + interpret(body, newEnv, args) + case Ref(name, _) => env.lookup(name) + case ApplyBinaryPrimOp(op, l, r) => + val lValue = interpret(l, env, args) + val rValue = interpret(r, env, args) + if (lValue == null || rValue == null) + null + else + (l.typ, r.typ) match { + case (TInt32, TInt32) => + val ll = lValue.asInstanceOf[Int] + val rr = rValue.asInstanceOf[Int] + (op: @unchecked) match { + case Add() => ll + rr + case Subtract() => ll - rr + case Multiply() => ll * rr + case FloatingPointDivide() => ll.toDouble / rr.toDouble + case RoundToNegInfDivide() => java.lang.Math.floorDiv(ll, rr) + case BitAnd() => ll & rr + case BitOr() => ll | rr + case BitXOr() => ll ^ rr + case LeftShift() => ll << rr + case RightShift() => ll >> rr + case LogicalRightShift() => ll >>> rr + } + case (TInt64, TInt32) => + val ll = lValue.asInstanceOf[Long] + val rr = rValue.asInstanceOf[Int] + (op: @unchecked) match { + case LeftShift() => ll << rr + case RightShift() => ll >> rr + case LogicalRightShift() => ll >>> rr + } + case (TInt64, TInt64) => + val ll = lValue.asInstanceOf[Long] + val rr = rValue.asInstanceOf[Long] + (op: @unchecked) match { + case Add() => ll + rr + case Subtract() => ll - rr + case Multiply() => ll * rr + case FloatingPointDivide() => ll.toDouble / rr.toDouble + case RoundToNegInfDivide() => java.lang.Math.floorDiv(ll, rr) + case BitAnd() => ll & rr + case BitOr() => ll | rr + case BitXOr() => ll ^ rr + case LeftShift() => ll << rr + case RightShift() => ll >> rr + } + case (TFloat32, TFloat32) => + val ll = lValue.asInstanceOf[Float] + val rr = rValue.asInstanceOf[Float] + (op: @unchecked) match { + case Add() => ll + rr + case Subtract() => ll - rr + case Multiply() => ll * rr + case FloatingPointDivide() => ll / rr + case RoundToNegInfDivide() => math.floor(ll.toDouble / rr).toFloat + } + case (TFloat64, TFloat64) => + val ll = lValue.asInstanceOf[Double] + val rr = rValue.asInstanceOf[Double] + (op: @unchecked) match { + case Add() => ll + rr + case Subtract() => ll - rr + case Multiply() => ll * rr + case FloatingPointDivide() => ll / rr + case RoundToNegInfDivide() => math.floor(ll / rr) + } + } + case ApplyUnaryPrimOp(op, x) => + val xValue = interpret(x, env, args) + if (xValue == null) + null + else op match { + case Bang => + assert(x.typ == TBoolean) + !xValue.asInstanceOf[Boolean] + case Negate => + assert(x.typ.isInstanceOf[TNumeric]) + x.typ match { + case TInt32 => -xValue.asInstanceOf[Int] + case TInt64 => -xValue.asInstanceOf[Long] + case TFloat32 => -xValue.asInstanceOf[Float] + case TFloat64 => -xValue.asInstanceOf[Double] } - case (TFloat32, TFloat32) => - val ll = lValue.asInstanceOf[Float] - val rr = rValue.asInstanceOf[Float] - (op: @unchecked) match { - case Add() => ll + rr - case Subtract() => ll - rr - case Multiply() => ll * rr - case FloatingPointDivide() => ll / rr - case RoundToNegInfDivide() => math.floor(ll.toDouble / rr).toFloat + case BitNot => + assert(x.typ.isInstanceOf[TIntegral]) + x.typ match { + case TInt32 => ~xValue.asInstanceOf[Int] + case TInt64 => ~xValue.asInstanceOf[Long] } - case (TFloat64, TFloat64) => - val ll = lValue.asInstanceOf[Double] - val rr = rValue.asInstanceOf[Double] - (op: @unchecked) match { - case Add() => ll + rr - case Subtract() => ll - rr - case Multiply() => ll * rr - case FloatingPointDivide() => ll / rr - case RoundToNegInfDivide() => math.floor(ll / rr) + case BitCount => + assert(x.typ.isInstanceOf[TIntegral]) + x.typ match { + case TInt32 => Integer.bitCount(xValue.asInstanceOf[Int]) + case TInt64 => java.lang.Long.bitCount(xValue.asInstanceOf[Long]) } } - case ApplyUnaryPrimOp(op, x) => - val xValue = interpret(x, env, args) - if (xValue == null) - null - else op match { - case Bang => - assert(x.typ == TBoolean) - !xValue.asInstanceOf[Boolean] - case Negate => - assert(x.typ.isInstanceOf[TNumeric]) - x.typ match { - case TInt32 => -xValue.asInstanceOf[Int] - case TInt64 => -xValue.asInstanceOf[Long] - case TFloat32 => -xValue.asInstanceOf[Float] - case TFloat64 => -xValue.asInstanceOf[Double] - } - case BitNot => - assert(x.typ.isInstanceOf[TIntegral]) - x.typ match { - case TInt32 => ~xValue.asInstanceOf[Int] - case TInt64 => ~xValue.asInstanceOf[Long] - } - case BitCount => - assert(x.typ.isInstanceOf[TIntegral]) - x.typ match { - case TInt32 => Integer.bitCount(xValue.asInstanceOf[Int]) - case TInt64 => java.lang.Long.bitCount(xValue.asInstanceOf[Long]) + case ApplyComparisonOp(op, l, r) => + val lValue = interpret(l, env, args) + val rValue = interpret(r, env, args) + val t = l.typ + if (op.strict && (lValue == null || rValue == null)) + null + else + op match { + case EQ => t.ordering(ctx.stateManager).equiv(lValue, rValue) + case EQWithNA => t.ordering(ctx.stateManager).equiv(lValue, rValue) + case NEQ => !t.ordering(ctx.stateManager).equiv(lValue, rValue) + case NEQWithNA => !t.ordering(ctx.stateManager).equiv(lValue, rValue) + case LT => t.ordering(ctx.stateManager).lt(lValue, rValue) + case GT => t.ordering(ctx.stateManager).gt(lValue, rValue) + case LTEQ => t.ordering(ctx.stateManager).lteq(lValue, rValue) + case GTEQ => t.ordering(ctx.stateManager).gteq(lValue, rValue) + case Compare => t.ordering(ctx.stateManager).compare(lValue, rValue) } - } - case ApplyComparisonOp(op, l, r) => - val lValue = interpret(l, env, args) - val rValue = interpret(r, env, args) - val t = l.typ - if (op.strict && (lValue == null || rValue == null)) - null - else - op match { - case EQ => t.ordering(ctx.stateManager).equiv(lValue, rValue) - case EQWithNA => t.ordering(ctx.stateManager).equiv(lValue, rValue) - case NEQ => !t.ordering(ctx.stateManager).equiv(lValue, rValue) - case NEQWithNA => !t.ordering(ctx.stateManager).equiv(lValue, rValue) - case LT => t.ordering(ctx.stateManager).lt(lValue, rValue) - case GT => t.ordering(ctx.stateManager).gt(lValue, rValue) - case LTEQ => t.ordering(ctx.stateManager).lteq(lValue, rValue) - case GTEQ => t.ordering(ctx.stateManager).gteq(lValue, rValue) - case Compare => t.ordering(ctx.stateManager).compare(lValue, rValue) - } - case MakeArray(elements, _) => elements.map(interpret(_, env, args)).toFastSeq - case MakeStream(elements, _, _) => elements.map(interpret(_, env, args)).toFastSeq - case ArrayRef(a, i, errorId) => - val aValue = interpret(a, env, args) - val iValue = interpret(i, env, args) - if (aValue == null || iValue == null) - null - else { - val a = aValue.asInstanceOf[IndexedSeq[Any]] - val i = iValue.asInstanceOf[Int] + case MakeArray(elements, _) => elements.map(interpret(_, env, args)).toFastSeq + case MakeStream(elements, _, _) => elements.map(interpret(_, env, args)).toFastSeq + case ArrayRef(a, i, errorId) => + val aValue = interpret(a, env, args) + val iValue = interpret(i, env, args) + if (aValue == null || iValue == null) + null + else { + val a = aValue.asInstanceOf[IndexedSeq[Any]] + val i = iValue.asInstanceOf[Int] - if (i < 0 || i >= a.length) { - fatal(s"array index out of bounds: index=$i, length=${a.length}", errorId = errorId) - } else - a.apply(i) - } - case ArraySlice(a, start, stop, step, errorID) => - val aValue = interpret(a, env, args) - val startValue = interpret(start, env, args) - val stopValue = stop.map(ir => interpret(ir, env, args)) - val stepValue = interpret(step, env, args) - if ( - startValue == null || stepValue == null || aValue == null || - stopValue.getOrElse(aValue.asInstanceOf[IndexedSeq[Any]].size) == null - ) - null - else { - val a = aValue.asInstanceOf[IndexedSeq[Any]] - val requestedStart = startValue.asInstanceOf[Int] - val requestedStep = stepValue.asInstanceOf[Int] - if (requestedStep == 0) - fatal("step cannot be 0 for array slice", errorID) - val noneStop = if (requestedStep < 0) -a.size - 1 - else a.size - val maxBound = if (requestedStep > 0) a.size - else a.size - 1 - val minBound = if (requestedStep > 0) 0 - else -1 - val requestedStop = stopValue.getOrElse(noneStop).asInstanceOf[Int] - val realStart = if (requestedStart >= a.size) maxBound - else if (requestedStart >= 0) requestedStart - else if (requestedStart + a.size >= 0) requestedStart + a.size - else minBound - val realStop = if (requestedStop >= a.size) maxBound - else if (requestedStop >= 0) requestedStop - else if (requestedStop + a.size > 0) requestedStop + a.size - else minBound - (realStart until realStop by requestedStep).map(idx => a(idx)) - } - case ArrayLen(a) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else - aValue.asInstanceOf[IndexedSeq[Any]].length - case StreamLen(a) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else - aValue.asInstanceOf[IndexedSeq[Any]].length - case StreamIota(_, _, _) => - throw new UnsupportedOperationException - case StreamRange(start, stop, step, _, errorID) => - val startValue = interpret(start, env, args) - val stopValue = interpret(stop, env, args) - val stepValue = interpret(step, env, args) - if (stepValue == 0) - fatal("Array range cannot have step size 0.", errorID) - if (startValue == null || stopValue == null || stepValue == null) - null - else - startValue.asInstanceOf[Int] until stopValue.asInstanceOf[Int] by stepValue.asInstanceOf[ - Int - ] - case ArraySort(a, l, r, lessThan) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - aValue.asInstanceOf[IndexedSeq[Any]].sortWith { (left, right) => - if (left != null && right != null) { - val res = interpret(lessThan, env.bind(l, left).bind(r, right), args) - if (res == null) - fatal("Result of sorting function cannot be missing.") - res.asInstanceOf[Boolean] - } else { - right == null + if (i < 0 || i >= a.length) { + fatal(s"array index out of bounds: index=$i, length=${a.length}", errorId = errorId) + } else + a.apply(i) + } + case ArraySlice(a, start, stop, step, errorID) => + val aValue = interpret(a, env, args) + val startValue = interpret(start, env, args) + val stopValue = stop.map(ir => interpret(ir, env, args)) + val stepValue = interpret(step, env, args) + if ( + startValue == null || stepValue == null || aValue == null || + stopValue.getOrElse(aValue.asInstanceOf[IndexedSeq[Any]].size) == null + ) + null + else { + val a = aValue.asInstanceOf[IndexedSeq[Any]] + val requestedStart = startValue.asInstanceOf[Int] + val requestedStep = stepValue.asInstanceOf[Int] + if (requestedStep == 0) + fatal("step cannot be 0 for array slice", errorID) + val noneStop = if (requestedStep < 0) -a.size - 1 + else a.size + val maxBound = if (requestedStep > 0) a.size + else a.size - 1 + val minBound = if (requestedStep > 0) 0 + else -1 + val requestedStop = stopValue.getOrElse(noneStop).asInstanceOf[Int] + val realStart = if (requestedStart >= a.size) maxBound + else if (requestedStart >= 0) requestedStart + else if (requestedStart + a.size >= 0) requestedStart + a.size + else minBound + val realStop = if (requestedStop >= a.size) maxBound + else if (requestedStop >= 0) requestedStop + else if (requestedStop + a.size > 0) requestedStop + a.size + else minBound + (realStart until realStop by requestedStep).map(idx => a(idx)) + } + case ArrayLen(a) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else + aValue.asInstanceOf[IndexedSeq[Any]].length + case StreamLen(a) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else + aValue.asInstanceOf[IndexedSeq[Any]].length + case StreamIota(_, _, _) => + throw new UnsupportedOperationException + case StreamRange(start, stop, step, _, errorID) => + val startValue = interpret(start, env, args) + val stopValue = interpret(stop, env, args) + val stepValue = interpret(step, env, args) + if (stepValue == 0) fatal("Array range cannot have step size 0.", errorID) + else if (startValue == null || stopValue == null || stepValue == null) null + else Range( + startValue.asInstanceOf[Int], + stopValue.asInstanceOf[Int], + stepValue.asInstanceOf[Int], + ) + case ArraySort(a, l, r, lessThan) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + aValue.asInstanceOf[IndexedSeq[Any]].sortWith { (left, right) => + if (left != null && right != null) { + val res = interpret(lessThan, env.bind(l, left).bind(r, right), args) + if (res == null) + fatal("Result of sorting function cannot be missing.") + res.asInstanceOf[Boolean] + } else { + right == null + } } } - } - case ToSet(a) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else - aValue.asInstanceOf[IndexedSeq[Any]].toSet - case ToDict(a) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else - aValue.asInstanceOf[IndexedSeq[Row]].filter(_ != null).map { case Row(k, v) => - (k, v) - }.toMap - case _: CastToArray | _: ToArray | _: ToStream => - val c = ir.children.head.asInstanceOf[IR] - val cValue = interpret(c, env, args) - if (cValue == null) - null - else { - val ordering = tcoerce[TIterable](c.typ).elementType.ordering(ctx.stateManager).toOrdering - cValue match { - case s: Set[_] => - s.asInstanceOf[Set[Any]].toFastSeq.sorted(ordering) - case d: Map[_, _] => - d.iterator.map { case (k, v) => Row(k, v) }.toFastSeq.sorted(ordering) - case a => a + case ToSet(a) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else + aValue.asInstanceOf[IndexedSeq[Any]].toSet + case ToDict(a) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else + aValue.asInstanceOf[IndexedSeq[Row]].filter(_ != null).map { case Row(k, v) => + (k, v) + }.toMap + case _: CastToArray | _: ToArray | _: ToStream => + val c = ir.children.head.asInstanceOf[IR] + val cValue = interpret(c, env, args) + if (cValue == null) + null + else { + val ordering = + tcoerce[TIterable](c.typ).elementType.ordering(ctx.stateManager).toOrdering + cValue match { + case s: Set[_] => + s.asInstanceOf[Set[Any]].toFastSeq.sorted(ordering) + case d: Map[_, _] => + d.iterator.map { case (k, v) => Row(k, v) }.toFastSeq.sorted(ordering) + case a => a + } } - } - case LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => - val cValue = interpret(orderedCollection, env, args) - val eValue = interpret(elem, env, args) - if (cValue == null) - null - else { - cValue match { - case s: Set[_] => - assert(!onKey) - s.count(elem.typ.ordering(ctx.stateManager).lt(_, eValue)) - case d: Map[_, _] => - assert(onKey) - d.count { case (k, _) => elem.typ.ordering(ctx.stateManager).lt(k, eValue) } - case a: IndexedSeq[_] => - if (onKey) { - val (eltF, eltT) = - orderedCollection.typ.asInstanceOf[TContainer].elementType match { - case t: TBaseStruct => ( - { (x: Any) => - val r = x.asInstanceOf[Row] - if (r == null) null else r.get(0) - }, - t.types(0), - ) - case i: TInterval => ( - { (x: Any) => - val i = x.asInstanceOf[Interval] - if (i == null) null else i.start - }, - i.pointType, - ) - } - val ordering = eltT.ordering(ctx.stateManager) - val lb = a.count(elem => ordering.lt(eltF(elem), eValue)) - lb - } else - a.count(elem.typ.ordering(ctx.stateManager).lt(_, eValue)) + case LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => + val cValue = interpret(orderedCollection, env, args) + val eValue = interpret(elem, env, args) + if (cValue == null) + null + else { + cValue match { + case s: Set[_] => + assert(!onKey) + s.count(elem.typ.ordering(ctx.stateManager).lt(_, eValue)) + case d: Map[_, _] => + assert(onKey) + d.count { case (k, _) => elem.typ.ordering(ctx.stateManager).lt(k, eValue) } + case a: IndexedSeq[_] => + if (onKey) { + val (eltF, eltT) = + orderedCollection.typ.asInstanceOf[TContainer].elementType match { + case t: TBaseStruct => ( + { (x: Any) => + val r = x.asInstanceOf[Row] + if (r == null) null else r.get(0) + }, + t.types(0), + ) + case i: TInterval => ( + { (x: Any) => + val i = x.asInstanceOf[Interval] + if (i == null) null else i.start + }, + i.pointType, + ) + } + val ordering = eltT.ordering(ctx.stateManager) + val lb = a.count(elem => ordering.lt(eltF(elem), eValue)) + lb + } else + a.count(elem.typ.ordering(ctx.stateManager).lt(_, eValue)) + } } - } - case GroupByKey(collection) => - interpret(collection, env, args).asInstanceOf[IndexedSeq[Row]] - .groupBy { case Row(k, _) => k } - .view - .mapValues { elt: IndexedSeq[Row] => elt.map { case Row(_, v) => v } } - .toMap - case StreamTake(a, len) => - val aValue = interpret(a, env, args) - val lenValue = interpret(len, env, args) - if (aValue == null || lenValue == null) - null - else { - val len = lenValue.asInstanceOf[Int] - if (len < 0) fatal("stream take: negative num") - aValue.asInstanceOf[IndexedSeq[Any]].take(len) - } - case StreamDrop(a, num) => - val aValue = interpret(a, env, args) - val numValue = interpret(num, env, args) - if (aValue == null || numValue == null) - null - else { - val n = numValue.asInstanceOf[Int] - if (n < 0) fatal("stream drop: negative num") - aValue.asInstanceOf[IndexedSeq[Any]].drop(n) - } - case StreamGrouped(a, size) => - val aValue = interpret(a, env, args) - val sizeValue = interpret(size, env, args) - if (aValue == null || sizeValue == null) - null - else { - val size = sizeValue.asInstanceOf[Int] - if (size <= 0) fatal("stream grouped: non-positive size") - aValue.asInstanceOf[IndexedSeq[Any]].grouped(size).toFastSeq - } - case StreamGroupByKey(a, key, missingEqual) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - val structType = tcoerce[TStruct](tcoerce[TStream](a.typ).elementType) - val seq = aValue.asInstanceOf[IndexedSeq[Row]] - if (seq.isEmpty) - FastSeq[IndexedSeq[Row]]() + case GroupByKey(collection) => + interpret(collection, env, args).asInstanceOf[IndexedSeq[Row]] + .groupBy { case Row(k, _) => k } + .view + .mapValues { elt: IndexedSeq[Row] => elt.map { case Row(_, v) => v } } + .toMap + case StreamTake(a, len) => + val aValue = interpret(a, env, args) + val lenValue = interpret(len, env, args) + if (aValue == null || lenValue == null) + null else { - val outer = ArraySeq.newBuilder[IndexedSeq[Row]] - val inner = ArraySeq.newBuilder[Row] - val (kType, getKey) = structType.select(key) - val keyOrd = TBaseStruct.getJoinOrdering(ctx.stateManager, kType.types, missingEqual) - var curKey: Row = getKey(seq.head) + val len = lenValue.asInstanceOf[Int] + if (len < 0) fatal("stream take: negative num") + aValue.asInstanceOf[IndexedSeq[Any]].take(len) + } + case StreamDrop(a, num) => + val aValue = interpret(a, env, args) + val numValue = interpret(num, env, args) + if (aValue == null || numValue == null) + null + else { + val n = numValue.asInstanceOf[Int] + if (n < 0) fatal("stream drop: negative num") + aValue.asInstanceOf[IndexedSeq[Any]].drop(n) + } + case StreamGrouped(a, size) => + val aValue = interpret(a, env, args) + val sizeValue = interpret(size, env, args) + if (aValue == null || sizeValue == null) + null + else { + val size = sizeValue.asInstanceOf[Int] + if (size <= 0) fatal("stream grouped: non-positive size") + aValue.asInstanceOf[IndexedSeq[Any]].grouped(size).toFastSeq + } + case StreamGroupByKey(a, key, missingEqual) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + val structType = tcoerce[TStruct](tcoerce[TStream](a.typ).elementType) + val seq = aValue.asInstanceOf[IndexedSeq[Row]] + if (seq.isEmpty) + FastSeq[IndexedSeq[Row]]() + else { + val outer = ArraySeq.newBuilder[IndexedSeq[Row]] + val inner = ArraySeq.newBuilder[Row] + val (kType, getKey) = structType.select(key) + val keyOrd = TBaseStruct.getJoinOrdering(ctx.stateManager, kType.types, missingEqual) + var curKey: Row = getKey(seq.head) - seq.foreach { elt => - val nextKey = getKey(elt) - if (!keyOrd.equiv(curKey, nextKey)) { - outer += inner.result() - inner.clear() - curKey = nextKey + seq.foreach { elt => + val nextKey = getKey(elt) + if (!keyOrd.equiv(curKey, nextKey)) { + outer += inner.result() + inner.clear() + curKey = nextKey + } + inner += elt } - inner += elt - } - outer += inner.result() + outer += inner.result() - outer.result().toFastSeq - } - } - case StreamMap(a, name, body) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - aValue.asInstanceOf[IndexedSeq[Any]].map { element => - interpret(body, env.bind(name, element), args) + outer.result().toFastSeq + } } - } - case StreamZip(as, names, body, behavior, errorID) => - val aValues = as.map(interpret(_, env, args).asInstanceOf[IndexedSeq[_]]) - if (aValues.contains(null)) - null - else { - val len = behavior match { - case ArrayZipBehavior.AssertSameLength | ArrayZipBehavior.AssumeSameLength => - val lengths = aValues.map(_.length).toSet - if (lengths.size != 1) - fatal(s"zip: length mismatch: ${lengths.mkString(", ")}", errorID) - lengths.head - case ArrayZipBehavior.TakeMinLength => - aValues.map(_.length).min - case ArrayZipBehavior.ExtendNA => - aValues.map(_.length).max + case StreamMap(a, name, body) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + aValue.asInstanceOf[IndexedSeq[Any]].map { element => + interpret(body, env.bind(name, element), args) + } } - (0 until len).map { i => - val e = - env.bindIterable(names.zip(aValues.map(a => if (i >= a.length) null else a.apply(i)))) - interpret(body, e, args) + case StreamZip(as, names, body, behavior, errorID) => + val aValues = as.map(interpret(_, env, args).asInstanceOf[IndexedSeq[_]]) + if (aValues.contains(null)) + null + else { + val len = behavior match { + case ArrayZipBehavior.AssertSameLength | ArrayZipBehavior.AssumeSameLength => + val lengths = aValues.map(_.length).toSet + if (lengths.size != 1) + fatal(s"zip: length mismatch: ${lengths.mkString(", ")}", errorID) + lengths.head + case ArrayZipBehavior.TakeMinLength => + aValues.map(_.length).min + case ArrayZipBehavior.ExtendNA => + aValues.map(_.length).max + } + (0 until len).map { i => + val e = + env.bindIterable(names.zip(aValues.map(a => + if (i >= a.length) null else a.apply(i) + ))) + interpret(body, e, args) + } } - } - case StreamMultiMerge(as, key) => - val streams = as.map(interpret(_, env, args).asInstanceOf[IndexedSeq[Row]]) - if (streams.contains(null)) - null - else { - val k = as.length - val tournament = Array.fill[Int](k)(-1) - val structType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) - val (kType, getKey) = structType.select(key) - val heads = Array.fill[Int](k)(-1) - val ordering = kType.ordering(ctx.stateManager).toOrdering.on[Row](getKey) + case StreamMultiMerge(as, key) => + val streams = as.map(interpret(_, env, args).asInstanceOf[IndexedSeq[Row]]) + if (streams.contains(null)) + null + else { + val k = as.length + val tournament = Array.fill[Int](k)(-1) + val structType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) + val (kType, getKey) = structType.select(key) + val heads = Array.fill[Int](k)(-1) + val ordering = kType.ordering(ctx.stateManager).toOrdering.on[Row](getKey) - def get(i: Int): Row = streams(i)(heads(i)) - def lt(li: Int, lv: Row, ri: Int, rv: Row): Boolean = { - val c = ordering.compare(lv, rv) - c < 0 || (c == 0 && li < ri) - } + def get(i: Int): Row = streams(i)(heads(i)) + def lt(li: Int, lv: Row, ri: Int, rv: Row): Boolean = { + val c = ordering.compare(lv, rv) + c < 0 || (c == 0 && li < ri) + } - def advance(i: Int): Unit = { - heads(i) += 1 - var winner = if (heads(i) < streams(i).length) i else k - var j = (i + k) / 2 - while (j != 0 && tournament(j) != -1) { - val challenger = tournament(j) - if (challenger != k && (winner == k || lt(j, get(challenger), i, get(winner)))) { - tournament(j) = winner - winner = challenger + def advance(i: Int): Unit = { + heads(i) += 1 + var winner = if (heads(i) < streams(i).length) i else k + var j = (i + k) / 2 + while (j != 0 && tournament(j) != -1) { + val challenger = tournament(j) + if (challenger != k && (winner == k || lt(j, get(challenger), i, get(winner)))) { + tournament(j) = winner + winner = challenger + } + j = j / 2 } - j = j / 2 + tournament(j) = winner } - tournament(j) = winner - } - for (i <- 0 until k) advance(i) + for (i <- 0 until k) advance(i) - val builder = ArraySeq.newBuilder[Row] - while (tournament(0) != k) { - val i = tournament(0) - val elt = streams(i)(heads(i)) - advance(i) - builder += elt + val builder = ArraySeq.newBuilder[Row] + while (tournament(0) != k) { + val i = tournament(0) + val elt = streams(i)(heads(i)) + advance(i) + builder += elt + } + builder.result() } - builder.result() - } - case StreamZipJoin(as, key, curKeyName, curValsName, joinF) => - val streams = as.map(interpret(_, env, args).asInstanceOf[IndexedSeq[Row]]) - if (streams.contains(null)) - null - else { - val k = as.length - val tournament = Array.fill[Int](k)(-1) - val structType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) - val (kType, getKey) = structType.select(key) - val heads = Array.fill[Int](k)(-1) - val ordering = kType.ordering(ctx.stateManager).toOrdering.on[Row](getKey) - val hasKey = TBaseStruct.getJoinOrdering(ctx.stateManager, kType.types).equivNonnull _ + case StreamZipJoin(as, key, curKeyName, curValsName, joinF) => + val streams = as.map(interpret(_, env, args).asInstanceOf[IndexedSeq[Row]]) + if (streams.contains(null)) + null + else { + val k = as.length + val tournament = Array.fill[Int](k)(-1) + val structType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) + val (kType, getKey) = structType.select(key) + val heads = Array.fill[Int](k)(-1) + val ordering = kType.ordering(ctx.stateManager).toOrdering.on[Row](getKey) + val hasKey = TBaseStruct.getJoinOrdering(ctx.stateManager, kType.types).equivNonnull _ - def get(i: Int): Row = streams(i)(heads(i)) + def get(i: Int): Row = streams(i)(heads(i)) - def advance(i: Int): Unit = { - heads(i) += 1 - var winner = if (heads(i) < streams(i).length) i else k - var j = (i + k) / 2 - while (j != 0 && tournament(j) != -1) { - val challenger = tournament(j) - if (challenger != k && (winner == k || ordering.lteq(get(challenger), get(winner)))) { - tournament(j) = winner - winner = challenger + def advance(i: Int): Unit = { + heads(i) += 1 + var winner = if (heads(i) < streams(i).length) i else k + var j = (i + k) / 2 + while (j != 0 && tournament(j) != -1) { + val challenger = tournament(j) + if ( + challenger != k && (winner == k || ordering.lteq(get(challenger), get(winner))) + ) { + tournament(j) = winner + winner = challenger + } + j = j / 2 } - j = j / 2 + tournament(j) = winner } - tournament(j) = winner - } - for (i <- 0 until k) advance(i) + for (i <- 0 until k) advance(i) - val builder = new mutable.ArrayBuffer[Any]() - while (tournament(0) != k) { - val i = tournament(0) - val elt = Array.fill[Row](k)(null) - elt(i) = streams(i)(heads(i)) - val curKey = getKey(elt(i)) - advance(i) - var j = tournament(0) - while (j != k && hasKey(getKey(get(j)), curKey)) { - elt(j) = streams(j)(heads(j)) - advance(j) - j = tournament(0) + val builder = new mutable.ArrayBuffer[Any]() + while (tournament(0) != k) { + val i = tournament(0) + val elt = Array.fill[Row](k)(null) + elt(i) = streams(i)(heads(i)) + val curKey = getKey(elt(i)) + advance(i) + var j = tournament(0) + while (j != k && hasKey(getKey(get(j)), curKey)) { + elt(j) = streams(j)(heads(j)) + advance(j) + j = tournament(0) + } + builder += interpret( + joinF, + env.bind(curKeyName -> curKey, curValsName -> elt.toFastSeq), + args, + ) } - builder += interpret( - joinF, - env.bind(curKeyName -> curKey, curValsName -> elt.toFastSeq), - args, - ) + builder.toFastSeq } - builder.toFastSeq - } - case StreamFilter(a, name, cond) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - aValue.asInstanceOf[IndexedSeq[Any]].filter { element => - // casting to boolean treats null as false - interpret(cond, env.bind(name, element), args).asInstanceOf[Boolean] + case StreamFilter(a, name, cond) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + aValue.asInstanceOf[IndexedSeq[Any]].filter { element => + // casting to boolean treats null as false + interpret(cond, env.bind(name, element), args).asInstanceOf[Boolean] + } } - } - case StreamTakeWhile(a, name, cond) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - aValue.asInstanceOf[IndexedSeq[Any]].takeWhile { element => - // casting to boolean treats null as false - interpret(cond, env.bind(name, element), args).asInstanceOf[Boolean] + case StreamTakeWhile(a, name, cond) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + aValue.asInstanceOf[IndexedSeq[Any]].takeWhile { element => + // casting to boolean treats null as false + interpret(cond, env.bind(name, element), args).asInstanceOf[Boolean] + } } - } - case StreamDropWhile(a, name, cond) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - aValue.asInstanceOf[IndexedSeq[Any]].dropWhile { element => - // casting to boolean treats null as false - interpret(cond, env.bind(name, element), args).asInstanceOf[Boolean] + case StreamDropWhile(a, name, cond) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + aValue.asInstanceOf[IndexedSeq[Any]].dropWhile { element => + // casting to boolean treats null as false + interpret(cond, env.bind(name, element), args).asInstanceOf[Boolean] + } } - } - case StreamFlatMap(a, name, body) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - aValue.asInstanceOf[IndexedSeq[Any]].flatMap { element => - val r = interpret(body, env.bind(name, element), args).asInstanceOf[IndexedSeq[Any]] - if (r != null) - r - else - None + case StreamFlatMap(a, name, body) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + aValue.asInstanceOf[IndexedSeq[Any]].flatMap { element => + val r = interpret(body, env.bind(name, element), args).asInstanceOf[IndexedSeq[Any]] + if (r != null) + r + else + None + } } - } - case StreamFold(a, zero, accumName, valueName, body) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - var zeroValue = interpret(zero, env, args) - aValue.asInstanceOf[IndexedSeq[Any]].foreach { element => - zeroValue = - interpret(body, env.bind(accumName -> zeroValue, valueName -> element), args) + case StreamFold(a, zero, accumName, valueName, body) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + var zeroValue = interpret(zero, env, args) + aValue.asInstanceOf[IndexedSeq[Any]].foreach { element => + zeroValue = + interpret(body, env.bind(accumName -> zeroValue, valueName -> element), args) + } + zeroValue } - zeroValue - } - case StreamFold2(a, accum, valueName, seq, res) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - val accVals = accum.map { case (name, value) => (name, interpret(value, env, args)) } - var e = env.bindIterable(accVals) - aValue.asInstanceOf[IndexedSeq[Any]].foreach { elt => - e = e.bind(valueName, elt) - accVals.indices.foreach(i => e = e.bind(accum(i)._1, interpret(seq(i), e, args))) + case StreamFold2(a, accum, valueName, seq, res) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + val accVals = accum.map { case (name, value) => (name, interpret(value, env, args)) } + var e = env.bindIterable(accVals) + aValue.asInstanceOf[IndexedSeq[Any]].foreach { elt => + e = e.bind(valueName, elt) + accVals.indices.foreach(i => e = e.bind(accum(i)._1, interpret(seq(i), e, args))) + } + interpret(res, e.delete(valueName), args) } - interpret(res, e.delete(valueName), args) - } - case StreamScan(a, zero, accumName, valueName, body) => - val aValue = interpret(a, env, args) - if (aValue == null) - null - else { - val zeroValue = interpret(zero, env, args) - aValue.asInstanceOf[IndexedSeq[Any]].scanLeft(zeroValue) { (accum, elt) => - interpret(body, env.bind(accumName -> accum, valueName -> elt), args) + case StreamScan(a, zero, accumName, valueName, body) => + val aValue = interpret(a, env, args) + if (aValue == null) + null + else { + val zeroValue = interpret(zero, env, args) + aValue.asInstanceOf[IndexedSeq[Any]].scanLeft(zeroValue) { (accum, elt) => + interpret(body, env.bind(accumName -> accum, valueName -> elt), args) + } } - } - case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => - val lValue = interpret(left, env, args).asInstanceOf[IndexedSeq[Any]] - val rValue = interpret(right, env, args).asInstanceOf[IndexedSeq[Any]] + case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => + val lValue = interpret(left, env, args).asInstanceOf[IndexedSeq[Any]] + val rValue = interpret(right, env, args).asInstanceOf[IndexedSeq[Any]] - if (lValue == null || rValue == null) - null - else { - val (lKeyTyp, lGetKey) = - tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).select(lKey) - val (rKeyTyp, rGetKey) = - tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).select(rKey) - assert(lKeyTyp isJoinableWith rKeyTyp) - val keyOrd = TBaseStruct.getJoinOrdering(ctx.stateManager, lKeyTyp.types) + if (lValue == null || rValue == null) + null + else { + val (lKeyTyp, lGetKey) = + tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).select(lKey) + val (rKeyTyp, rGetKey) = + tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).select(rKey) + assert(lKeyTyp isJoinableWith rKeyTyp) + val keyOrd = TBaseStruct.getJoinOrdering(ctx.stateManager, lKeyTyp.types) - def compF(lelt: Any, relt: Any): Int = - keyOrd.compare(lGetKey(lelt.asInstanceOf[Row]), rGetKey(relt.asInstanceOf[Row])) - def joinF(lelt: Any, relt: Any): Any = - interpret(join, env.bind(l -> lelt, r -> relt), args) + def compF(lelt: Any, relt: Any): Int = + keyOrd.compare(lGetKey(lelt.asInstanceOf[Row]), rGetKey(relt.asInstanceOf[Row])) + def joinF(lelt: Any, relt: Any): Any = + interpret(join, env.bind(l -> lelt, r -> relt), args) - val builder = scala.collection.mutable.ArrayBuilder.make[(Option[Int], Option[Int])] - var i = 0 - var j = 0 + val builder = scala.collection.mutable.ArrayBuilder.make[(Option[Int], Option[Int])] + var i = 0 + var j = 0 - while (i < lValue.length && j < rValue.length) { - val lelt = lValue(i) - val relt = rValue(j) - val c = compF(lelt, relt) - if (c < 0) { + while (i < lValue.length && j < rValue.length) { + val lelt = lValue(i) + val relt = rValue(j) + val c = compF(lelt, relt) + if (c < 0) { + builder += ((Some(i), None)) + i += 1 + } else if (c > 0) { + builder += ((None, Some(j))) + j += 1 + } else { + builder += ((Some(i), Some(j))) + i += 1 + if (i == lValue.length || compF(lValue(i), relt) > 0) + j += 1 + } + } + while (i < lValue.length) { builder += ((Some(i), None)) i += 1 - } else if (c > 0) { + } + while (j < rValue.length) { builder += ((None, Some(j))) j += 1 - } else { - builder += ((Some(i), Some(j))) - i += 1 - if (i == lValue.length || compF(lValue(i), relt) > 0) - j += 1 } - } - while (i < lValue.length) { - builder += ((Some(i), None)) - i += 1 - } - while (j < rValue.length) { - builder += ((None, Some(j))) - j += 1 - } - val outerResult = builder.result() - val elts: Iterator[(Option[Int], Option[Int])] = joinType match { - case "inner" => outerResult.iterator.filter { case (l, r) => - l.isDefined && r.isDefined - } - case "outer" => outerResult.iterator - case "left" => outerResult.iterator.filter { case (l, _) => l.isDefined } - case "right" => outerResult.iterator.filter { case (_, r) => r.isDefined } - } - elts.map { case (lIdx, rIdx) => - joinF(lIdx.map(lValue.apply).orNull, rIdx.map(rValue.apply).orNull) + val outerResult = builder.result() + val elts: Iterator[(Option[Int], Option[Int])] = joinType match { + case "inner" => outerResult.iterator.filter { case (l, r) => + l.isDefined && r.isDefined + } + case "outer" => outerResult.iterator + case "left" => outerResult.iterator.filter { case (l, _) => l.isDefined } + case "right" => outerResult.iterator.filter { case (_, r) => r.isDefined } + } + elts.map { case (lIdx, rIdx) => + joinF(lIdx.map(lValue.apply).orNull, rIdx.map(rValue.apply).orNull) + } + .toFastSeq } - .toFastSeq - } - case StreamFor(a, valueName, body) => - val aValue = interpret(a, env, args) - if (aValue != null) { - aValue.asInstanceOf[IndexedSeq[Any]].foreach { element => - interpret(body, env.bind(valueName -> element), args) - } - } - () - case MakeStruct(fields) => - Row.fromSeq(fields.map { case (_, fieldIR) => interpret(fieldIR, env, args) }) - case SelectFields(old, fields) => - val oldt = tcoerce[TStruct](old.typ) - val oldRow = interpret(old, env, args).asInstanceOf[Row] - if (oldRow == null) - null - else - Row.fromSeq(fields.map(id => oldRow.get(oldt.fieldIdx(id)))) - case InsertFields(old, fields, fieldOrder) => - var struct = interpret(old, env, args) - if (struct != null) - fieldOrder match { - case Some(fds) => - val m = fields.toMap - val oldIndices = - old.typ.asInstanceOf[TStruct].fields.map(f => f.name -> f.index).toMap - Row.fromSeq(fds.map(name => - m.get(name).map(interpret(_, env, args)).getOrElse( - struct.asInstanceOf[Row].get(oldIndices(name)) - ) - )) - case None => - var t = old.typ.asInstanceOf[TStruct] - fields.foreach { case (name, body) => - val (newT, ins) = t.insert(body.typ, FastSeq(name)) - t = newT.asInstanceOf[TStruct] - struct = ins(struct, interpret(body, env, args)) - } - struct + case StreamFor(a, valueName, body) => + val aValue = interpret(a, env, args) + if (aValue != null) { + aValue.asInstanceOf[IndexedSeq[Any]].foreach { element => + interpret(body, env.bind(valueName -> element), args) + } } - else - null - - case GetField(o, name) => - val oValue = interpret(o, env, args) - if (oValue == null) - null - else { - val oType = o.typ.asInstanceOf[TStruct] - val fieldIndex = oType.fieldIdx(name) - oValue.asInstanceOf[Row].get(fieldIndex) - } - case MakeTuple(types) => - Row.fromSeq(types.map { case (_, x) => interpret(x, env, args) }) - case GetTupleElement(o, idx) => - val oValue = interpret(o, env, args) - if (oValue == null) - null - else - oValue.asInstanceOf[Row].get(o.typ.asInstanceOf[TTuple].fieldIndex(idx)) - case In(i, _) => - val (a, _) = args(i) - a - case Die(message, _, errorId) => - val message_ = interpret(message).asInstanceOf[String] - fatal(if (message_ != null) message_ else "", errorId) - case ConsoleLog(message, result) => - val message_ = interpret(message).asInstanceOf[String] - logger.info(message_) - interpret(result) - case ir @ ApplyIR(_, _, _, _, _) => - interpret(ir.explicitNode, env, args) - case ApplySpecial("lor", _, Seq(left_, right_), _, _) => - val left = interpret(left_) - if (left == true) - true - else { - val right = interpret(right_) - if (right == true) - true - else if (left == null || right == null) + () + case MakeStruct(fields) => + Row.fromSeq(fields.map { case (_, fieldIR) => interpret(fieldIR, env, args) }) + case SelectFields(old, fields) => + val oldt = tcoerce[TStruct](old.typ) + val oldRow = interpret(old, env, args).asInstanceOf[Row] + if (oldRow == null) null - else false - } - case ApplySpecial("land", _, Seq(left_, right_), _, _) => - val left = interpret(left_) - if (left == false) - false - else { - val right = interpret(right_) - if (right == false) - false - else if (left == null || right == null) + else + Row.fromSeq(fields.map(id => oldRow.get(oldt.fieldIdx(id)))) + case InsertFields(old, fields, fieldOrder) => + var struct = interpret(old, env, args) + if (struct != null) + fieldOrder match { + case Some(fds) => + val m = fields.toMap + val oldIndices = + old.typ.asInstanceOf[TStruct].fields.map(f => f.name -> f.index).toMap + Row.fromSeq(fds.map(name => + m.get(name).map(interpret(_, env, args)).getOrElse( + struct.asInstanceOf[Row].get(oldIndices(name)) + ) + )) + case None => + var t = old.typ.asInstanceOf[TStruct] + fields.foreach { case (name, body) => + val (newT, ins) = t.insert(body.typ, FastSeq(name)) + t = newT.asInstanceOf[TStruct] + struct = ins(struct, interpret(body, env, args)) + } + struct + } + else null - else true - } - case ir: AbstractApplyNode[_] => - val argTuple = - PType.canonical(TTuple(ir.args.map(_.typ): _*)).setRequired(true).asInstanceOf[PTuple] - ctx.r.pool.scopedRegion { region => - val (rt, f) = functionMemo.getOrElseUpdate( - ir, { - val in = Ref(freshName(), argTuple.virtualType) - val wrappedIR = ir.mapChildrenWithIndex { case (_, i) => GetTupleElement(in, i) } - val (rt, makeFunction) = - Compile[AsmFunction2RegionLongLong]( - ctx, - FastSeq(( - in.name, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argTuple)), - )), - FastSeq(classInfo[Region], LongInfo), - LongInfo, - MakeTuple.ordered(FastSeq(wrappedIR)), - ) - (rt.get, makeFunction(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region)) - }, - ) - val rvb = new RegionValueBuilder(ctx.stateManager) - rvb.set(region) - rvb.start(argTuple) - rvb.startTuple() - ir.args.zip(argTuple.types).foreach { case (arg, t) => - val argValue = interpret(arg, env, args) - rvb.addAnnotation(t.virtualType, argValue) + case GetField(o, name) => + val oValue = interpret(o, env, args) + if (oValue == null) + null + else { + val oType = o.typ.asInstanceOf[TStruct] + val fieldIndex = oType.fieldIdx(name) + oValue.asInstanceOf[Row].get(fieldIndex) } - rvb.endTuple() - val offset = rvb.end() - - try { - val resultOffset = f(region, offset) - SafeRow( - rt.asInstanceOf[PTypeReferenceSingleCodeType].pt.asInstanceOf[PTuple], - resultOffset, - ).get(0) - } catch { - case e: Exception => - fatal(s"error while calling '${ir.implementation.name}': ${e.getMessage}", e) + case MakeTuple(types) => + Row.fromSeq(types.map { case (_, x) => interpret(x, env, args) }) + case GetTupleElement(o, idx) => + val oValue = interpret(o, env, args) + if (oValue == null) + null + else + oValue.asInstanceOf[Row].get(o.typ.asInstanceOf[TTuple].fieldIndex(idx)) + case In(i, _) => + val (a, _) = args(i) + a + case Die(message, _, errorId) => + val message_ = interpret(message, env, args).asInstanceOf[String] + fatal(if (message_ != null) message_ else "", errorId) + case ConsoleLog(message, result) => + val message_ = interpret(message, env, args).asInstanceOf[String] + logger.info(message_) + interpret(result, env, args) + case ir @ ApplyIR(_, _, _, _, _) => + interpret(ir.explicitNode, env, args) + case ApplySpecial("lor", _, Seq(left_, right_), _, _) => + val left = interpret(left_, env, args) + if (left == true) + true + else { + val right = interpret(right_, env, args) + if (right == true) + true + else if (left == null || right == null) + null + else false } - } - case TableCount(child) => - PartitionCounts(child) - .map(_.sum) - .getOrElse(ExecuteRelational(ctx, child).asTableValue(ctx).rvd.count()) - case TableGetGlobals(child) => - ExecuteRelational(ctx, child).asTableValue(ctx).globals.safeJavaValue - case TableCollect(child) => - val tv = ExecuteRelational(ctx, child).asTableValue(ctx) - Row(tv.rvd.collect(ctx).toFastSeq, tv.globals.safeJavaValue) - case TableMultiWrite(children, writer) => - val tvs = children.map(child => ExecuteRelational(ctx, child).asTableValue(ctx)) - writer(ctx, tvs) - case TableWrite(child, writer) => - writer(ctx, ExecuteRelational(ctx, child).asTableValue(ctx)) - case BlockMatrixWrite(child, writer) => - writer(ctx, child.execute(ctx)) - case BlockMatrixMultiWrite(blockMatrices, writer) => - writer(ctx, blockMatrices.map(_.execute(ctx))) - case TableToValueApply(child, function) => - function.execute(ctx, ExecuteRelational(ctx, child).asTableValue(ctx)) - case BlockMatrixToValueApply(child, function) => - function.execute(ctx, child.execute(ctx)) - case BlockMatrixCollect(child) => - val bm = child.execute(ctx) - // transpose because breeze toArray is column major - val breezeMat = bm.transpose().toBreezeMatrix() - val shape = IndexedSeq(bm.nRows, bm.nCols) - SafeNDArray(shape, breezeMat.toArray) - case x @ TableAggregate(child, query) => - val value = ExecuteRelational(ctx, child).asTableValue(ctx) - val fsBc = ctx.fsBc - - val globalsBc = value.globals.broadcast(ctx.theHailClassLoader) - val globalsOffset = value.globals.value.offset - - val extracted = agg.Extract(ctx, query, Requiredness(x, ctx)).independent - val aggSigs = extracted.sigs + case ApplySpecial("land", _, Seq(left_, right_), _, _) => + val left = interpret(left_, env, args) + if (left == false) + false + else { + val right = interpret(right_, env, args) + if (right == false) + false + else if (left == null || right == null) + null + else true + } + case ir: AbstractApplyNode[_] => + val in = Ref(freshName(), TTuple(ir.args.map(_.typ): _*)) + val argPType = PType.canonical(in.typ).setRequired(true).asInstanceOf[PTuple] + val wrappedIR = ir.mapChildrenWithIndex { case (_, i) => GetTupleElement(in, i) } - val wrapped = if (aggSigs.isEmpty) { - val (Some(PTypeReferenceSingleCodeType(rt: PTuple)), f) = + val (Some(rt), mkApply) = Compile[AsmFunction2RegionLongLong]( ctx, FastSeq(( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + in.name, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argPType)), )), FastSeq(classInfo[Region], LongInfo), LongInfo, - MakeTuple.ordered(FastSeq(extracted.result)), + MakeTuple.ordered(FastSeq(wrappedIR)), ) - // TODO Is this right? where does wrapped run? - ctx.scopedExecution((hcl, fs, htc, r) => - SafeRow(rt, f(hcl, fs, htc, r)(r, globalsOffset)) - ) - } else { - val spec = BufferSpec.blockedUncompressed + val inArgs = ir.args.map(interpret(_, env, args)) - val (_, initOp) = CompileWithAggregators[AsmFunction2RegionLongUnit]( - ctx, - aggSigs.states, - FastSeq(( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), - )), - FastSeq(classInfo[Region], LongInfo), - UnitInfo, - extracted.init, - ) + ctx.scopedExecution { (hcl, fs, htc, r) => + val rvb = new RegionValueBuilder(ctx.stateManager) + rvb.set(r) + rvb.start(argPType) + rvb.startTuple() + for ((argv, t) <- inArgs.zip(argPType.types)) + rvb.addAnnotation(t.virtualType, argv) + rvb.endTuple() + val offset = rvb.end() - val (_, partitionOpSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( - ctx, - aggSigs.states, - FastSeq( - ( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), - ), - ( - TableIR.rowName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.rvd.rowPType)), - ), - ), - FastSeq(classInfo[Region], LongInfo, LongInfo), - UnitInfo, - extracted.seqPerElt, - ) + try + SafeRow( + rt.asInstanceOf[PTypeReferenceSingleCodeType].pt.asInstanceOf[PTuple], + mkApply(hcl, fs, htc, r)(r, offset), + ).get(0) + catch { + case e: Exception => + fatal(s"error while calling '${ir.implementation.name}': ${e.getMessage}", e) + } + } + case TableCount(child) => + PartitionCounts(child) + .map(_.sum) + .getOrElse(ExecuteRelational(ctx, child).asTableValue(ctx).rvd.count()) + case TableGetGlobals(child) => + ExecuteRelational(ctx, child).asTableValue(ctx).globals.safeJavaValue + case TableCollect(child) => + val tv = ExecuteRelational(ctx, child).asTableValue(ctx) + Row(tv.rvd.collect(ctx).toFastSeq, tv.globals.safeJavaValue) + case TableMultiWrite(children, writer) => + val tvs = children.map(child => ExecuteRelational(ctx, child).asTableValue(ctx)) + writer(ctx, tvs) + case TableWrite(child, writer) => + writer(ctx, ExecuteRelational(ctx, child).asTableValue(ctx)) + case BlockMatrixWrite(child, writer) => + writer(ctx, child.execute(ctx)) + case BlockMatrixMultiWrite(blockMatrices, writer) => + writer(ctx, blockMatrices.map(_.execute(ctx))) + case TableToValueApply(child, function) => + function.execute(ctx, ExecuteRelational(ctx, child).asTableValue(ctx)) + case BlockMatrixToValueApply(child, function) => + function.execute(ctx, child.execute(ctx)) + case BlockMatrixCollect(child) => + val bm = child.execute(ctx) + // transpose because breeze toArray is column major + val breezeMat = bm.transpose().toBreezeMatrix() + val shape = IndexedSeq(bm.nRows, bm.nCols) + SafeNDArray(shape, breezeMat.toArray) + case x @ TableAggregate(child, query) => + val value = ExecuteRelational(ctx, child).asTableValue(ctx) - val useTreeAggregate = aggSigs.shouldTreeAggregate - val isCommutative = aggSigs.isCommutative - logger.info(s"Aggregate: useTreeAggregate=$useTreeAggregate") - logger.info(s"Aggregate: commutative=$isCommutative") + val globalsBc = value.globals.broadcast(ctx.theHailClassLoader) + val globalsOffset = value.globals.value.offset - // A mutable reference to a byte array. If someone higher up the - // call stack holds a WrappedByteArray, we can set the reference - // to null to allow the array to be GCed. - class WrappedByteArray(_bytes: Array[Byte]) { - private var ref: Array[Byte] = _bytes - def bytes: Array[Byte] = ref - def clear(): Unit = ref = null - } + val extracted = agg.Extract(ctx, query, Requiredness(x, ctx)).independent + val aggSigs = extracted.sigs - // creates a region, giving ownership to the caller - val read: (HailClassLoader, HailTaskContext) => (WrappedByteArray => RegionValue) = { - val deserialize = aggSigs.deserialize(ctx, spec) - (hcl: HailClassLoader, htc: HailTaskContext) => { - (a: WrappedByteArray) => - val r = Region(Region.SMALL, htc.getRegionPool()) - val res = deserialize(hcl, htc, r, a.bytes) - a.clear() - RegionValue(r, res) + val wrapped = if (aggSigs.isEmpty) { + val (Some(PTypeReferenceSingleCodeType(rt: PTuple)), f) = + Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(( + TableIR.globalName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + MakeTuple.ordered(FastSeq(extracted.result)), + ) + + // TODO Is this right? where does wrapped run? + ctx.scopedExecution { (hcl, fs, htc, r) => + SafeRow(rt, f(hcl, fs, htc, r)(r, globalsOffset)) } - } + } else { + val spec = BufferSpec.blockedUncompressed - // consumes a region, taking ownership from the caller - val write: (HailClassLoader, HailTaskContext, RegionValue) => WrappedByteArray = { - val serialize = aggSigs.serialize(ctx, spec) - (hcl: HailClassLoader, htc: HailTaskContext, rv: RegionValue) => { - val a = serialize(hcl, htc, rv.region, rv.offset) - rv.region.invalidate() - new WrappedByteArray(a) + val (_, initOp) = + CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, + aggSigs.states, + FastSeq(( + TableIR.globalName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + extracted.init, + ) + + val (_, partitionOpSeq) = + CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + ctx, + aggSigs.states, + FastSeq( + ( + TableIR.globalName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + ), + ( + TableIR.rowName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.rvd.rowPType)), + ), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + UnitInfo, + extracted.seqPerElt, + ) + + val useTreeAggregate = aggSigs.shouldTreeAggregate + val isCommutative = aggSigs.isCommutative + logger.info(s"Aggregate: useTreeAggregate=$useTreeAggregate") + logger.info(s"Aggregate: commutative=$isCommutative") + + // A mutable reference to a byte array. If someone higher up the + // call stack holds a WrappedByteArray, we can set the reference + // to null to allow the array to be GCed. + class WrappedByteArray(_bytes: Array[Byte]) { + private var ref: Array[Byte] = _bytes + def bytes: Array[Byte] = ref + def clear(): Unit = ref = null } - } - // takes ownership of both inputs, returns ownership of result - val combOpF: (HailClassLoader, HailTaskContext, RegionValue, RegionValue) => RegionValue = - aggSigs.combOpF(ctx, spec) + val read: Compiled[(Region, WrappedByteArray) => Long] = + aggSigs.deserialize(ctx, spec).map { f => (r, a) => + val res = f(r, a.bytes) + a.clear() + res + } + + val write: Compiled[(Region, Long) => WrappedByteArray] = + aggSigs.serialize(ctx, spec).map { fn => (r, offset) => + new WrappedByteArray(fn(r, offset)) + } + + val combOpF = aggSigs.combOpF(ctx, spec) + + val seqF: Compiled[(Region, Iterator[Long]) => Long] = { + (hcl, fs, ctx, r) => + val globalsOffset = globalsBc.value.readRegionValue(r, hcl) + val init = initOp(hcl, fs, ctx, r) + val seqOps = partitionOpSeq(hcl, fs, ctx, r) + (aggRegion, it) => + init.newAggState(aggRegion) + init(r, globalsOffset) + seqOps.setAggState(aggRegion, init.getAggOffset()) - // returns ownership of a new region holding the partition aggregation - // result - def itF(theHailClassLoader: HailClassLoader, i: Int, ctx: RVDContext, it: Iterator[Long]) - : RegionValue = { - val partRegion = ctx.partitionRegion - val globalsOffset = globalsBc.value.readRegionValue(partRegion, theHailClassLoader) - val init = initOp(theHailClassLoader, fsBc.value, SparkTaskContext.get(), partRegion) - val seqOps = - partitionOpSeq(theHailClassLoader, fsBc.value, SparkTaskContext.get(), partRegion) - val aggRegion = ctx.freshRegion(Region.SMALL) + aggRegion.pool.scopedSmallRegion { inner => + for (ptr <- it) { + seqOps(inner, globalsOffset, ptr) + inner.clear() + } + } - init.newAggState(aggRegion) - init(partRegion, globalsOffset) - seqOps.setAggState(aggRegion, init.getAggOffset()) - it.foreach { ptr => - seqOps(ctx.region, globalsOffset, ptr) - ctx.region.clear() + seqOps.getAggOffset() } - RegionValue(aggRegion, seqOps.getAggOffset()) - } + // writes the `Zero` value into the caller-provided region + val mkZero: Compiled[Region => Long] = { (hcl, fs, htc, outer) => + val initF = initOp(hcl, fs, htc, outer) + r => + initF.newAggState(r) + initF(r, globalsBc.value.readRegionValue(r, hcl)) + initF.getAggOffset() + } - // creates a new region holding the zero value, giving ownership to - // the caller - val mkZero = (theHailClassLoader: HailClassLoader, tc: HailTaskContext) => { - val region = Region(Region.SMALL, tc.getRegionPool()) - val initF = initOp(theHailClassLoader, fsBc.value, tc, region) - initF.newAggState(region) - initF(region, globalsBc.value.readRegionValue(region, theHailClassLoader)) - RegionValue(region, initF.getAggOffset()) - } + val rvF = + value.rvd.combine[WrappedByteArray, Long]( + ctx.fsBc, + mkZero, + seqF, + read, + write, + combOpF, + isCommutative, + someIf(useTreeAggregate, ctx.branchingFactor), + ) - val rv = value.rvd.combine[WrappedByteArray, RegionValue]( - ctx, mkZero, itF, read, write, combOpF, isCommutative, useTreeAggregate) + val (Some(PTypeReferenceSingleCodeType(rTyp: PTuple)), f) = + CompileWithAggregators[AsmFunction2RegionLongLong]( + ctx, + aggSigs.states, + FastSeq(( + TableIR.globalName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + MakeTuple.ordered(FastSeq(extracted.result)), + ) - val (Some(PTypeReferenceSingleCodeType(rTyp: PTuple)), f) = - CompileWithAggregators[AsmFunction2RegionLongLong]( - ctx, - aggSigs.states, - FastSeq(( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), - )), - FastSeq(classInfo[Region], LongInfo), - LongInfo, - MakeTuple.ordered(FastSeq(extracted.result)), - ) - assert(rTyp.types(0).virtualType == query.typ) + assert(rTyp.types(0).virtualType == query.typ) - ctx.r.pool.scopedRegion { r => - val resF = f(ctx.theHailClassLoader, fsBc.value, ctx.taskContext, r) - resF.setAggState(rv.region, rv.offset) - val resAddr = resF(r, globalsOffset) - val res = SafeRow(rTyp, resAddr) - resF.storeAggsToRegion() - rv.region.invalidate() - res + ctx.scopedExecution { (hcl, fs, htc, r) => + val offset = rvF(hcl, fs, htc, r)(r) + val resF = f(hcl, fs, htc, r) + + resF.setAggState(r, offset) + val resAddr = resF(r, globalsOffset) + val res = SafeRow(rTyp, resAddr) + resF.storeAggsToRegion() + res + } } - } - wrapped.get(0) - case UUID4(_) => - uuid4() - } + wrapped.get(0) + case UUID4(_) => + uuid4() + } + + ctx.time(interpret(ir, env, args)) } } diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 580b21db821..763d1857756 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir import is.hail.annotations._ import is.hail.asm4s._ import is.hail.asm4s.implicits.valueToRichCodeInputBuffer -import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer} +import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext} import is.hail.collection.FastSeq import is.hail.collection.implicits.toRichIterable import is.hail.expr.ir.defs._ @@ -28,6 +28,8 @@ import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.virtual._ import is.hail.utils._ +import scala.collection.mutable + import java.io.{Closeable, InputStream} import org.apache.spark.sql.Row @@ -862,6 +864,13 @@ case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: Str override def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats) } +class TaskFinalizer extends (() => Unit) { + private[this] val closeables = mutable.ArrayBuffer.empty[Closeable] + override def apply(): Unit = closeables.foreach(_.close()) + def addCloseable(c: Closeable): Unit = closeables += c + def clear(): Unit = closeables.clear() +} + case class PartitionNativeIntervalReader( sm: HailStateManager, tablePath: String, @@ -942,8 +951,10 @@ case class PartitionNativeIntervalReader( val currIdxInPartition = mb.genFieldThisRef[Long]("n_to_read") val stopIdxInPartition = mb.genFieldThisRef[Long]("n_to_read") + val finalizer = mb.genFieldThisRef[TaskFinalizer]("finalizer") - cb.assign(finalizer, cb.emb.ecb.getTaskContext.invoke[TaskFinalizer]("newFinalizer")) + cb.assign(finalizer, Code.newInstance[TaskFinalizer]()) + cb += mb.getTaskContext.invoke[() => Unit, Unit]("onClose", finalizer) val startPartitionIndex = mb.genFieldThisRef[Int]("start_part") val currPartitionIdx = mb.genFieldThisRef[Int]("curr_part") diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index a8b82be5a2a..da00fe8f058 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -3,11 +3,12 @@ package is.hail.expr.ir import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.backend.spark.{SparkBackend, SparkTaskContext} +import is.hail.backend.spark.{unsafeHailClassLoaderForSparkWorkers, SparkBackend, SparkTaskContext} import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.implicits.toRichIterable import is.hail.expr.TableAnnotationImpex +import is.hail.expr.ir.TableValue.readToBytes import is.hail.expr.ir.agg.IndependentExtractedAggs import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.{RVDToTableStage, TableStage, TableStageToRVD} @@ -303,6 +304,13 @@ object TableValue extends Logging { ), ) } + + private[TableValue] def readToBytes(is: DataInputStream): Array[Byte] = { + val len = is.readInt() + val b = new Array[Byte](len) + is.readFully(b) + b + } } case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow, rvd: RVD) @@ -354,7 +362,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val sb = new StringBuilder() it.map { ptr => - val ur = new UnsafeRow(localSignature, ctx.r, ptr) + val ur = new UnsafeRow(localSignature, ctx.region, ptr) sb.clear() localTypes.indices.foreachBetween { i => sb ++= TableAnnotationImpex.exportAnnotation(ur.get(i), localTypes(i)) @@ -509,18 +517,18 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow .repartition(ctx, prevRVD.partitioner.strictify()) .boundary .mapPartitionsWithIndex(newRVDType) { (i, hcl, ctx, it) => - val partRegion = ctx.partitionRegion + val partRegion = ctx.r val globalsOff = globalsBc.value.readRegionValue(partRegion, hcl) val initialize = makeInit( hcl, fsBc.value, - SparkTaskContext.get(), + ctx, partRegion, ) - val sequence = makeSeq(hcl, fsBc.value, SparkTaskContext.get(), partRegion) - val newRowF = makeRow(hcl, fsBc.value, SparkTaskContext.get(), partRegion) + val sequence = makeSeq(hcl, fsBc.value, ctx, partRegion) + val newRowF = makeRow(hcl, fsBc.value, ctx, partRegion) val aggRegion = ctx.freshRegion() @@ -548,11 +556,11 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow aggRegion.clear() initialize.newAggState(aggRegion) - initialize(ctx.r, globalsOff) + initialize(ctx.region, globalsOff) sequence.setAggState(aggRegion, initialize.getAggOffset()) do { - sequence(ctx.r, globalsOff, current) + sequence(ctx.region, globalsOff, current) current = 0 } while (hasNext && keyOrd.equiv(rowKey.value.offset, current)) newRowF.setAggState(aggRegion, sequence.getAggOffset()) @@ -632,9 +640,9 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow typ.copy(rowType = newRow.typ), globals, rvd.boundary.mapPartitionsWithIndex(rvdType) { (i, hcl, ctx, it) => - val globalRegion = ctx.partitionRegion - val lenF = l(hcl, fsBc.value, SparkTaskContext.get(), globalRegion) - val rowF = f(hcl, fsBc.value, SparkTaskContext.get(), globalRegion) + val globalRegion = ctx.r + val lenF = l(hcl, fsBc.value, ctx, globalRegion) + val rowF = f(hcl, fsBc.value, ctx, globalRegion) it.flatMap { ptr => val len = lenF(ctx.region, ptr) new Iterator[Long] { @@ -681,12 +689,12 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow copy(rvd = rvd.filterWithContext[(AsmFunction3RegionLongLongBoolean, Long)]( { (_, hcl, rvdCtx) => - val globalRegion = rvdCtx.partitionRegion + val globalRegion = rvdCtx.r ( f( hcl, fsBc.value, - SparkTaskContext.get(), + rvdCtx, globalRegion, ), localGlobals.value.readRegionValue(globalRegion, hcl), @@ -858,15 +866,21 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val serialize = aggSigs.serialize(ctx, spec) val deserialize = aggSigs.deserialize(ctx, spec) - val combOp = aggSigs.combOpFSerializedWorkersOnly(ctx, spec) - - val tc = ctx.taskContext - val initF = makeInit(ctx.theHailClassLoader, fsBc.value, tc, ctx.r) - val globalsOffset = globals.value.offset - val initAggs = ctx.r.pool.scopedRegion { aggRegion => - initF.newAggState(aggRegion) - initF(ctx.r, globalsOffset) - serialize(ctx.theHailClassLoader, tc, aggRegion, initF.getAggOffset()) + val combOp = { + val loadFn = aggSigs.combOpFSerializedFromRegionPool(ctx, spec) + (as: Array[Byte], bs: Array[Byte]) => + val htc = SparkTaskContext.get + val comb = loadFn(unsafeHailClassLoaderForSparkWorkers, fsBc.value, htc, htc.r) + comb(as, bs) + } + + val initAggs = ctx.scopedExecution { (hcl, fs, htc, r) => + val initF = makeInit(hcl, fs, htc, htc.r) + val write = serialize(hcl, fs, htc, htc.r) + + initF.newAggState(r) + initF(htc.r, globals.value.offset) + write(r, initF.getAggOffset()) } val newRowType = PCanonicalStruct( @@ -880,11 +894,10 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val rdd = rvd .boundary .mapPartitionsWithIndex { (i, hcl, ctx, it) => - val partRegion = ctx.partitionRegion - val tc = SparkTaskContext.get() + val partRegion = ctx.r val globals = globalsBc.value.readRegionValue(partRegion, hcl) val makeKey = { - val f = makeKeyF(hcl, fsBc.value, tc, partRegion) + val f = makeKeyF(hcl, fsBc.value, ctx, partRegion) ptr: Long => { val keyOff = f(ctx.region, ptr, globals) SafeRow.read(localKeyPType, keyOff).asInstanceOf[Row] @@ -892,11 +905,11 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } val makeAgg = { () => val aggRegion = ctx.freshRegion() - RegionValue(aggRegion, deserialize(hcl, tc, aggRegion, initAggs)) + RegionValue(aggRegion, deserialize(hcl, fsBc.value, ctx, ctx.r)(aggRegion, initAggs)) } val seqOp = { - val f = makeSeq(hcl, fsBc.value, SparkTaskContext.get(), partRegion) + val f = makeSeq(hcl, fsBc.value, ctx, partRegion) (ptr: Long, agg: RegionValue) => { f.setAggState(agg.region, agg.offset) f(ctx.region, globals, ptr) @@ -905,9 +918,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } } val serializeAndCleanupAggs = { rv: RegionValue => - val a = serialize(hcl, tc, rv.region, rv.offset) - rv.region.close() - a + using(rv.region)(r => serialize(hcl, fsBc.value, ctx, r)(r, rv.offset)) } new BufferedAggregatorIterator[Long, RegionValue, Array[Byte], Row]( @@ -916,7 +927,8 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow makeKey, seqOp, serializeAndCleanupAggs, - localBufferSize) + localBufferSize, + ) }.aggregateByKey(initAggs, nPartitions.getOrElse(rvd.getNumPartitions))(combOp, combOp) val keyType = tcoerce[TStruct](newKey.typ) @@ -924,10 +936,9 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val region = ctx.region val rvb = new RegionValueBuilder(sm) - val partRegion = ctx.partitionRegion - val tc = SparkTaskContext.get() + val partRegion = ctx.r val globals = globalsBc.value.readRegionValue(partRegion, hcl) - val annotate = makeAnnotate(hcl, fsBc.value, tc, partRegion) + val annotate = makeAnnotate(hcl, fsBc.value, ctx, partRegion) it.map { case (key, aggs) => rvb.set(region) @@ -939,7 +950,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow i += 1 } - val aggOff = deserialize(hcl, tc, region, aggs) + val aggOff = deserialize(hcl, fsBc.value, ctx, ctx.r)(region, aggs) annotate.setAggState(region, aggOff) rvb.addAllFields(rTyp, region, annotate(region, globals)) rvb.endStruct() @@ -974,8 +985,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow )), ) - val resultOff = - f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r, globals.value.offset) + val resultOff = f(ctx.theHailClassLoader, ctx.fs, ctx, ctx.r)(ctx.r, globals.value.offset) val newType = typ.copy(globalType = newGlobals.typ.asInstanceOf[TStruct]) copy(typ = newType, globals = BroadcastRow(ctx, RegionValue(ctx.r, resultOff), resultPType)) @@ -1031,9 +1041,8 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow makeIterator( hcl, fsBc.value, - SparkTaskContext.get(), consumerCtx, - globalsBc.value.readRegionValue(consumerCtx.partitionRegion, hcl), + globalsBc.value.readRegionValue(consumerCtx.r, hcl), boxedPartition, ).map(l => l.longValue()) } @@ -1057,18 +1066,14 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val newType = typ.copy(rowType = extracted.result.typ.asInstanceOf[TStruct]) if (aggSigs.isEmpty) { - val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = + val (Some(PTypeReferenceSingleCodeType(rTyp)), rowFn) = Compile[AsmFunction3RegionLongLongLong]( ctx, FastSeq( - ( - TableIR.globalName, + TableIR.globalName -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - ), - ( - TableIR.rowName, + TableIR.rowName -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), - ), ), FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, @@ -1078,28 +1083,23 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow )), ) - val rowIterationNeedsGlobals = Mentions(extracted.result, TableIR.globalName) val globalsBc = - if (rowIterationNeedsGlobals) - globals.broadcast(ctx.theHailClassLoader) - else - null - - val fsBc = ctx.fsBc - val itF = { (i: Int, hcl: HailClassLoader, ctx: RVDContext, it: Iterator[Long]) => - val globalRegion = ctx.partitionRegion - val globals = if (rowIterationNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, hcl) - else - 0 - - val newRow = f(hcl, fsBc.value, SparkTaskContext.get(), globalRegion) - it.map(ptr => newRow(ctx.r, globals, ptr)) - } + someIf( + Mentions(extracted.result, TableIR.globalName), + globals.broadcast(ctx.theHailClassLoader), + ) return copy( typ = newType, - rvd = rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF), + rvd = rvd.mapPartitions(RVDType(rTyp.asInstanceOf[PStruct], typ.key)) { + (hcl: HailClassLoader, ctx: RVDContext, it: Iterator[Long]) => + val globals = globalsBc + .map(_.value.readRegionValue(ctx.r, hcl)) + .getOrElse(0L) + + val newRow = rowFn(hcl, fsBc.value, ctx, ctx.r) + it.map(ptr => newRow(ctx.region, globals, ptr)) + }, ) } @@ -1108,10 +1108,10 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val rowIterationNeedsGlobals = Mentions(extracted.result, TableIR.globalName) val globalsBc = - if (rowIterationNeedsGlobals || scanInitNeedsGlobals || scanSeqNeedsGlobals) - globals.broadcast(ctx.theHailClassLoader) - else - null + someIf( + rowIterationNeedsGlobals || scanInitNeedsGlobals || scanSeqNeedsGlobals, + globals.broadcast(ctx.theHailClassLoader), + ) val spec = BufferSpec.blockedUncompressed @@ -1121,53 +1121,47 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow // 3. load in partition aggregations, comb op as necessary, serialize. // 4. load in partStarts, calculate newRow based on those results. - val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit]( - ctx, - aggSigs.states, - FastSeq(( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - )), - FastSeq(classInfo[Region], LongInfo), - UnitInfo, - extracted.init, - ) - - val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( - ctx, - aggSigs.states, - FastSeq( - ( - TableIR.globalName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), + val (_, initFn) = + CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, + aggSigs.states, + FastSeq( + TableIR.globalName -> + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)) ), - ( - TableIR.rowName, - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + extracted.init, + ) + + val (_, eltSeqFn) = + CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + ctx, + aggSigs.states, + FastSeq( + TableIR.globalName -> + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), + TableIR.rowName -> + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), ), - ), - FastSeq(classInfo[Region], LongInfo, LongInfo), - UnitInfo, - extracted.seqPerElt, - ) + FastSeq(classInfo[Region], LongInfo, LongInfo), + UnitInfo, + extracted.seqPerElt, + ) - val read = aggSigs.deserialize(ctx, spec) - val write = aggSigs.serialize(ctx, spec) - val combOpFNeedsPool = aggSigs.combOpFSerializedFromRegionPool(ctx, spec) + val readFn = aggSigs.deserialize(ctx, spec) + val writeFn = aggSigs.serialize(ctx, spec) + val combOpFn = aggSigs.combOpFSerializedFromRegionPool(ctx, spec) - val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = + val (Some(PTypeReferenceSingleCodeType(rTyp)), rowFn) = CompileWithAggregators[AsmFunction3RegionLongLongLong]( ctx, aggSigs.states, FastSeq( - ( - TableIR.globalName, + TableIR.globalName -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globals.t)), - ), - ( - TableIR.rowName, + TableIR.rowName -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)), - ), ), FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, @@ -1178,43 +1172,49 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow ) // 1. init op on all aggs and write out to initPath - val initAgg = ctx.r.pool.scopedRegion { aggRegion => - ctx.r.pool.scopedRegion { fRegion => - val init = initF(ctx.theHailClassLoader, fsBc.value, ctx.taskContext, fRegion) - init.newAggState(aggRegion) - init(fRegion, globals.value.offset) - write(ctx.theHailClassLoader, ctx.taskContext, aggRegion, init.getAggOffset()) + val initAgg = ctx.scopedExecution { (hcl, fs, htc, r) => + val init = initFn(hcl, fs, htc, r) + val serialize = writeFn(hcl, fs, htc, r) + htc.r.pool.scopedRegion { inner => + init.newAggState(r) + init(inner, globals.value.offset) + serialize(inner, init.getAggOffset()) } } if (ctx.getFlag("distributed_scan_comb_op") != null && extracted.sigs.shouldTreeAggregate) { - val fsBc = ctx.fsBc val tmpBase = ctx.createTmpPath("table-map-rows-distributed-scan") val d = digitsNeeded(rvd.getNumPartitions) val files = rvd.mapPartitionsWithIndex { (i, hcl, ctx, it) => val path = tmpBase + "/" + partFile(d, i, TaskContext.get()) + + val fs = fsBc.value + val globalRegion = ctx.freshRegion() - val globals = if (scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, hcl) - else 0 + val globals = + someIf(scanSeqNeedsGlobals, globalsBc) + .flatten + .map(_.value.readRegionValue(globalRegion, hcl)) + .getOrElse(0L) + + val read = readFn(hcl, fs, ctx, globalRegion) + val seq = eltSeqFn(hcl, fs, ctx, globalRegion) + val write = writeFn(hcl, fs, ctx, globalRegion) ctx.r.pool.scopedSmallRegion { aggRegion => - val tc = SparkTaskContext.get() - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) + seq.setAggState(aggRegion, read(aggRegion, initAgg)) - seq.setAggState( - aggRegion, - read(hcl, tc, aggRegion, initAgg), - ) it.foreach { ptr => seq(ctx.region, globals, ptr) ctx.region.clear() } - using(new DataOutputStream(fsBc.value.create(path))) { os => - val bytes = write(hcl, tc, aggRegion, seq.getAggOffset()) + + using(new DataOutputStream(fs.create(path))) { os => + val bytes = write(aggRegion, seq.getAggOffset()) os.writeInt(bytes.length) os.write(bytes) } + Iterator.single(path) } }.collect() @@ -1238,82 +1238,78 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val file1 = filesToMerge(i * 2) val file2 = filesToMerge(i * 2 + 1) - def readToBytes(is: DataInputStream): Array[Byte] = { - val len = is.readInt() - val b = new Array[Byte](len) - is.readFully(b) - b - } + val fs = fsBc.value - val b1 = using(new DataInputStream(fsBc.value.open(file1)))(readToBytes) - val b2 = using(new DataInputStream(fsBc.value.open(file2)))(readToBytes) - using(new DataOutputStream(fsBc.value.create(path))) { os => - val bytes = combOpFNeedsPool(() => - (ctx.r.pool, hcl, SparkTaskContext.get()) - )(b1, b2) + val combine = combOpFn(hcl, fs, ctx, ctx.r) + + val b1 = using(new DataInputStream(fs.open(file1)))(readToBytes) + val b2 = using(new DataInputStream(fs.open(file2)))(readToBytes) + + using(new DataOutputStream(fs.create(path))) { os => + val bytes = combine(b1, b2) os.writeInt(bytes.length) os.write(bytes) } + Iterator.single(path) + }.collect() } + fileStack += filesToMerge - val itF = { (i: Int, hcl: HailClassLoader, ctx: RVDContext, it: Iterator[Long]) => - val globalRegion = ctx.freshRegion() - val globals = if (rowIterationNeedsGlobals || scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, hcl) - else - 0 - val partitionAggs = { - var x = i - val ab = ArraySeq.newBuilder[String] - fileStack.result().foreach { files => - assert(x <= files.length) - if (x % 2 != 0) { - x -= 1 - ab += files(x) - } - assert(x % 2 == 0) - x = x / 2 - } - assert(x == 0) - var b = initAgg - ab.result().reverseIterator.foreach { path => - def readToBytes(is: DataInputStream): Array[Byte] = { - val len = is.readInt() - val b = new Array[Byte](len) - is.readFully(b) - b + return copy( + typ = newType, + rvd = rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key)) { + (i: Int, hcl: HailClassLoader, ctx: RVDContext, it: Iterator[Long]) => + val fs = fsBc.value + + val globalRegion = ctx.freshRegion() + val globals = + someIf(rowIterationNeedsGlobals || scanSeqNeedsGlobals, globalsBc) + .flatten + .map(_.value.readRegionValue(globalRegion, hcl)) + .getOrElse(0L) + + val partitionAggs = { + var x = i + val ab = ArraySeq.newBuilder[String] + fileStack.result().foreach { files => + assert(x <= files.length) + if (x % 2 != 0) { + x -= 1 + ab += files(x) + } + assert(x % 2 == 0) + x = x / 2 + } + assert(x == 0) + + val combine = combOpFn(hcl, fs, ctx, globalRegion) + + var acc = initAgg + for (path <- ab.result().reverseIterator) + acc = combine(acc, using(new DataInputStream(fs.open(path)))(readToBytes)) + + acc } - b = combOpFNeedsPool(() => - (ctx.r.pool, hcl, SparkTaskContext.get()) - )(b, using(new DataInputStream(fsBc.value.open(path)))(readToBytes)) - } - b - } + val aggRegion = ctx.freshRegion() + val newRow = rowFn(hcl, fs, ctx, globalRegion) + val seq = eltSeqFn(hcl, fs, ctx, globalRegion) - val aggRegion = ctx.freshRegion() - val tc = SparkTaskContext.get() - val newRow = f(hcl, fsBc.value, tc, globalRegion) - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) - var aggOff = read(hcl, tc, aggRegion, partitionAggs) + var aggOff = readFn(hcl, fs, ctx, globalRegion)(aggRegion, partitionAggs) - val res = it.map { ptr => - newRow.setAggState(aggRegion, aggOff) - val newPtr = newRow(ctx.region, globals, ptr) - aggOff = newRow.getAggOffset() - seq.setAggState(aggRegion, aggOff) - seq(ctx.region, globals, ptr) - aggOff = seq.getAggOffset() - newPtr - } - res - } - return copy( - typ = newType, - rvd = rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF), + it.map { ptr => + newRow.setAggState(aggRegion, aggOff) + val newPtr = newRow(ctx.region, globals, ptr) + aggOff = newRow.getAggOffset() + seq.setAggState(aggRegion, aggOff) + seq(ctx.region, globals, ptr) + aggOff = seq.getAggOffset() + newPtr + } + }, ) } @@ -1321,31 +1317,39 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow val scanPartitionAggs = SpillingCollectIterator( ctx.localTmpdir, ctx.fs, - rvd.mapPartitionsWithIndex { (i, hcl, ctx, it) => - val globalRegion = ctx.partitionRegion - val globals = if (scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, hcl) - else 0 + rvd.mapPartitions { (hcl, ctx, it) => + val fs = fsBc.value - SparkTaskContext.get().getRegionPool().scopedSmallRegion { aggRegion => - val tc = SparkTaskContext.get() - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) + val globalRegion = ctx.freshRegion() + val globals = + someIf(scanSeqNeedsGlobals, globalsBc) + .flatten + .map(_.value.readRegionValue(globalRegion, hcl)) + .getOrElse(0L) - seq.setAggState(aggRegion, read(hcl, tc, aggRegion, initAgg)) - it.foreach { ptr => + val read = readFn(hcl, fs, ctx, globalRegion) + val seq = eltSeqFn(hcl, fs, ctx, globalRegion) + val write = writeFn(hcl, fs, ctx, globalRegion) + + ctx.r.pool.scopedSmallRegion { aggRegion => + seq.setAggState(aggRegion, read(aggRegion, initAgg)) + + for (ptr <- it) { seq(ctx.region, globals, ptr) ctx.region.clear() } - Iterator.single(write(hcl, tc, aggRegion, seq.getAggOffset())) + + Iterator.single(write(aggRegion, seq.getAggOffset())) } }, ctx.getFlag("max_leader_scans").toInt, ) // 3. load in partition aggregations, comb op as necessary, write back out. - val partAggs = scanPartitionAggs.scanLeft(initAgg)(combOpFNeedsPool(() => - (ctx.r.pool, ctx.theHailClassLoader, ctx.taskContext) - )) + val partAggs = + scanPartitionAggs.scanLeft(initAgg)( + combOpFn(ctx.theHailClassLoader, ctx.fs, ctx, ctx.r) + ) val scanAggCount = rvd.getNumPartitions val partitionIndices = new Array[Long](scanAggCount) val scanAggsPerPartitionFile = ctx.createTmpPath("table-map-rows-scan-aggs-part") @@ -1361,14 +1365,22 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } // 4. load in partStarts, calculate newRow based on those results. - val itF = { - (i: Int, hcl: HailClassLoader, ctx: RVDContext, filePosition: Long, it: Iterator[Long]) => - val globalRegion = ctx.partitionRegion - val globals = if (rowIterationNeedsGlobals || scanSeqNeedsGlobals) - globalsBc.value.readRegionValue(globalRegion, hcl) - else - 0 - val partitionAggs = using(fsBc.value.openNoCompression(scanAggsPerPartitionFile)) { is => + copy( + typ = newType, + rvd = rvd.mapPartitionsWithIndexAndValue( + RVDType(rTyp.asInstanceOf[PStruct], typ.key), + partitionIndices, + ) { (_, hcl, ctx, filePosition, it) => + val fs = fsBc.value + + val globalRegion = ctx.r + val globals = + someIf(rowIterationNeedsGlobals || scanSeqNeedsGlobals, globalsBc) + .flatten + .map(_.value.readRegionValue(globalRegion, hcl)) + .getOrElse(0L) + + val partitionAggs = using(fs.openNoCompression(scanAggsPerPartitionFile)) { is => is.seek(filePosition) val aggSize = is.readInt() val partAggs = new Array[Byte](aggSize) @@ -1385,10 +1397,9 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } val aggRegion = ctx.freshRegion() - val tc = SparkTaskContext.get() - val newRow = f(hcl, fsBc.value, tc, globalRegion) - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) - var aggOff = read(hcl, tc, aggRegion, partitionAggs) + val newRow = rowFn(hcl, fs, ctx, globalRegion) + val seq = eltSeqFn(hcl, fs, ctx, globalRegion) + var aggOff = readFn(hcl, fs, ctx, globalRegion)(aggRegion, partitionAggs) var idx = 0 it.map { ptr => @@ -1400,14 +1411,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow aggOff = seq.getAggOffset() off } - } - - copy( - typ = newType, - rvd = rvd.mapPartitionsWithIndexAndValue( - RVDType(rTyp.asInstanceOf[PStruct], typ.key), - partitionIndices, - )(itF), + }, ) } diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index adca1d09b94..b124a92e96b 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -1,9 +1,8 @@ package is.hail.expr.ir.agg -import is.hail.annotations.{Region, RegionPool, RegionValue} +import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.backend.{ExecuteContext, HailTaskContext} -import is.hail.backend.spark.{unsafeHailClassLoaderForSparkWorkers, SparkTaskContext} +import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.compat.mutable.Growable @@ -11,6 +10,7 @@ import is.hail.collection.implicits.toRichIterable import is.hail.expr.ir import is.hail.expr.ir._ import is.hail.expr.ir.defs._ +import is.hail.expr.ir.implicits._ import is.hail.io.BufferSpec import is.hail.types.{tcoerce, TypeWithRequiredness, VirtualTypeWithReq} import is.hail.types.physical.stypes.EmitType @@ -19,8 +19,6 @@ import is.hail.types.virtual._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.TaskContext - class UnsupportedExtraction(msg: String) extends Exception(msg) object AggStateSig { @@ -229,122 +227,96 @@ class AggSignatures(val sigs: IndexedSeq[PhysicalAggSig]) { } def deserialize(ctx: ExecuteContext, spec: BufferSpec) - : ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = { - val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( - ctx, - states, - FastSeq(), - FastSeq(classInfo[Region]), - UnitInfo, - DeserializeAggs(0, 0, spec, states), - ) + : Compiled[(Region, Array[Byte]) => Long] = { + val (_, loadFn) = + CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, + states, + FastSeq(), + FastSeq(classInfo[Region]), + UnitInfo, + DeserializeAggs(0, 0, spec, states), + ) - val fsBc = ctx.fsBc; - { (hcl: HailClassLoader, htc: HailTaskContext, aggRegion: Region, bytes: Array[Byte]) => - val f2 = f(hcl, fsBc.value, htc, aggRegion) - f2.newAggState(aggRegion) - f2.setSerializedAgg(0, bytes) - f2(aggRegion) - f2.getAggOffset() + loadFn map { fn => (aggRegion, bytes) => + fn.newAggState(aggRegion) + fn.setSerializedAgg(0, bytes) + fn(aggRegion) + fn.getAggOffset() } } - def serialize(ctx: ExecuteContext, spec: BufferSpec) - : (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = { - val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( - ctx, - states, - FastSeq(), - FastSeq(classInfo[Region]), - UnitInfo, - SerializeAggs(0, 0, spec, states), - ) - - val fsBc = ctx.fsBc; - { (hcl: HailClassLoader, htc: HailTaskContext, aggRegion: Region, off: Long) => - val f2 = f(hcl, fsBc.value, htc, aggRegion) - f2.setAggState(aggRegion, off) - f2(aggRegion) - f2.storeAggsToRegion() - f2.getSerializedAgg(0) - } - } + def serialize(ctx: ExecuteContext, spec: BufferSpec): Compiled[(Region, Long) => Array[Byte]] = { + val (_, loadFn) = + CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, + states, + FastSeq(), + FastSeq(classInfo[Region]), + UnitInfo, + SerializeAggs(0, 0, spec, states), + ) - def combOpFSerializedWorkersOnly(ctx: ExecuteContext, spec: BufferSpec) - : (Array[Byte], Array[Byte]) => Array[Byte] = { - combOpFSerializedFromRegionPool(ctx, spec) { () => - val htc = SparkTaskContext.get() - val hcl = unsafeHailClassLoaderForSparkWorkers - if (htc == null) { - throw new UnsupportedOperationException( - s"Can't get htc. On worker = ${TaskContext.get() != null}" - ) - } - (htc.getRegionPool(), hcl, htc) + loadFn map { fn => (aggRegion, off) => + fn.setAggState(aggRegion, off) + fn(aggRegion) + fn.storeAggsToRegion() + fn.getSerializedAgg(0) } } def combOpFSerializedFromRegionPool(ctx: ExecuteContext, spec: BufferSpec) - : (() => (RegionPool, HailClassLoader, HailTaskContext)) => ( - (Array[Byte], Array[Byte]) => Array[Byte], - ) = { - val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit]( - ctx, - states ++ states, - FastSeq(), - FastSeq(classInfo[Region]), - UnitInfo, - Begin(FastSeq( - DeserializeAggs(0, 0, spec, states), - DeserializeAggs(nAggs, 1, spec, states), - Begin(sigs.zipWithIndex.map { case (sig, i) => CombOp(i, i + nAggs, sig) }), - SerializeAggs(0, 0, spec, states), - )), - ) + : Compiled[(Array[Byte], Array[Byte]) => Array[Byte]] = { + val (_, loadFn) = + CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, + states ++ states, + FastSeq(), + FastSeq(classInfo[Region]), + UnitInfo, + Begin(FastSeq( + DeserializeAggs(0, 0, spec, states), + DeserializeAggs(nAggs, 1, spec, states), + Begin(sigs.zipWithIndex.map { case (sig, i) => CombOp(i, i + nAggs, sig) }), + SerializeAggs(0, 0, spec, states), + )), + ) - val fsBc = ctx.fsBc - poolGetter: (() => (RegionPool, HailClassLoader, HailTaskContext)) => { - (bytes1: Array[Byte], bytes2: Array[Byte]) => - val (pool, hcl, htc) = poolGetter() - pool.scopedSmallRegion { r => - val f2 = f(hcl, fsBc.value, htc, r) - f2.newAggState(r) - f2.setSerializedAgg(0, bytes1) - f2.setSerializedAgg(1, bytes2) - f2(r) - f2.storeAggsToRegion() - f2.getSerializedAgg(0) + (hcl, fs, htc, r) => + val fn = loadFn(hcl, fs, htc, r) + (bytes1, bytes2) => + htc.r.pool.scopedSmallRegion { r => + fn.newAggState(r) + fn.setSerializedAgg(0, bytes1) + fn.setSerializedAgg(1, bytes2) + fn(r) + fn.storeAggsToRegion() + fn.getSerializedAgg(0) } - } } - // Takes ownership of both input regions, and returns ownership of region in - // resulting RegionValue. - def combOpF(ctx: ExecuteContext, spec: BufferSpec) - : (HailClassLoader, HailTaskContext, RegionValue, RegionValue) => RegionValue = { - val fb = ir.EmitFunctionBuilder[AsmFunction4RegionLongRegionLongLong]( + def combOpF(ctx: ExecuteContext, spec: BufferSpec): Compiled[(Region, Long, Long) => Long] = { + val fb = ir.EmitFunctionBuilder[(Region, Long, Long) => Long]( ctx, "combOpF3", - FastSeq[ParamType](classInfo[Region], LongInfo, classInfo[Region], LongInfo), + FastSeq[ParamType](classInfo[Region], LongInfo, LongInfo), LongInfo, ) - val leftAggRegion = fb.genFieldThisRef[Region]("agg_combine_left_top_region") + val combRegion = fb.genFieldThisRef[Region]("agg_combine_region") val leftAggOff = fb.genFieldThisRef[Long]("agg_combine_left_off") val rightAggRegion = fb.genFieldThisRef[Region]("agg_combine_right_top_region") val rightAggOff = fb.genFieldThisRef[Long]("agg_combine_right_off") fb.emit(EmitCodeBuilder.scopedCode(fb.emb) { cb => - cb.assign(leftAggRegion, fb.getCodeParam[Region](1)) + cb.assign(combRegion, fb.getCodeParam[Region](1)) cb.assign(leftAggOff, fb.getCodeParam[Long](2)) - cb.assign(rightAggRegion, fb.getCodeParam[Region](3)) - cb.assign(rightAggOff, fb.getCodeParam[Long](4)) + cb.assign(rightAggOff, fb.getCodeParam[Long](3)) val leftStates = StateTuple(states.map(s => AggStateSig.getState(s, fb.ecb))) - val leftAggState = new TupleAggregatorState(fb.ecb, leftStates, leftAggRegion, leftAggOff) + val leftAggState = new TupleAggregatorState(fb.ecb, leftStates, combRegion, leftAggOff) val rightStates = StateTuple(states.map(s => AggStateSig.getState(s, fb.ecb))) - val rightAggState = - new TupleAggregatorState(fb.ecb, rightStates, rightAggRegion, rightAggOff) + val rightAggState = new TupleAggregatorState(fb.ecb, rightStates, rightAggRegion, rightAggOff) leftStates.createStates(cb) leftAggState.load(cb) @@ -354,7 +326,7 @@ class AggSignatures(val sigs: IndexedSeq[PhysicalAggSig]) { for (i <- 0 until nAggs) { val rvAgg = Extract.getAgg(sigs(i)) - rvAgg.combOp(ctx, cb, leftAggRegion, leftAggState.states(i), rightAggState.states(i)) + rvAgg.combOp(ctx, cb, combRegion, leftAggState.states(i), rightAggState.states(i)) } leftAggState.store(cb) @@ -362,15 +334,7 @@ class AggSignatures(val sigs: IndexedSeq[PhysicalAggSig]) { leftAggOff }) - val f = fb.resultWithIndex() - val fsBc = ctx.fsBc - - { (hcl: HailClassLoader, htc: HailTaskContext, l: RegionValue, r: RegionValue) => - val comb = f(hcl, fsBc.value, htc, l.region) - l.setOffset(comb(l.region, l.offset, r.region, r.offset)) - r.region.invalidate() - l - } + fb.resultWithIndex() } } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala index 8de56a85f15..4ba8291b823 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir.lowering import is.hail.annotations.{BroadcastRow, Region, RegionValue} import is.hail.asm4s._ import is.hail.backend.{BroadcastValue, ExecuteContext} -import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} +import is.hail.backend.spark.AnonymousDependency import is.hail.collection.FastSeq import is.hail.expr.ir._ import is.hail.expr.ir.defs.{ @@ -47,7 +47,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType), ) - val gbAddr = f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r) + val gbAddr = f(ctx.theHailClassLoader, ctx.fs, ctx, ctx.r)(ctx.r) val globRow = BroadcastRow(ctx, RegionValue(ctx.r, gbAddr), globType) @@ -70,10 +70,10 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader val partF = rowF( hcl, fsBc.value, - SparkTaskContext.get(), - ctx.partitionRegion, + ctx, + ctx.r, ) - it.map(elt => partF(ctx.r, elt)) + it.map(elt => partF(ctx.region, elt)) }, )) } @@ -138,14 +138,16 @@ object TableStageToRVD { )), ) - val (Some(PTypeReferenceSingleCodeType(gbPType: PStruct)), f) = Compile[AsmFunction1RegionLong]( - ctx, - FastSeq(), - FastSeq(classInfo[Region]), - LongInfo, - globalsAndBroadcastVals, - ) - val gbAddr = f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r) + val (Some(PTypeReferenceSingleCodeType(gbPType: PStruct)), f) = + Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + globalsAndBroadcastVals, + ) + + val gbAddr = f(ctx.theHailClassLoader, ctx.fs, ctx, ctx.r)(ctx.r) val globPType = gbPType.fieldType("globals").asInstanceOf[PStruct] val globRow = BroadcastRow(ctx, RegionValue(ctx.r, gbPType.loadField(gbAddr, 0)), globPType) @@ -209,16 +211,15 @@ object TableStageToRVD { new ByteArrayInputStream(encodedContext), hcl, ) - .readRegionValue(rvdContext.partitionRegion) + .readRegionValue(rvdContext.r) val decodedBroadcastVals = makeBcDec( new ByteArrayInputStream(encodedBcVals.value), hcl, ) - .readRegionValue(rvdContext.partitionRegion) + .readRegionValue(rvdContext.r) makeIterator( hcl, fsBc.value, - SparkTaskContext.get(), rvdContext, decodedContext, decodedBroadcastVals, diff --git a/hail/hail/src/is/hail/io/bgen/BgenRDDPartitions.scala b/hail/hail/src/is/hail/io/bgen/BgenRDDPartitions.scala index 88c8711765e..e81313e3c22 100644 --- a/hail/hail/src/is/hail/io/bgen/BgenRDDPartitions.scala +++ b/hail/hail/src/is/hail/io/bgen/BgenRDDPartitions.scala @@ -94,7 +94,7 @@ object BgenRDDPartitions extends Logging { } val allPositions = partFirstVariantIndex ++ partLastVariantIndex.map(_ - 1L) - val keys = getKeysFromFile(file.indexPath, allPositions) + val keys = getKeysFromFile(ctx, file.indexPath, allPositions) val rangeBounds = (0 until nPartitions).map { i => Interval( keys(i), diff --git a/hail/hail/src/is/hail/io/bgen/LoadBgen.scala b/hail/hail/src/is/hail/io/bgen/LoadBgen.scala index ed1b563693c..3e83c52a5df 100644 --- a/hail/hail/src/is/hail/io/bgen/LoadBgen.scala +++ b/hail/hail/src/is/hail/io/bgen/LoadBgen.scala @@ -205,7 +205,8 @@ object LoadBgen extends Logging { require(files.length == indexFilePaths.length) val headers = getFileHeaders(fs, files.map(_.getPath)) - val cacheByRG: mutable.Map[Option[String], (String, Array[Long]) => Array[AnyRef]] = + val cacheByRG + : mutable.Map[Option[String], (ExecuteContext, String, Array[Long]) => Array[AnyRef]] = mutable.Map.empty headers.zip(indexFilePaths).map { case (h, indexFilePath) => @@ -232,7 +233,7 @@ object LoadBgen extends Logging { val nVariants = metadata.nKeys val rangeBounds = if (nVariants > 0) { - val Array(start, end) = getKeys(indexFilePath, Array[Long](0L, nVariants - 1)) + val Array(start, end) = getKeys(ctx, indexFilePath, Array[Long](0L, nVariants - 1)) Interval(start, end, includesStart = true, includesEnd = true) } else null diff --git a/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala b/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala index e45bcfe8d30..f87398a86aa 100644 --- a/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala +++ b/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala @@ -565,7 +565,7 @@ object StagedBGENReader { def queryIndexByPosition( ctx: ExecuteContext, indexSpec: AbstractIndexSpec, - ): (String, Array[Long]) => Array[AnyRef] = { + ): (ExecuteContext, String, Array[Long]) => Array[AnyRef] = { val fb = EmitFunctionBuilder[String, Array[Long], Array[AnyRef]](ctx, "bgen_query_index") fb.emitWithBuilder { cb => @@ -597,11 +597,8 @@ object StagedBGENReader { } val res = fb.resultWithIndex(); - { (path: String, indices: Array[Long]) => - ctx.r.pool.scopedRegion { r => - res.apply(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, r) - .apply(path, indices) - } + { (ctx, path: String, indices: Array[Long]) => + ctx.scopedExecution((hcl, fs, tc, r) => res(hcl, fs, tc, r)(path, indices)) } } diff --git a/hail/hail/src/is/hail/io/index/IndexWriter.scala b/hail/hail/src/is/hail/io/index/IndexWriter.scala index 8b30ab0207b..29105de140e 100644 --- a/hail/hail/src/is/hail/io/index/IndexWriter.scala +++ b/hail/hail/src/is/hail/io/index/IndexWriter.scala @@ -89,16 +89,16 @@ object IndexWriter { annotationType: PType, branchingFactor: Int = 4096, attributes: Map[String, Any] = Map.empty[String, Any], - ): (String, HailClassLoader, HailTaskContext, RegionPool) => IndexWriter = { + ): (String, HailClassLoader, HailTaskContext) => IndexWriter = { val sm = ctx.stateManager; val f = StagedIndexWriter.build(ctx, keyType, annotationType, branchingFactor); - { (path: String, hcl: HailClassLoader, htc: HailTaskContext, pool: RegionPool) => + { (path: String, hcl: HailClassLoader, htc: HailTaskContext) => new IndexWriter( sm, keyType, annotationType, - f(path, hcl, htc, pool, attributes), - pool, + f(path, hcl, htc, attributes), + htc.r.pool, attributes, ) } @@ -301,7 +301,7 @@ object StagedIndexWriter { keyType: PType, annotationType: PType, branchingFactor: Int = 4096, - ): (String, HailClassLoader, HailTaskContext, RegionPool, Map[String, Any]) => CompiledIndexWriter = { + ): (String, HailClassLoader, HailTaskContext, Map[String, Any]) => CompiledIndexWriter = { val fb = EmitFunctionBuilder[CompiledIndexWriter]( ctx, "indexwriter", @@ -344,10 +344,9 @@ object StagedIndexWriter { path: String, hcl: HailClassLoader, htc: HailTaskContext, - pool: RegionPool, attributes: Map[String, Any], ) => - pool.scopedRegion { r => + htc.r.pool.scopedRegion { r => // FIXME: This seems wrong? But also, anywhere we use broadcasting for the FS is wrong. val f = makeFB(hcl, fsBc.value, htc, r) f.init(path, attributes) diff --git a/hail/hail/src/is/hail/io/index/StagedIndexReader.scala b/hail/hail/src/is/hail/io/index/StagedIndexReader.scala index b8314e21f70..439865064b8 100644 --- a/hail/hail/src/is/hail/io/index/StagedIndexReader.scala +++ b/hail/hail/src/is/hail/io/index/StagedIndexReader.scala @@ -3,10 +3,9 @@ package is.hail.io.index import is.hail.annotations._ import is.hail.asm4s._ import is.hail.asm4s.implicits.valueToRichCodeInputBuffer -import is.hail.backend.TaskFinalizer import is.hail.collection.FastSeq import is.hail.expr.ir.{ - BinarySearch, EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode, + BinarySearch, EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode, TaskFinalizer, } import is.hail.expr.ir.functions.IntervalFunctions.{ arrayOfStructFindIntervalRange, compareStructWithPartitionIntervalEndpoint, diff --git a/hail/hail/src/is/hail/linalg/BlockMatrix.scala b/hail/hail/src/is/hail/linalg/BlockMatrix.scala index 9eec728c327..dc58b853ab8 100644 --- a/hail/hail/src/is/hail/linalg/BlockMatrix.scala +++ b/hail/hail/src/is/hail/linalg/BlockMatrix.scala @@ -2354,7 +2354,7 @@ class WriteBlocksRDD( val writeBlocksPart = split.asInstanceOf[WriteBlocksRDDPartition] val start = writeBlocksPart.start writeBlocksPart.range.zip(writeBlocksPart.parentPartitions).foreach { case (pi, pPart) => - using(RVDContext.default(SparkTaskContext.get().getRegionPool())) { ctx => + using(RVDContext.default(SparkTaskContext.get)) { ctx => val it = crdd.iterator(pPart, context, unsafeHailClassLoaderForSparkWorkers, ctx) if (pi == start) { diff --git a/hail/hail/src/is/hail/methods/IBD.scala b/hail/hail/src/is/hail/methods/IBD.scala index 774d892a400..b4197d60894 100644 --- a/hail/hail/src/is/hail/methods/IBD.scala +++ b/hail/hail/src/is/hail/methods/IBD.scala @@ -249,7 +249,7 @@ object IBD { val rowPType = input.rvRowPType val unnormalizedIbse = input.rvd.mapPartitions { (_, ctx, it) => - val rv = RegionValue(ctx.r) + val rv = RegionValue(ctx.region) val view = HardCallView(rowPType) it.map { ptr => rv.setOffset(ptr) diff --git a/hail/hail/src/is/hail/methods/LogisticRegression.scala b/hail/hail/src/is/hail/methods/LogisticRegression.scala index a000c13f38d..256c05f8161 100644 --- a/hail/hail/src/is/hail/methods/LogisticRegression.scala +++ b/hail/hail/src/is/hail/methods/LogisticRegression.scala @@ -145,7 +145,7 @@ case class LogisticRegression( rvb.start(newRVDType.rowType) rvb.startStruct() - rvb.addFields(fullRowType, ctx.r, ptr, copiedFieldIndices) + rvb.addFields(fullRowType, ctx.region, ptr, copiedFieldIndices) rvb.startArray(_yVecs.cols) logregAnnotations.foreach { stats => rvb.startStruct() diff --git a/hail/hail/src/is/hail/methods/PoissonRegression.scala b/hail/hail/src/is/hail/methods/PoissonRegression.scala index b21010b58b3..c105b690008 100644 --- a/hail/hail/src/is/hail/methods/PoissonRegression.scala +++ b/hail/hail/src/is/hail/methods/PoissonRegression.scala @@ -112,7 +112,7 @@ case class PoissonRegression( rvb.start(newRVDType.rowType) rvb.startStruct() - rvb.addFields(fullRowType, ctx.r, ptr, copiedFieldIndices) + rvb.addFields(fullRowType, ctx.region, ptr, copiedFieldIndices) poisRegTestBc.value .test(X, yBc.value, nullFitBc.value, "poisson", maxIter = maxIterations, tol = tolerance) .addToRVB(rvb) diff --git a/hail/hail/src/is/hail/methods/Skat.scala b/hail/hail/src/is/hail/methods/Skat.scala index 048aa59b021..d134484c478 100644 --- a/hail/hail/src/is/hail/methods/Skat.scala +++ b/hail/hail/src/is/hail/methods/Skat.scala @@ -363,7 +363,7 @@ case class Skat( fatal(s"Row weights must be non-negative, got $weight") val key = Annotation.copy( keyType.virtualType, - UnsafeRow.read(keyType, ctx.r, fullRowType.loadField(ptr, keyIndex)), + UnsafeRow.read(keyType, ctx.region, fullRowType.loadField(ptr, keyIndex)), ) val data = new Array[Double](n) diff --git a/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala b/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala index 1bc24b29591..3704f1c86e3 100644 --- a/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala +++ b/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala @@ -85,12 +85,12 @@ object AbstractRVDSpec { val (part0Count, bytesWritten) = using(fs.create(partsPath + "/" + filePath)) { os => - using(RVDContext.default(execCtx.r.pool)) { ctx => + using(RVDContext.default(execCtx)) { ctx => RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(execCtx, rowType))( execCtx.theHailClassLoader, ctx, rows.iterator.map { a => - rowType.unstagedStoreJavaObject(execCtx.stateManager, a, ctx.r) + rowType.unstagedStoreJavaObject(execCtx.stateManager, a, ctx.region) }, os, null, diff --git a/hail/hail/src/is/hail/rvd/RVD.scala b/hail/hail/src/is/hail/rvd/RVD.scala index 70d42b6b4fc..b233a9f5be2 100644 --- a/hail/hail/src/is/hail/rvd/RVD.scala +++ b/hail/hail/src/is/hail/rvd/RVD.scala @@ -2,15 +2,16 @@ package is.hail.rvd import is.hail.annotations._ import is.hail.asm4s.HailClassLoader -import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext} -import is.hail.backend.spark.{SparkBackend, SparkTaskContext} +import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager} +import is.hail.backend.spark.SparkBackend import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.implicits.{arrayToRichIndexedSeq, toRichIterable, toRichOrderedArray} -import is.hail.expr.ir.InferPType +import is.hail.expr.ir.{Compiled, InferPType} import is.hail.expr.ir.PruneDeadFields.isSupertype -import is.hail.expr.ir.agg.AggExecuteContextExtensions +import is.hail.expr.ir.implicits._ import is.hail.io._ +import is.hail.io.fs.FS import is.hail.io.index.IndexWriter import is.hail.rvd.RVD.RichIteratorLong import is.hail.sparkextras._ @@ -86,7 +87,7 @@ class RVD( def stabilize(ctx: ExecuteContext, enc: AbstractTypedCodecSpec): RDD[Array[Byte]] = { val makeEnc = enc.buildEncoder(ctx, rowPType) - crdd.cmapPartitions((hcl, ctx, it) => RegionValue.toBytes(hcl, makeEnc, ctx.r, it)).run + crdd.cmapPartitions((hcl, ctx, it) => RegionValue.toBytes(hcl, makeEnc, ctx.region, it)).run } def encodedRDD(ctx: ExecuteContext, enc: AbstractTypedCodecSpec): RDD[Array[Byte]] = @@ -105,8 +106,8 @@ class RVD( val encoder = new ByteArrayEncoder(hcl, makeEnc) TaskContext.get().addTaskCompletionListener[Unit](_ => encoder.close()): Unit it.map { ptr => - val keys: Any = SafeRow.selectFields(localRowPType, ctx.r, ptr)(kFieldIdx) - val bytes = encoder.regionValueToBytes(ctx.r, ptr) + val keys: Any = SafeRow.selectFields(localRowPType, ctx.region, ptr)(kFieldIdx) + val bytes = encoder.regionValueToBytes(ctx.region, ptr) (keys, bytes) } }.run @@ -275,7 +276,7 @@ class RVD( (ur, key) }, { case ((ur, key), _, ctx, ptr) => - ur.set(ctx.r, ptr) + ur.set(ctx.region, ptr) partBc.value.contains(key) }, ) @@ -417,7 +418,7 @@ class RVD( val localType = typ repartition(execCtx, partitioner.strictify()) .mapPartitions(typ)((_, ctx, it) => - OrderedRVIterator(localType, it.toIteratorRV(ctx.r), ctx, sm) + OrderedRVIterator(localType, it.toIteratorRV(ctx.region), ctx, sm) .staircase .map(_.value.offset) ) @@ -659,7 +660,7 @@ class RVD( (_, _, _) => new UnsafeRow(kPType), { case (kUR, _, ctx, ptr) => ctx.rvb.start(kType) - ctx.rvb.selectRegionValue(rowPType, kRowFieldIdx, ctx.r, ptr) + ctx.rvb.selectRegionValue(rowPType, kRowFieldIdx, ctx.region, ptr) kUR.set(ctx.region, ctx.rvb.end()) !intervalsBc.value.contains(kUR) }, @@ -672,7 +673,7 @@ class RVD( val kRowFieldIdx = typ.kFieldIdx val pred: (RVDContext, Long) => Boolean = (ctx: RVDContext, ptr: Long) => { - val ur = new UnsafeRow(localRowPType, ctx.r, ptr) + val ur = new UnsafeRow(localRowPType, ctx.region, ptr) val key = Row.fromSeq( kRowFieldIdx.map(i => ur.get(i)) ) @@ -709,27 +710,26 @@ class RVD( } def combine[U: ClassTag, T: ClassTag]( - execCtx: ExecuteContext, - mkZero: (HailClassLoader, HailTaskContext) => T, - itF: (HailClassLoader, Int, RVDContext, Iterator[Long]) => T, - deserialize: (HailClassLoader, HailTaskContext) => (U => T), - serialize: (HailClassLoader, HailTaskContext, T) => U, - combOp: (HailClassLoader, HailTaskContext, T, T) => T, + fsBc: BroadcastValue[FS], + mkZero: Compiled[Region => T], + itF: Compiled[(Region, Iterator[Long]) => T], + deserialize: Compiled[(Region, U) => T], + serialize: Compiled[(Region, T) => U], + combOp: Compiled[(Region, T, T) => T], commutative: Boolean, - tree: Boolean, - ): T = { - var reduced = crdd.cmapPartitionsWithIndex[U] { (i, hcl, ctx, it) => - Iterator.single( - serialize( - hcl, - SparkTaskContext.get(), - itF(hcl, i, ctx, it), - ) - ) + treeBranchFactor: Option[Int], + ): Compiled[Region => T] = { + + var reduced = crdd.cmapPartitions[U] { (hcl, ctx, it) => + val write = serialize(hcl, fsBc.value, ctx, ctx.r) + val elem = itF(hcl, fsBc.value, ctx, ctx.r) + ctx.r.pool.scopedSmallRegion { aggRegion => + Iterator.single(write(aggRegion, elem(aggRegion, it))) + } } - if (tree) { - val depth = treeAggDepth(getNumPartitions, execCtx.branchingFactor) + treeBranchFactor.foreach { factor => + val depth = treeAggDepth(getNumPartitions, factor) val scale = math.max( math.ceil(math.pow(getNumPartitions.toDouble, 1.0 / depth)).toInt, 2, @@ -740,38 +740,66 @@ class RVD( val nParts = reduced.getNumPartitions val newNParts = nParts / scale logger.info(s"starting tree aggregate stage $i ($nParts => $newNParts partitions)") - reduced = reduced - .mapPartitionsWithIndex { (i, it) => - it.map(x => (itemPartition(i, nParts, newNParts), (i, x))) - } - .partitionBy(new Partitioner { - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - override def numPartitions: Int = newNParts - }) - .cmapPartitions { (hcl, _, it) => - val htc = SparkTaskContext.get() - var acc = mkZero(hcl, htc) - it.foreach { case (_, (_, v)) => - acc = combOp(hcl, htc, acc, deserialize(hcl, htc)(v)) + reduced = + reduced + .mapPartitionsWithIndex { (i, it) => + it.map(x => (itemPartition(i, nParts, newNParts), (i, x))) } - Iterator.single(serialize(hcl, htc, acc)) - } + .partitionBy( + new Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + override def numPartitions: Int = newNParts + } + ) + .cmapPartitions { (hcl, ctx, it) => + val fs = fsBc.value + + val zero = mkZero(hcl, fs, ctx, ctx.r) + val read = deserialize(hcl, fs, ctx, ctx.r) + val write = serialize(hcl, fs, ctx, ctx.r) + val combine = combOp(hcl, fs, ctx, ctx.r) + + ctx.r.pool.scopedSmallRegion { r => + var acc = zero(ctx.r) + + for ((_, (_, v)) <- it) { + acc = combine(ctx.r, acc, read(r, v)) + r.clear() + } + + Iterator.single(write(ctx.r, acc)) + } + } + i += 1 } } - val ac = Combiner( - mkZero(execCtx.theHailClassLoader, execCtx.taskContext), - (acc1: T, acc2: T) => combOp(execCtx.theHailClassLoader, execCtx.taskContext, acc1, acc2), - commutative, - true, - ) - sparkContext.runJob( - reduced.run, - (it: Iterator[U]) => singletonElement(it), - (i, x: U) => ac.combine(i, deserialize(execCtx.theHailClassLoader, execCtx.taskContext)(x)), - ) - ac.result() + for { + zeroFn <- mkZero + combineFn <- combOp + read <- deserialize + } yield { r: Region => + val ac = + Combiner( + zero = zeroFn(r), + combine = combineFn(r, _, _), + commutative = commutative, + associative = true, + ) + + r.pool.scopedSmallRegion { inner => + sparkContext.runJob( + reduced.run, + (it: Iterator[U]) => singletonElement(it), + (i, x: U) => { + ac.combine(i, read(inner, x)); inner.clear() + }, + ) + } + + ac.result() + } } def count(): Long = @@ -794,8 +822,8 @@ class RVD( val enc = TypedCodecSpec(execCtx, rowPType, BufferSpec.wireSpec) val encodedData = collectAsBytes(execCtx, enc) val (pType: PStruct, dec) = enc.buildDecoder(execCtx, rowType) - execCtx.r.pool.scopedRegion { region => - RegionValue.fromBytes(execCtx.theHailClassLoader, dec, region, encodedData.iterator) + execCtx.scopedExecution { (hcl, _, _, region) => + RegionValue.fromBytes(hcl, dec, region, encodedData.iterator) .map { ptr => val row = SafeRow(pType, ptr) region.clear() @@ -1011,7 +1039,7 @@ class RVD( start = Row(interval.start), end = Row(interval.end), ) - val bytes = encoder.regionValueToBytes(ctx.r, ptr) + val bytes = encoder.regionValueToBytes(ctx.region, ptr) partBc.value.queryInterval(wrappedInterval).map(i => ((i, interval), bytes)) } else Iterator() @@ -1127,9 +1155,10 @@ object RVD extends Logging { val localType = typ val sm = ctx.stateManager - val keyInfo = keys.crunJobWithIndex { (i, _, rvdContext, it) => + val keyInfo = keys.cmapPartitionsWithIndex { (i, _, ctx, it) => if (it.hasNext) - Some(RVDPartitionInfo( + Iterator.single(RVDPartitionInfo( + ctx, sm, localType, partitionKey, @@ -1137,11 +1166,11 @@ object RVD extends Logging { i, it, partitionSeed(i), - rvdContext, )) else - None - }.flatten + Iterator.empty + } + .collect() val kOrd = PartitionBoundOrdering(sm, typ.kType.virtualType).toOrdering keyInfo.sortBy(_.min)(kOrd) @@ -1341,7 +1370,7 @@ object RVD extends Logging { val srcRowPType = rvd.rowPType val newRVDType = rvd.typ.copy(rowType = unifiedRowPType, key = unifiedKey) rvd.map(newRVDType)((_, ctx, ptr) => - unifiedRowPType.copyFromAddress(sm, ctx.r, srcRowPType, ptr, false) + unifiedRowPType.copyFromAddress(sm, ctx.region, srcRowPType, ptr, false) ) } } @@ -1406,12 +1435,9 @@ object RVD extends Logging { IndexSpec.defaultAnnotation(execCtx, "../../index", localTyp.kType, withOffsetField = true) val makeRowsEnc = rowsCodecSpec.buildEncoder(execCtx, fullRowType) val makeEntriesEnc = entriesCodecSpec.buildEncoder(execCtx, fullRowType) - val _makeIndexWriter = + val makeIndexWriter = IndexWriter.builder(execCtx, localTyp.kType, +PCanonicalStruct("entries_offset" -> PInt64())) - val makeIndexWriter: (String, HailClassLoader, RegionPool) => IndexWriter = - _makeIndexWriter(_, _, SparkTaskContext.get(), _) - val partDigits = digitsNeeded(nPartitions) for (i <- 0 until nRVDs) { val path = paths(i) diff --git a/hail/hail/src/is/hail/rvd/RVDContext.scala b/hail/hail/src/is/hail/rvd/RVDContext.scala index d4ec9c9f5f8..fffb11f8b52 100644 --- a/hail/hail/src/is/hail/rvd/RVDContext.scala +++ b/hail/hail/src/is/hail/rvd/RVDContext.scala @@ -1,32 +1,28 @@ package is.hail.rvd -import is.hail.annotations.{Region, RegionPool, RegionValueBuilder} -import is.hail.backend.HailStateManager +import is.hail.annotations.{Region, RegionValueBuilder} +import is.hail.backend.{HailStateManager, HailTaskContext} import scala.collection.mutable object RVDContext { - - def default(pool: RegionPool) = { - val partRegion = Region(pool = pool) - val ctx = new RVDContext(partRegion, Region(pool = pool)) - ctx.own(partRegion) - ctx - } + def default(tc: HailTaskContext) = + new RVDContext(tc.r, Region(pool = tc.r.pool)) } -class RVDContext(val partitionRegion: Region, val r: Region) extends AutoCloseable { - private[this] val children = new mutable.HashSet[AutoCloseable]() +class RVDContext(override val r: Region, val region: Region) + extends HailTaskContext with AutoCloseable { + private[this] val children = mutable.HashSet.empty[AutoCloseable] private def own(child: AutoCloseable): Unit = children += child private[this] def disown(child: AutoCloseable): Unit = assert(children.remove(child)) - own(r) + own(region) def freshContext(): RVDContext = { - val ctx = new RVDContext(partitionRegion, Region(pool = r.pool)) + val ctx = RVDContext.default(this) own(ctx) ctx } @@ -37,9 +33,10 @@ class RVDContext(val partitionRegion: Region, val r: Region) extends AutoCloseab r2 } - def region: Region = r + override def onClose(f: () => Unit): Unit = + own(() => f()) - private[this] val theRvb = new RegionValueBuilder(HailStateManager(Map.empty), r) + private[this] val theRvb = new RegionValueBuilder(HailStateManager(Map.empty), region) def rvb = theRvb // frees the memory associated with this context diff --git a/hail/hail/src/is/hail/rvd/RVDPartitionInfo.scala b/hail/hail/src/is/hail/rvd/RVDPartitionInfo.scala index 64d38365647..a088a4cf0d6 100644 --- a/hail/hail/src/is/hail/rvd/RVDPartitionInfo.scala +++ b/hail/hail/src/is/hail/rvd/RVDPartitionInfo.scala @@ -1,7 +1,7 @@ package is.hail.rvd import is.hail.annotations.{Region, RegionValue, SafeRow, WritableRegionValue} -import is.hail.backend.HailStateManager +import is.hail.backend.{HailStateManager, HailTaskContext} import is.hail.types.virtual.Type import is.hail.utils._ @@ -29,6 +29,7 @@ object RVDPartitionInfo extends Logging { final val KSORTED = 2 def apply( + ctx: HailTaskContext, sm: HailStateManager, typ: RVDType, partitionKey: Int, @@ -36,9 +37,8 @@ object RVDPartitionInfo extends Logging { partitionIndex: Int, it: Iterator[Long], seed: Long, - producerContext: RVDContext, - ): RVDPartitionInfo = { - using(RVDContext.default(producerContext.r.pool)) { localctx => + ): RVDPartitionInfo = + using(RVDContext.default(ctx)) { localctx => val kPType = typ.kType val pkOrd = typ.copy(key = typ.key.take(partitionKey)).kOrd(sm) val minF = WritableRegionValue(sm, kPType, localctx.freshRegion()) @@ -65,7 +65,6 @@ object RVDPartitionInfo extends Logging { i += 1 } - producerContext.region.clear() while (it.hasNext) { val f = it.next() @@ -101,7 +100,6 @@ object RVDPartitionInfo extends Logging { samples(j.toInt).set(f, deepCopy = true) } - producerContext.region.clear() i += 1 } @@ -117,5 +115,4 @@ object RVDPartitionInfo extends Logging { contextStr, ) } - } } diff --git a/hail/hail/src/is/hail/sparkextras/ContextRDD.scala b/hail/hail/src/is/hail/sparkextras/ContextRDD.scala index c1534400ba1..ab7bf5e8983 100644 --- a/hail/hail/src/is/hail/sparkextras/ContextRDD.scala +++ b/hail/hail/src/is/hail/sparkextras/ContextRDD.scala @@ -154,7 +154,7 @@ object ContextRDD { class ContextRDD[T: ClassTag](val rdd: RDD[Element[T]]) extends Serializable { private[this] def sparkManagedContext[U](f: (HailClassLoader, RVDContext) => U): U = { - val c = RVDContext.default(SparkTaskContext.get().getRegionPool()) + val c = RVDContext.default(SparkTaskContext.get) TaskContext.get().addTaskCompletionListener[Unit]((_: TaskContext) => c.close()): Unit f(unsafeHailClassLoaderForSparkWorkers, c) } @@ -219,20 +219,6 @@ class ContextRDD[T: ClassTag](val rdd: RDD[Element[T]]) extends Serializable { part.flatMap(x => inCtx((hcl, ctx) => f(i, hcl, ctx, x))) }) - // Gives consumer ownership of the context. Consumer is responsible for freeing - // resources per element. - def crunJobWithIndex[U: ClassTag](f: (Int, HailClassLoader, RVDContext, Iterator[T]) => U) - : Array[U] = - sparkContext.runJob( - rdd, - { (taskContext, it: Iterator[Element[T]]) => - val c = RVDContext.default(SparkTaskContext.get().getRegionPool()) - val hcl = unsafeHailClassLoaderForSparkWorkers - val ans = f(taskContext.partitionId(), hcl, c, it.flatMap(_(hcl, c))) - ans - }, - ) - def cmapPartitionsAndContext[U: ClassTag]( f: (HailClassLoader, RVDContext, Iterator[Element[T]]) => Iterator[U], preservesPartitioning: Boolean = false, diff --git a/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDD.scala b/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDD.scala index 9a9d7bccae3..1b4d3bde71c 100644 --- a/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDD.scala +++ b/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDD.scala @@ -1,8 +1,6 @@ package is.hail.sparkextras.implicits - -import is.hail.annotations.RegionPool import is.hail.asm4s.HailClassLoader -import is.hail.backend.ExecuteContext +import is.hail.backend.{ExecuteContext, HailTaskContext} import is.hail.expr.ir.partFile import is.hail.io.FileWriteMetadata import is.hail.io.fs.FS @@ -24,7 +22,7 @@ object RichContextRDD { rootPath: String, f: String, idxRelPath: String, - mkIdxWriter: (String, HailClassLoader, RegionPool) => IndexWriter, + mkIdxWriter: (String, HailClassLoader, HailTaskContext) => IndexWriter, stageLocally: Boolean, fs: FS, localTmpdir: String, @@ -36,18 +34,17 @@ object RichContextRDD { if (idxRelPath != null) rootPath + "/" + idxRelPath + "/" + f + ".idx" else null val (filename, idxFilename) = if (stageLocally) { - val context = TaskContext.get() val partPath = ExecuteContext.createTmpPathNoCleanup(localTmpdir, "write-partitions-part") val idxPath = partPath + ".idx" - context.addTaskCompletionListener[Unit] { (context: TaskContext) => + ctx.onClose { () => fs.delete(partPath, recursive = false) fs.delete(idxPath, recursive = true) - }: Unit + } partPath -> idxPath } else finalFilename -> finalIdxFilename val os = fs.create(filename) - val iw = mkIdxWriter(idxFilename, hcl, ctx.r.pool) + val iw = mkIdxWriter(idxFilename, hcl, ctx) // write must close `os` and `iw` val (rowCount, bytesWritten) = write(hcl, ctx, it, os, iw) @@ -97,7 +94,7 @@ class RichContextRDD[T](val crdd: ContextRDD[T]) extends AnyVal { path: String, idxRelPath: String, stageLocally: Boolean, - mkIdxWriter: (String, HailClassLoader, RegionPool) => IndexWriter, + mkIdxWriter: (String, HailClassLoader, HailTaskContext) => IndexWriter, write: (HailClassLoader, RVDContext, Iterator[T], OutputStream, IndexWriter) => (Long, Long), ): Array[FileWriteMetadata] = { val localTmpdir = ctx.localTmpdir diff --git a/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDDRegionValue.scala b/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDDRegionValue.scala index cd9af24279e..b7be1a6e06b 100644 --- a/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDDRegionValue.scala +++ b/hail/hail/src/is/hail/sparkextras/implicits/RichContextRDDRegionValue.scala @@ -3,7 +3,6 @@ package is.hail.sparkextras.implicits import is.hail.annotations._ import is.hail.asm4s.HailClassLoader import is.hail.backend.ExecuteContext -import is.hail.backend.spark.SparkTaskContext import is.hail.expr.ir.partFile import is.hail.io.{AbstractTypedCodecSpec, Encoder, FileWriteMetadata} import is.hail.io.fs.FS @@ -45,7 +44,7 @@ object RichContextRDDRegionValue { it.foreach { ptr => if (iw != null) { val off = en.indexOffset() - val key = SafeRow.selectFields(rowType, ctx.r, ptr)(indexKeyFieldIndices) + val key = SafeRow.selectFields(rowType, ctx.region, ptr)(indexKeyFieldIndices) iw.appendRow(key, off, Row()) } en.writeByte(1) @@ -87,7 +86,7 @@ object RichContextRDDRegionValue { ctx: RVDContext, partDigits: Int, stageLocally: Boolean, - makeIndexWriter: (String, HailClassLoader, RegionPool) => IndexWriter, + makeIndexWriter: (String, HailClassLoader, RVDContext) => IndexWriter, makeRowsEnc: (OutputStream) => Encoder, makeEntriesEnc: (OutputStream) => Encoder, ): FileWriteMetadata = { @@ -121,20 +120,20 @@ object RichContextRDDRegionValue { using(fs.create(entriesPartPath)) { entriesOS => val trackedEntriesOS = new ByteTrackingOutputStream(entriesOS) using(makeEntriesEnc(trackedEntriesOS)) { entriesEN => - using(makeIndexWriter(idxPath, hcl, ctx.r.pool)) { iw => + using(makeIndexWriter(idxPath, hcl, ctx)) { iw => var rowCount = 0L it.foreach { ptr => val rows_off = rowsEN.indexOffset() val ents_off = entriesEN.indexOffset() - val key = SafeRow.selectFields(fullRowType, ctx.r, ptr)(t.kFieldIdx) + val key = SafeRow.selectFields(fullRowType, ctx.region, ptr)(t.kFieldIdx) iw.appendRow(key, rows_off, Row(ents_off)) rowsEN.writeByte(1) - rowsEN.writeRegionValue(ctx.r, ptr) + rowsEN.writeRegionValue(ctx.region, ptr) entriesEN.writeByte(1) - entriesEN.writeRegionValue(ctx.r, ptr) + entriesEN.writeRegionValue(ctx.region, ptr) ctx.region.clear() @@ -225,7 +224,7 @@ class RichContextRDDLong(val crdd: ContextRDD[Long]) extends AnyVal { def toCRDDRegionValue: ContextRDD[RegionValue] = boundary.cmapPartitionsWithContext { (hcl, ctx, part) => - val rv = RegionValue(ctx.r) + val rv = RegionValue(ctx.region) part(hcl, ctx).map { ptr => rv.setOffset(ptr); rv } } @@ -241,10 +240,8 @@ class RichContextRDDLong(val crdd: ContextRDD[Long]) extends AnyVal { ctx, path, idxRelPath, - stageLocally, { - val f1 = IndexWriter.builder(ctx, t.kType, +PCanonicalStruct()) - f1(_, _, SparkTaskContext.get(), _) - }, + stageLocally, + IndexWriter.builder(ctx, t.kType, +PCanonicalStruct()), RichContextRDDRegionValue.writeRowsPartition( encoding.buildEncoder(ctx, t.rowType), t.kFieldIdx, diff --git a/hail/hail/src/is/hail/stats/LinearMixedModel.scala b/hail/hail/src/is/hail/stats/LinearMixedModel.scala index 8637a5b514a..66c54235507 100644 --- a/hail/hail/src/is/hail/stats/LinearMixedModel.scala +++ b/hail/hail/src/is/hail/stats/LinearMixedModel.scala @@ -105,7 +105,7 @@ class LinearMixedModel(lmmData: LMMData) { val r0 = 0 to 0 val r1 = 1 until f - val region = Region(pool = SparkTaskContext.get().getRegionPool()) + val region = Region(pool = SparkTaskContext.get.r.pool) val rv = RegionValue(region) val rvb = new RegionValueBuilder(ctx.stateManager, region) @@ -177,7 +177,7 @@ class LinearMixedModel(lmmData: LMMData) { val r0 = 0 to 0 val r1 = 1 until f - val region = Region(pool = SparkTaskContext.get().getRegionPool()) + val region = Region(pool = SparkTaskContext.get.r.pool) val rv = RegionValue(region) val rvb = new RegionValueBuilder(ctx.stateManager, region) diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index c339889cc51..8cb296678af 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -3,7 +3,7 @@ package is.hail import is.hail.ExecStrategy.ExecStrategy import is.hail.annotations._ import is.hail.asm4s.HailClassLoader -import is.hail.backend.{Backend, ExecuteContext, OwningTempFileManager} +import is.hail.backend.{Backend, ExecuteContext, HailTaskContext, OwningTempFileManager} import is.hail.backend.spark.SparkBackend import is.hail.collection.{FastSeq, ImmutableMap} import is.hail.collection.implicits.toRichIterable @@ -48,6 +48,7 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { def pool: RegionPool = ctx.r.pool def sc: SparkContext = ctx.backend.asSpark.sc def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader + def taskContext: HailTaskContext = ctx_ private[this] lazy val resources: String = sys.env.getOrElse("MILL_TEST_RESOURCE_DIR", "hail/test/resources") diff --git a/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala b/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala index cef44deb5fd..b616623e1a6 100644 --- a/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala @@ -613,7 +613,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec deepCopy = false, ).a } - val cp1 = f1.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)() + val cp1 = f1.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r)() assert(SafeRow.read(t2, cp1) == Row(value)) val f2 = EmitFunctionBuilder[Long](ctx, "stagedCopy2") @@ -626,7 +626,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec deepCopy = false, ).a } - val cp2 = f2.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)() + val cp2 = f2.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r)() assert(SafeRow.read(t1, cp2) == Row(value)) } } diff --git a/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala b/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala index 012b5e16ad1..0d9d138b52c 100644 --- a/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala +++ b/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala @@ -26,7 +26,7 @@ class CodeSuite extends HailSuite { sum } - val result = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)() + val result = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, ctx.r)() assert(result == 10) } diff --git a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala index 04401ada94b..75debbd0bf2 100644 --- a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala @@ -108,8 +108,8 @@ class Aggregators2Suite extends HailSuite { ResultOp.makeTuple(Array(aggSig)), ) - val init = initF(theHailClassLoader, ctx.fs, ctx.taskContext, region) - val res = resOneF(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val init = initF(theHailClassLoader, ctx.fs, taskContext, region) + val res = resOneF(theHailClassLoader, ctx.fs, taskContext, region) pool.scopedSmallRegion { aggRegion => init.newAggState(aggRegion) @@ -122,9 +122,9 @@ class Aggregators2Suite extends HailSuite { val serializedParts = seqOps.grouped(math.ceil(seqOps.length / nPartitions.toDouble).toInt).map { seqs => - val init = initF(theHailClassLoader, ctx.fs, ctx.taskContext, region) - val seq = withArgs(Begin(seqs))(theHailClassLoader, ctx.fs, ctx.taskContext, region) - val write = writeF(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val init = initF(theHailClassLoader, ctx.fs, taskContext, region) + val seq = withArgs(Begin(seqs))(theHailClassLoader, ctx.fs, taskContext, region) + val write = writeF(theHailClassLoader, ctx.fs, taskContext, region) pool.scopedSmallRegion { aggRegion => init.newAggState(aggRegion) init(region, argOff) @@ -139,13 +139,13 @@ class Aggregators2Suite extends HailSuite { }.toArray pool.scopedSmallRegion { aggRegion => - val combOp = combAndDuplicate(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val combOp = combAndDuplicate(theHailClassLoader, ctx.fs, taskContext, region) combOp.newAggState(aggRegion) serializedParts.zipWithIndex.foreach { case (s, i) => combOp.setSerializedAgg(i, s) } combOp(region) - val res = resF(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val res = resF(theHailClassLoader, ctx.fs, taskContext, region) res.setAggState(aggRegion, combOp.getAggOffset()) val double = SafeRow(rt, res(region)) transformResult match { diff --git a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala index 7290f55334d..1e464109005 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala @@ -81,7 +81,7 @@ class EmitStreamSuite extends HailSuite { val f = fb.resultWithIndex() (arg: T) => pool.scopedRegion { r => - val off = call(f(theHailClassLoader, ctx.fs, ctx.taskContext, r), r, arg) + val off = call(f(theHailClassLoader, ctx.fs, taskContext, r), r, arg) if (off == 0L) null else @@ -187,7 +187,7 @@ class EmitStreamSuite extends HailSuite { } val f = fb.resultWithIndex() pool.scopedRegion { r => - val len = f(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) + val len = f(theHailClassLoader, ctx.fs, taskContext, r)(r) if (len < 0) None else Some(len) } } @@ -923,7 +923,7 @@ class EmitStreamSuite extends HailSuite { ) assert( - SafeRow.read(pt, f(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r, input)) == Row(null) + SafeRow.read(pt, f(theHailClassLoader, ctx.fs, taskContext, r)(r, input)) == Row(null) ) } } diff --git a/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala b/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala index d6289968f29..44e5f4ceb82 100644 --- a/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala @@ -129,7 +129,7 @@ class FunctionSuite extends HailSuite { i } pool.scopedRegion { r => - assert(fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, ctx.taskContext, r)() == 2) + assert(fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, taskContext, r)() == 2) } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala index 75f3c2ed4a3..faeabb8b422 100644 --- a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala @@ -64,7 +64,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { fb.ecb.getOrderingFunction(cv1.st, cv2.st, op) .apply(cb, EmitValue.present(cv1), EmitValue.present(cv2)) } - fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) + fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r) } @Test def testMissingNonequalComparisons(): Unit = { @@ -87,7 +87,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { fb.ecb.getOrderingFunction(ev1.st, ev2.st, op) .apply(cb, ev1, ev2) } - fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) + fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r) } forAll(genTypeVal[TStruct](ctx)) { case (t, a) => @@ -492,7 +492,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val asArray = SafeIndexedSeq(pArray, soff) - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, region) val i = f(region, soff, eoff) val ordering = t.ordering(sm) @@ -550,7 +550,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val asArray = SafeIndexedSeq(PCanonicalArray(pDict.elementType), soff).map(_.asInstanceOf[Row]) - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, region) val i = f(region, soff, eoff) val ordering = pDict.keyType.virtualType.ordering(sm) diff --git a/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala index dd545371bff..8b1928c06cc 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala @@ -89,7 +89,7 @@ object BTreeBackedSet { val inputBuffer = new StreamBufferSpec().buildInputBuffer(new ByteArrayInputStream(serialized)) val set = new BTreeBackedSet(ctx, region, n) - set.root = fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region)( + set.root = fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx, region)( region, inputBuffer, ) @@ -115,7 +115,7 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { root } - fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region) + fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx, region) } private val getF = { @@ -138,7 +138,7 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { cb.if_(key.isEmpty(cb, elt), key.storeKey(cb, elt, m, v)) root } - fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region) + fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx, region) } private val getResultsF = { @@ -177,7 +177,7 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { ) returnArray } - fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region) + fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx, region) } private val bulkStoreF = { @@ -204,7 +204,7 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { ob2.flush() } - fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region) + fb.resultWithIndex()(ctx.theHailClassLoader, ctx.fs, ctx, region) } def clear(): Unit = { diff --git a/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala index bc51df922a6..59027a39fac 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala @@ -48,7 +48,7 @@ class TakeByAggregatorSuite extends HailSuite { tba.result(cb, argR, rt).a } - val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) + val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r)(r) val result = SafeRow.read(rt, o) assert( result == ((n - 1) to 0 by -1) @@ -86,7 +86,7 @@ class TakeByAggregatorSuite extends HailSuite { tba.result(cb, argR, rt).a } - val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) + val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r)(r) val result = SafeRow.read(rt, o) assert(result == FastSeq(0, 1, 2, 3, null, null, null)) } @@ -131,7 +131,7 @@ class TakeByAggregatorSuite extends HailSuite { resultOff } - val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) + val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, r)(r) val pqOffset = Region.loadAddress(o) val pq = SafeRow.read(rt, pqOffset) val collOffset = Region.loadAddress(o + 8) diff --git a/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala b/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala index dc8e260806c..535b383e4be 100644 --- a/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala @@ -56,7 +56,7 @@ class DownsampleSuite extends HailSuite { } pool.scopedSmallRegion { r => - fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, ctx.taskContext, r).apply(pool) + fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, taskContext, r).apply(pool) } } } diff --git a/hail/hail/test/src/is/hail/io/IndexSuite.scala b/hail/hail/test/src/is/hail/io/IndexSuite.scala index 1e094021096..a108078fc2a 100644 --- a/hail/hail/test/src/is/hail/io/IndexSuite.scala +++ b/hail/hail/test/src/is/hail/io/IndexSuite.scala @@ -46,8 +46,7 @@ class IndexSuite extends HailSuite { val iw = IndexWriter.builder(ctx, keyType, annotationType, branchingFactor, attributes)( file, theHailClassLoader, - ctx.taskContext, - pool, + taskContext, ) data.zip(annotations).zipWithIndex.foreach { case ((s, a), offset) => iw.appendRow(s, offset.toLong, a) diff --git a/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala b/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala index 060ceb1157c..f16dff18c41 100644 --- a/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala +++ b/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala @@ -21,6 +21,6 @@ class LIRSplitSuite extends HailSuite { cb.invokeVoid(mb, cb.this_, const(1L)) Code._empty } - f.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)() + f.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, ctx.r)() } } diff --git a/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala b/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala index 4bf3dbbc575..3d89494b6fe 100644 --- a/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala +++ b/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala @@ -72,7 +72,7 @@ class ETypeSuite extends HailSuite { val buffer = new MemoryBuffer val ob = new MemoryOutputBuffer(buffer) - fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r).apply(x, ob) + fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, ctx.r).apply(x, ob) ob.flush() buffer.clearPos() @@ -85,7 +85,7 @@ class ETypeSuite extends HailSuite { outPType.store(cb, regArg, decoded, deepCopy = false) } - val result = fb2.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r).apply( + val result = fb2.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, ctx.r).apply( ctx.r, new MemoryInputBuffer(buffer), ) diff --git a/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala b/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala index 93c81e40472..6a974bab69f 100644 --- a/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala @@ -93,7 +93,7 @@ class PNDArraySuite extends PhysicalTestUtils { Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normA) } - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, region) assert(f(region) < 1e-14) } @@ -178,7 +178,7 @@ class PNDArraySuite extends PhysicalTestUtils { Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normA) } - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, region) assert(f(region) < 1e-14) } @@ -312,7 +312,7 @@ class PNDArraySuite extends PhysicalTestUtils { Code._empty } - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, region) f(region) succeed @@ -374,7 +374,7 @@ class PNDArraySuite extends PhysicalTestUtils { Code._empty } - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, taskContext, region) f(region) succeed