Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hail/hail/src/is/hail/expr/ir/AggOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package is.hail.expr.ir

import is.hail.collection.FastSeq
import is.hail.expr.ir.agg._
import is.hail.io.TypedCodecSpec
import is.hail.types.physical.PStruct
import is.hail.types.virtual._

sealed trait AggOp {}
Expand All @@ -28,6 +30,7 @@ final case class ImputeType() extends AggOp
final case class NDArraySum() extends AggOp
final case class NDArrayMultiplyAdd() extends AggOp
final case class Fold() extends AggOp
final case class WriteTBD(codec: TypedCodecSpec, indexKey: Option[PStruct]) extends AggOp

// exists === map(p).sum, needs short-circuiting aggs
// forall === map(p).product, needs short-circuiting aggs
Expand Down Expand Up @@ -55,6 +58,7 @@ object AggOp {
case (Downsample(), Seq(_, _, _)) => DownsampleAggregator.resultType
case (NDArraySum(), Seq(t)) => t
case (NDArrayMultiplyAdd(), Seq(a: TNDArray, _)) => a
case (WriteTBD(_, _), _) => TString
case _ => throw new UnsupportedExtraction(this.toString)
}

Expand Down
62 changes: 47 additions & 15 deletions hail/hail/src/is/hail/expr/ir/TableWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,6 @@ object TableNativeWriter {
// write out partitioner key, which may be stricter than table key
val partitioner = ts.partitioner
val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType))
val rowWriter = PartitionNativeWriter(
rowSpec,
pKey.fieldNames,
s"$path/rows/parts/",
Some(s"$path/index/" -> pKey),
if (stageLocally) Some(FileSystems.getDefault.getPath(
ctx.localTmpdir,
s"hail_staging_tmp_${UUID.randomUUID()}",
"rows",
"parts",
))
else None,
)

val globalWriter =
PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/parts/", None, None)
Expand All @@ -106,8 +93,53 @@ object TableNativeWriter {
}
}(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("table_native_writer") {
(rows, ctxRef) =>
val file = GetField(ctxRef, "writeCtx")
WritePartition(rows, file + UUID4(), rowWriter)
bindIR(GetField(ctxRef, "writeCtx") + UUID4()) { file =>
val root = s"$path/rows/parts/"
val partPath = file + UUID4()

val zero = makestruct(
"distinct" -> !pKey.fieldNames.isEmpty,
"firstKey" -> NA(pKey.virtualType),
"lastKey" -> NA(pKey.virtualType),
)
val partResult = streamAggIR(rows) { row =>
makestruct(
"partpath" -> ApplyAggOp(
WriteTBD(rowSpec, Some(pKey)),
partPath,
Str(root),
Str(s"$path/index/"),
)(row),
"partitionCounts" -> ApplyAggOp(Count())(),
"keyMeta" -> aggFoldIR(zero) { accum =>
bindIRs(SelectFields(row, pKey.fieldNames), GetField(accum, "lastKey")) {
case Seq(key, prev) =>
makestruct(
"distinct" -> (GetField(accum, "distinct") && Coalesce(FastSeq(
prev.cne(key),
True(),
))),
"firstKey" -> Coalesce(FastSeq(GetField(accum, "firstKey"), key)),
"lastKey" -> Coalesce(FastSeq(key, prev)),
)
}
} { (accum1, accum2) =>
Die("unreachable: calling combop on writer fold makes no sense", zero.typ)
},
)
}
bindIR(partResult) { result =>
bindIR(GetField(result, "keyMeta")) { keymeta =>
makestruct(
"filePath" -> GetField(result, "partpath"),
"partitionCounts" -> GetField(result, "partitionCounts"),
"distinctlyKeyed" -> GetField(keymeta, "distinct"),
"firstKey" -> GetField(keymeta, "firstKey"),
"lastKey" -> GetField(keymeta, "lastKey"),
)
}
}
}
} { (parts, globals) =>
val writeGlobals = WritePartition(
MakeStream(FastSeq(globals), TStream(globals.typ)),
Expand Down
5 changes: 4 additions & 1 deletion hail/hail/src/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ object TypeCheck {
s"${args.map(_.typ)} != ${aggSig.initOpTypes}",
)
case SeqOp(_, args, aggSig) =>
assert(args.map(_.typ) == aggSig.seqOpTypes)
assert(
args.map(_.typ) == aggSig.seqOpTypes,
s"${args.map(_.typ)} != ${aggSig.seqOpTypes}\n$aggSig",
)
case _: CombOp =>
case _: ResultOp =>
case AggStateValue(_, _) =>
Expand Down
10 changes: 10 additions & 0 deletions hail/hail/src/is/hail/expr/ir/agg/Extract.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import is.hail.expr.ir._
import is.hail.expr.ir.defs._
import is.hail.io.BufferSpec
import is.hail.types.{tcoerce, TypeWithRequiredness, VirtualTypeWithReq}
import is.hail.types.physical.PStruct
import is.hail.types.physical.stypes.EmitType
import is.hail.types.virtual._

Expand Down Expand Up @@ -59,6 +60,9 @@ object AggStateSig {
seqVTypes.head.setRequired(false)
) // set required to false to handle empty aggs
case NDArrayMultiplyAdd() => NDArrayMultiplyAddStateSig(seqVTypes.head.setRequired(false))
case WriteTBD(codecSpec, key) =>
assert(codecSpec.encodedVirtualType == seqVTypes.head.t)
WriteSig(seqVTypes.head, key)
case _ => throw new UnsupportedExtraction(op.toString)
}
}
Expand Down Expand Up @@ -92,6 +96,7 @@ object AggStateSig {
val vWithReq = resultEmitType.typeWithRequiredness
new TypedRegionBackedAggState(vWithReq, cb)
case LinearRegressionStateSig() => new LinearRegressionAggregatorState(cb)
case WriteSig(_, key) => new StreamWriterState(cb, key)
}
}

