Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions hail/hail/src/is/hail/annotations/RegionPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ final class RegionPool private (strictMemoryCheck: Boolean, threadName: String,
}
}

def getRegion(): Region = getRegion(Region.REGULAR)

def getRegion(size: Int): Region = {
def getRegion(size: Int = Region.REGULAR): Region = {
val r = new Region(size, this)
r.memory = getMemory(size)
r
Expand Down Expand Up @@ -154,11 +152,29 @@ final class RegionPool private (strictMemoryCheck: Boolean, threadName: String,

def report(context: String): Unit = {
val inBlocks = bytesInBlocks()
val (chunksAllocated, cacheHits) = chunkCache.getUsage()

logger.info(
s"RegionPool: $context: ${readableBytes(totalAllocatedBytes)} allocated (${readableBytes(inBlocks)} blocks / " +
s"${readableBytes(totalAllocatedBytes - inBlocks)} chunks), regions.size = ${regions.size}, " +
s"$numJavaObjects current java objects, thread $threadID: $threadName"
s"""RegionPool: $context
| thread:
| id: $threadID
| name: $threadName
| objects: $numJavaObjects
| allocations:
| peak: $getHighestTotalUsage
| total: ${readableBytes(totalAllocatedBytes)}
| blocks: ${readableBytes(inBlocks)}
| chunks: ${readableBytes(totalAllocatedBytes - inBlocks)}
| regions:
| total: ${regions.size}
| free: ${freeRegions.size}
| blocks:
| total: ${blocks.sum}
| free: ${freeBlocks.view.map(_.size).sum}
| chunks:
| total: $chunksAllocated
| reused: $cacheHits
| """.stripMargin
)
// logger.info("-----------STACK_TRACES---------")
// val stacks: String = regions.result().toIndexedSeq.flatMap(r => r.stackTrace.map((r.getTotalChunkMemory(), _))).foldLeft("")((a: String, b) => a + "\n" + b.toString())
Expand All @@ -170,8 +186,6 @@ final class RegionPool private (strictMemoryCheck: Boolean, threadName: String,
def scopedSmallRegion[T](f: Region => T): T = using(Region(Region.SMALL, pool = this))(f)
def scopedTinyRegion[T](f: Region => T): T = using(Region(Region.TINY, pool = this))(f)

override def finalize(): Unit = close()

private[this] var closed: Boolean = false

override def close(): Unit = {
Expand Down
22 changes: 12 additions & 10 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package is.hail.backend
import is.hail.HailFeatureFlags
import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{BaseIR, CompileCache, Compiled}
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.lowering.IrMetadata
Expand Down Expand Up @@ -72,7 +71,7 @@ object ExecuteContext {
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
)(
f: ExecuteContext => T
): T = {
): T =
RegionPool.scoped { pool =>
pool.scopedRegion { region =>
using(new ExecuteContext(
Expand All @@ -94,7 +93,6 @@ object ExecuteContext {
))(f(_))
}
}
}

def createTmpPathNoCleanup(tmpdir: String, prefix: String, extension: String = null): String = {
val random = new SecureRandom()
Expand All @@ -113,7 +111,7 @@ class ExecuteContext(
val backend: Backend,
val references: Map[String, ReferenceGenome],
val fs: FS,
val r: Region,
override val r: Region,
val timer: ExecutionTimer,
val tempFileManager: TempFileManager,
val theHailClassLoader: HailClassLoader,
Expand All @@ -123,7 +121,7 @@ class ExecuteContext(
val CompileCache: CompileCache,
val PersistedIrCache: mutable.Map[Int, BaseIR],
val PersistedCoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
) extends Closeable {
) extends HailTaskContext with Closeable {

val rngNonce: Long =
try
Expand All @@ -142,10 +140,14 @@ class ExecuteContext(

val memo: mutable.Map[Any, Any] = new mutable.HashMap[Any, Any]()

val taskContext: HailTaskContext = new LocalTaskContext(0, 0)
private[this] val onCloseTasks = mutable.ArrayBuffer.empty[() => Unit]
override def onClose(f: () => Unit): Unit = onCloseTasks += f

def run[A](f: Compiled[A])(implicit E: Enclosing): A =
time(f(theHailClassLoader, fs, this, r))

def scopedExecution[T](f: Compiled[T])(implicit E: Enclosing): T =
using(new LocalTaskContext(0, 0))(tc => time(f(theHailClassLoader, fs, tc, r)))
r.pool.scopedRegion(r => local(r = r)(_.run(f)))

def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String =
tempFileManager.newTmpPath(if (local) localTmpdir else tmpdir, prefix, extension)
Expand All @@ -159,8 +161,8 @@ class ExecuteContext(
def shouldLogIR(): Boolean = !shouldNotLogIR()

override def close(): Unit = {
onCloseTasks.foreach(_())
tempFileManager.close()
taskContext.close()
}

def time[A](block: => A)(implicit E: Enclosing): A =
Expand All @@ -179,7 +181,7 @@ class ExecuteContext(
flags: HailFeatureFlags = this.flags,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: CompileCache = this.CompileCache,
compileCache: CompileCache = this.CompileCache,
persistedIrCache: mutable.Map[Int, BaseIR] = this.PersistedIrCache,
persistedCoercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.PersistedCoercerCache,
)(
Expand All @@ -198,7 +200,7 @@ class ExecuteContext(
flags,
irMetadata,
blockMatrixCache,
codeCache,
compileCache,
persistedIrCache,
persistedCoercerCache,
))(f)
Expand Down
55 changes: 20 additions & 35 deletions hail/hail/src/is/hail/backend/HailTaskContext.scala
Original file line number Diff line number Diff line change
@@ -1,50 +1,35 @@
package is.hail.backend

import is.hail.annotations.RegionPool
import is.hail.utils._
import is.hail.annotations.{Region, RegionPool}
import is.hail.utils.using

import scala.collection.mutable

import java.io.Closeable
trait HailTaskContext {

class TaskFinalizer {
val closeables = mutable.ArrayBuffer.empty[Closeable]
/** region whose lifetime is at least as long as this task */
def r: Region

def clear(): Unit =
closeables.clear()

def addCloseable(c: Closeable): Unit =
closeables += c

def closeAll(): Unit = closeables.foreach(_.close())
/** register an action that will be called when this task completes */
def onClose(f: () => Unit): Unit
}

abstract class HailTaskContext extends AutoCloseable with Logging {
def stageId(): Int

def partitionId(): Int

def attemptNumber(): Int

private lazy val thePool = RegionPool()

def getRegionPool(): RegionPool = thePool
object HailTaskContext {
def runPartition[A](partId: Int)(f: HailTaskContext => A): A =
using(new PartitionContext(partId))(f)
}

val finalizers = mutable.ArrayBuffer.empty[TaskFinalizer]
class PartitionContext(partId: Int) extends HailTaskContext with AutoCloseable {
private[this] val onCloseTasks = mutable.ArrayBuffer.empty[() => Unit]

def newFinalizer(): TaskFinalizer = {
val f = new TaskFinalizer
finalizers += f
f
}
private[this] val pool = RegionPool()
override val r: Region = Region(pool = pool)
override def onClose(f: () => Unit): Unit = onCloseTasks += f

override def close(): Unit = {
logger.info(
s"TaskReport: stage=${stageId()}, partition=${partitionId()}, attempt=${attemptNumber()}, " +
s"peakBytes=${thePool.getHighestTotalUsage}, peakBytesReadable=${formatSpace(thePool.getHighestTotalUsage)}, " +
s"chunks requested=${thePool.getUsage._1}, cache hits=${thePool.getUsage._2}"
)
finalizers.foreach(_.closeAll())
thePool.close()
onCloseTasks.foreach(_())
r.close()
pool.logStats(s"Partition $partId")
pool.close()
}
}
27 changes: 4 additions & 23 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ import com.fasterxml.jackson.core.StreamReadConstraints

class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable

class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskContext {
override def attemptNumber(): Int = 0
}

object LocalBackend extends Backend with Logging {

// From https://github.com/hail-is/hail/issues/14580 :
Expand All @@ -43,16 +39,7 @@ object LocalBackend extends Backend with Logging {
override def broadcast[T: ClassTag](value: T): BroadcastValue[T] =
new LocalBroadcastValue[T](value)

private[this] var stageIdx: Int = 0

private[this] def nextStageId(): Int =
synchronized {
val current = stageIdx
stageIdx += 1
current
}

override def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext = {
override def runtimeContext(ctx: ExecuteContext): DriverRuntimeContext =
new DriverRuntimeContext {

override val executionCache: ExecutionCache =
Expand All @@ -77,14 +64,10 @@ object LocalBackend extends Backend with Logging {
var failure: Option[Throwable] =
None

val stageId = nextStageId()

try
for (idx <- todo)
results += using(new LocalTaskContext(idx, stageId)) { htc =>
htc.getRegionPool().scopedRegion { r =>
f(ctx.theHailClassLoader, ctx.fs, htc, r)(globals, contexts(idx)) -> idx
}
results += ctx.scopedExecution { (hcl, fs, ctx, r) =>
(f(hcl, fs, ctx, r)(globals, contexts(idx)), idx)
}
catch {
case NonFatal(t) =>
Expand All @@ -94,12 +77,10 @@ object LocalBackend extends Backend with Logging {
(failure, results.result())
}
}
}

override def defaultParallelism: Int = 1

override def close(): Unit =
synchronized { stageIdx = 0 }
override def close(): Unit = {}

private[this] def _jvmLowerAndExecute(
ctx: ExecuteContext,
Expand Down
7 changes: 2 additions & 5 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import is.hail.Revision
import is.hail.backend._
import is.hail.backend.Backend.PartitionFn
import is.hail.backend.ExecutionCache.Flags.UseFastRestarts
import is.hail.backend.local.LocalTaskContext
import is.hail.backend.service.ServiceBackend.Flags._
import is.hail.collection.FastSeq
import is.hail.collection.compat.immutable.ArraySeq
Expand Down Expand Up @@ -361,10 +360,8 @@ class ServiceBackend(
partitions.getOrElse(contexts.indices) match {
case Seq(k) =>
try
using(new LocalTaskContext(k, stageCount)) { htc =>
None -> htc.getRegionPool().scopedRegion { r =>
FastSeq(f(ctx.theHailClassLoader, ctx.fs, htc, r)(globals, contexts(k)) -> k)
}
ctx.scopedExecution { (hcl, fs, htc, r) =>
(None, FastSeq(f(hcl, fs, htc, r)(globals, contexts(k)) -> k))
}
catch {
case NonFatal(t) => Some(t) -> ArraySeq.empty
Expand Down
17 changes: 5 additions & 12 deletions hail/hail/src/is/hail/backend/service/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ import java.util
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger

class ServiceTaskContext(val partitionId: Int) extends HailTaskContext {
override def stageId(): Int = 0

override def attemptNumber(): Int = 0
}

class WorkerTimer extends Logging {

var startTimes: mutable.Map[String, Long] = mutable.Map()
Expand Down Expand Up @@ -250,12 +244,11 @@ object Worker extends Logging {
inputs.flatMap { case (globals, context, f) =>
timer.enter("execute") {
try
using(new ServiceTaskContext(partition)) { htc =>
retryTransientErrors {
htc.getRegionPool().scopedRegion { r =>
Right(f(hcl, fs, htc, r)(globals, context))
}
}
HailTaskContext.runPartition(partition) { htc =>
retryTransientErrors(
Right(f(hcl, fs, htc, htc.r)(globals, context)),
Some(() => htc.r.clear()),
)
}
catch {
case t: Throwable => Left(t)
Expand Down
34 changes: 12 additions & 22 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,23 @@ class SparkBroadcastValue[T](bc: Broadcast[T]) extends BroadcastValue[T] with Se
}

object SparkTaskContext {
def get(): SparkTaskContext = taskContext.get
def get: HailTaskContext = taskContext.get

private[this] val taskContext: ThreadLocal[SparkTaskContext] =
new ThreadLocal[SparkTaskContext]() {
override def initialValue(): SparkTaskContext = {
private[this] val taskContext: ThreadLocal[HailTaskContext] =
new ThreadLocal[HailTaskContext]() {
override def initialValue(): HailTaskContext = {
val sparkTC = TaskContext.get()
assert(sparkTC != null, "Spark Task Context was null, maybe this ran on the driver?")
sparkTC.addTaskCompletionListener[Unit]((_: TaskContext) => SparkTaskContext.finish()): Unit

// this must be the only place where SparkTaskContext classes are created
new SparkTaskContext(sparkTC)
val htc = new PartitionContext(sparkTC.stageId())
sparkTC.addTaskCompletionListener[Unit] { _ => htc.close(); remove(); }: Unit

htc
}
}

def finish(): Unit = {
taskContext.get().close()
def finish(): Unit =
taskContext.remove()
}
}

class SparkTaskContext private[spark] (ctx: TaskContext) extends HailTaskContext {
self =>
override def stageId(): Int = ctx.stageId()
override def partitionId(): Int = ctx.partitionId()
override def attemptNumber(): Int = ctx.attemptNumber()
}

object SparkBackend extends Logging {
Expand Down Expand Up @@ -267,11 +259,9 @@ class SparkBackend(val spark: SparkSession) extends Backend with Logging {

override def compute(partition: Partition, context: TaskContext)
: Iterator[Array[Byte]] = {
val htc = SparkTaskContext.get()
htc.getRegionPool().scopedRegion { r =>
val g = f(unsafeHailClassLoaderForSparkWorkers, new HadoopFS(fsConfig), htc, r)
Iterator.single(g(globals, partition.asInstanceOf[RDDPartition].data))
}
val ctx = SparkTaskContext.get
val g = f(unsafeHailClassLoaderForSparkWorkers, new HadoopFS(fsConfig), ctx, ctx.r)
Iterator.single(g(globals, partition.asInstanceOf[RDDPartition].data))
}
}

Expand Down
Loading