Expand Down Expand Up @@ -142,6 +147,9 @@ case class FoldStateSig(
combOpIR: IR,
) extends AggStateSig(ArraySeq(resultEmitType.typeWithRequiredness), None)

case class WriteSig(rowType: VirtualTypeWithReq, indexKey: Option[PStruct])
extends AggStateSig(ArraySeq(rowType), None)

object PhysicalAggSig {
def apply(op: AggOp, state: AggStateSig): PhysicalAggSig = BasicPhysicalAggSig(op, state)

Expand Down Expand Up @@ -467,6 +475,8 @@ object Extract {
new NDArrayMultiplyAddAggregator(nda)
case PhysicalAggSig(Fold(), FoldStateSig(res, accumName, otherAccumName, combOpIR)) =>
new FoldAggregator(res, accumName, otherAccumName, combOpIR)
case PhysicalAggSig(WriteTBD(codec, indexKey), WriteSig(_, _)) =>
new StreamWriterAggregator(codec, indexKey.isDefined)
}

def apply(ctx: ExecuteContext, ir: IR, r: RequirednessAnalysis, isScan: Boolean = false)
Expand Down
127 changes: 127 additions & 0 deletions hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package is.hail.expr.ir.agg

import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.asm4s.implicits.valueToRichCodeOutputBuffer
import is.hail.backend.ExecuteContext
import is.hail.collection.compat.immutable.ArraySeq
import is.hail.expr.ir._
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec}
import is.hail.io.index.StagedIndexWriter
import is.hail.types.physical._
import is.hail.types.physical.stypes.{EmitType, SValue}
import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaStringValue}
import is.hail.types.virtual._
import is.hail.utils.fatal

class StreamWriterState(override val kb: EmitClassBuilder[_], indexKey: Option[PStruct])
extends AggregatorState {
val outb: Settable[OutputBuffer] = kb.genFieldThisRef[OutputBuffer]()
val part: Settable[String] = kb.genFieldThisRef[String]()

val indexWriter = indexKey.map { key =>
val branchingFactor =
Option(kb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096)
StagedIndexWriter.withDefaults(key, kb, branchingFactor = branchingFactor)
}

override def storageType = PCanonicalStringRequired

override def createState(cb: EmitCodeBuilder): Unit = {}

override def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = {}

override def load(
cb: EmitCodeBuilder,
regionLoader: (EmitCodeBuilder, Value[Region]) => Unit,
src: Value[Long],
): Unit = fatal("makes no sense to load a writer's state")

override def store(
cb: EmitCodeBuilder,
regionStorer: (EmitCodeBuilder, Value[Region]) => Unit,
dest: Value[Long],
): Unit = {}

override def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = ???

override def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = ???

override def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = ???

private[agg] def addToIndex(cb: EmitCodeBuilder, codeRow: SValue): Unit =
indexWriter.foreach { iw =>
val row = codeRow.asBaseStruct
val rowKey = row.subset(indexKey.get.fieldNames: _*)
iw.add(
cb,
IEmitCode.present(cb, rowKey),
outb.invoke[Long]("indexOffset"),
IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L)),
)
}
}

class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends StagedAggregator {
type State = StreamWriterState

val initOpTypes: IndexedSeq[Type] = ArraySeq(
TString, // partfile base name
TString, // path root _with_ 'directory' separator
) ++ (if (indexed) Some(TString)
else None) // if indexed, index root path _with_ 'directory' separator
val seqOpTypes: IndexedSeq[Type] = ArraySeq(spec.encodedVirtualType)
val resultEmitType = EmitType(SJavaString, true)

override protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = {
val (partEC, rootEC, ixrootEC) = init match {
case Array(root, part) =>
require(!indexed)
(root, part, None)
case Array(root, part, ixroot) =>
require(indexed)
(root, part, Some(ixroot))
}

val root = rootEC.toI(cb).getOrFatal(cb, "path cannot be missing").asString.loadString(cb)
val part = partEC.toI(cb).getOrFatal(cb, "part cannot be missing").asString.loadString(cb)
val os = cb.emb.createUnbuffered(root.concat(part))

state.indexWriter.foreach { iw =>
val root =
ixrootEC.get.toI(cb).getOrFatal(cb, "index path cannot be missing").asString.loadString(cb)
val path = cb.memoize(root.concat(part).concat(".idx"))
iw.init(cb, path, cb.memoize(cb.emb.getObject[Map[String, Any]](Map.empty)))
}

cb.assign(state.part, part)
cb.assign(state.outb, spec.buildCodeOutputBuffer(os))
}

override protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = {
val Array(rowEC) = seq
val row = rowEC.toI(cb).getOrFatal(cb, "row cannot be missing")
val encoder = spec.encodedType.buildEncoder(row.st, cb.emb.ecb)

state.addToIndex(cb, row)
cb += state.outb.writeByte(1.asInstanceOf[Byte])
encoder.apply(cb, row, state.outb)
}

override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region])
: IEmitCode = {
cb += state.outb.writeByte(0.asInstanceOf[Byte])
cb += state.outb.flush()
cb += state.outb.close()
state.indexWriter.foreach(_.close(cb))
IEmitCode.present(cb, new SJavaStringValue(state.part))
}

override protected def _combOp(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
region: Value[Region],
state: State,
other: State,
): Unit = fatal("makes no sense to call a combop on the writer")
}
20 changes: 16 additions & 4 deletions hail/hail/src/is/hail/io/BufferSpecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.asm4s._
import is.hail.compatibility.{LEB128BufferSpec, LZ4BlockBufferSpec}
import is.hail.io.compress.LZ4
import is.hail.rvd.AbstractRVDSpec
import is.hail.utils.implicits.ByteTrackingOutputStream

import java.io._

Expand Down Expand Up @@ -266,8 +267,10 @@ object StreamBlockBufferSpec {
final class StreamBlockBufferSpec extends BlockBufferSpec {
override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer(in)

override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer =
new StreamBlockOutputBuffer(out)
override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = out match {
case out: ByteTrackingOutputStream => new StreamBlockOutputBuffer(out)
case out => new StreamBlockOutputBuffer(new ByteTrackingOutputStream(out))
}

override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] =
Code.newInstance[StreamBlockInputBuffer, InputStream](in)
Expand All @@ -280,19 +283,28 @@ final class StreamBlockBufferSpec extends BlockBufferSpec {

object StreamBlockBufferSpec2 {
def extract(jv: JValue): StreamBlockBufferSpec2 = new StreamBlockBufferSpec2

def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = out match {
case out: ByteTrackingOutputStream => new StreamBlockOutputBuffer2(out)
case out => new StreamBlockOutputBuffer2(new ByteTrackingOutputStream(out))
}
}

final class StreamBlockBufferSpec2 extends BlockBufferSpec {
override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer2(in)

override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer =
new StreamBlockOutputBuffer2(out)
StreamBlockBufferSpec2.buildOutputBuffer(out)

override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] =
Code.newInstance[StreamBlockInputBuffer2, InputStream](in)

override def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] =
Code.newInstance[StreamBlockOutputBuffer2, OutputStream](out)
Code.invokeScalaObject1[OutputStream, OutputBlockBuffer](
StreamBlockBufferSpec2.getClass,
"buildOutputBuffer",
out,
)

override def equals(other: Any): Boolean = other.isInstanceOf[StreamBlockBufferSpec2]
}
Expand Down
4 changes: 2 additions & 2 deletions hail/hail/src/is/hail/io/OutputBuffers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer
override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten
}

final class StreamBlockOutputBuffer2(out: OutputStream) extends OutputBlockBuffer {
final class StreamBlockOutputBuffer2(out: ByteTrackingOutputStream) extends OutputBlockBuffer {
override def flush(): Unit =
out.flush()

Expand All @@ -288,7 +288,7 @@ final class StreamBlockOutputBuffer2(out: OutputStream) extends OutputBlockBuffe
out.write(buf, 0, len)
}

override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten
override def getPos(): Long = out.bytesWritten
}

final class LZ4OutputBlockBuffer(lz4: LZ4, blockSize: Int, out: OutputBlockBuffer)
Expand Down
Loading