diff --git a/hail/hail/src/is/hail/expr/ir/AggOp.scala b/hail/hail/src/is/hail/expr/ir/AggOp.scala index ebc5885ebf9..c149a545e95 100644 --- a/hail/hail/src/is/hail/expr/ir/AggOp.scala +++ b/hail/hail/src/is/hail/expr/ir/AggOp.scala @@ -2,6 +2,8 @@ package is.hail.expr.ir import is.hail.collection.FastSeq import is.hail.expr.ir.agg._ +import is.hail.io.TypedCodecSpec +import is.hail.types.physical.PStruct import is.hail.types.virtual._ sealed trait AggOp {} @@ -28,6 +30,7 @@ final case class ImputeType() extends AggOp final case class NDArraySum() extends AggOp final case class NDArrayMultiplyAdd() extends AggOp final case class Fold() extends AggOp +final case class WriteTBD(codec: TypedCodecSpec, indexKey: Option[PStruct]) extends AggOp // exists === map(p).sum, needs short-circuiting aggs // forall === map(p).product, needs short-circuiting aggs @@ -55,6 +58,7 @@ object AggOp { case (Downsample(), Seq(_, _, _)) => DownsampleAggregator.resultType case (NDArraySum(), Seq(t)) => t case (NDArrayMultiplyAdd(), Seq(a: TNDArray, _)) => a + case (WriteTBD(_, _), _) => TString case _ => throw new UnsupportedExtraction(this.toString) } diff --git a/hail/hail/src/is/hail/expr/ir/TableWriter.scala b/hail/hail/src/is/hail/expr/ir/TableWriter.scala index c770165403a..4c2502f5dc3 100644 --- a/hail/hail/src/is/hail/expr/ir/TableWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/TableWriter.scala @@ -74,19 +74,6 @@ object TableNativeWriter { // write out partitioner key, which may be stricter than table key val partitioner = ts.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - val rowWriter = PartitionNativeWriter( - rowSpec, - pKey.fieldNames, - s"$path/rows/parts/", - Some(s"$path/index/" -> pKey), - if (stageLocally) Some(FileSystems.getDefault.getPath( - ctx.localTmpdir, - s"hail_staging_tmp_${UUID.randomUUID()}", - "rows", - "parts", - )) - else None, - ) val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/parts/", None, None) @@ -106,8 +93,53 @@ object TableNativeWriter { } }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("table_native_writer") { (rows, ctxRef) => - val file = GetField(ctxRef, "writeCtx") - WritePartition(rows, file + UUID4(), rowWriter) + bindIR(GetField(ctxRef, "writeCtx") + UUID4()) { file => + val root = s"$path/rows/parts/" + val partPath = file + UUID4() + + val zero = makestruct( + "distinct" -> !pKey.fieldNames.isEmpty, + "firstKey" -> NA(pKey.virtualType), + "lastKey" -> NA(pKey.virtualType), + ) + val partResult = streamAggIR(rows) { row => + makestruct( + "partpath" -> ApplyAggOp( + WriteTBD(rowSpec, Some(pKey)), + partPath, + Str(root), + Str(s"$path/index/"), + )(row), + "partitionCounts" -> ApplyAggOp(Count())(), + "keyMeta" -> aggFoldIR(zero) { accum => + bindIRs(SelectFields(row, pKey.fieldNames), GetField(accum, "lastKey")) { + case Seq(key, prev) => + makestruct( + "distinct" -> (GetField(accum, "distinct") && Coalesce(FastSeq( + prev.cne(key), + True(), + ))), + "firstKey" -> Coalesce(FastSeq(GetField(accum, "firstKey"), key)), + "lastKey" -> Coalesce(FastSeq(key, prev)), + ) + } + } { (accum1, accum2) => + Die("unreachable: calling combop on writer fold makes no sense", zero.typ) + }, + ) + } + bindIR(partResult) { result => + bindIR(GetField(result, "keyMeta")) { keymeta => + makestruct( + "filePath" -> GetField(result, "partpath"), + "partitionCounts" -> GetField(result, "partitionCounts"), + "distinctlyKeyed" -> GetField(keymeta, "distinct"), + "firstKey" -> GetField(keymeta, "firstKey"), + "lastKey" -> GetField(keymeta, "lastKey"), + ) + } + } + } } { (parts, globals) => val writeGlobals = WritePartition( MakeStream(FastSeq(globals), TStream(globals.typ)), diff --git a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala index 51238a140e5..d2d934d8759 100644 --- a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala +++ b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala @@ -489,7 +489,10 @@ object TypeCheck { s"${args.map(_.typ)} != ${aggSig.initOpTypes}", ) case SeqOp(_, args, aggSig) => - assert(args.map(_.typ) == aggSig.seqOpTypes) + assert( + args.map(_.typ) == aggSig.seqOpTypes, + s"${args.map(_.typ)} != ${aggSig.seqOpTypes}\n$aggSig", + ) case _: CombOp => case _: ResultOp => case AggStateValue(_, _) => diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index a17b0577cfc..4f8cb414e12 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -12,6 +12,7 @@ import is.hail.expr.ir._ import is.hail.expr.ir.defs._ import is.hail.io.BufferSpec import is.hail.types.{tcoerce, TypeWithRequiredness, VirtualTypeWithReq} +import is.hail.types.physical.PStruct import is.hail.types.physical.stypes.EmitType import is.hail.types.virtual._ @@ -59,6 +60,9 @@ object AggStateSig { seqVTypes.head.setRequired(false) ) // set required to false to handle empty aggs case NDArrayMultiplyAdd() => NDArrayMultiplyAddStateSig(seqVTypes.head.setRequired(false)) + case WriteTBD(codecSpec, key) => + assert(codecSpec.encodedVirtualType == seqVTypes.head.t) + WriteSig(seqVTypes.head, key) case _ => throw new UnsupportedExtraction(op.toString) } } @@ -92,6 +96,7 @@ object AggStateSig { val vWithReq = resultEmitType.typeWithRequiredness new TypedRegionBackedAggState(vWithReq, cb) case LinearRegressionStateSig() => new LinearRegressionAggregatorState(cb) + case WriteSig(_, key) => new StreamWriterState(cb, key) } } @@ -142,6 +147,9 @@ case class FoldStateSig( combOpIR: IR, ) extends AggStateSig(ArraySeq(resultEmitType.typeWithRequiredness), None) +case class WriteSig(rowType: VirtualTypeWithReq, indexKey: Option[PStruct]) + extends AggStateSig(ArraySeq(rowType), None) + object PhysicalAggSig { def apply(op: AggOp, state: AggStateSig): PhysicalAggSig = BasicPhysicalAggSig(op, state) @@ -467,6 +475,8 @@ object Extract { new NDArrayMultiplyAddAggregator(nda) case PhysicalAggSig(Fold(), FoldStateSig(res, accumName, otherAccumName, combOpIR)) => new FoldAggregator(res, accumName, otherAccumName, combOpIR) + case PhysicalAggSig(WriteTBD(codec, indexKey), WriteSig(_, _)) => + new StreamWriterAggregator(codec, indexKey.isDefined) } def apply(ctx: ExecuteContext, ir: IR, r: RequirednessAnalysis, isScan: Boolean = false) diff --git a/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala new file mode 100644 index 00000000000..8ed483f31e2 --- /dev/null +++ b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala @@ -0,0 +1,127 @@ +package is.hail.expr.ir.agg + +import is.hail.annotations.Region +import is.hail.asm4s._ +import is.hail.asm4s.implicits.valueToRichCodeOutputBuffer +import is.hail.backend.ExecuteContext +import is.hail.collection.compat.immutable.ArraySeq +import is.hail.expr.ir._ +import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec} +import is.hail.io.index.StagedIndexWriter +import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SValue} +import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaStringValue} +import is.hail.types.virtual._ +import is.hail.utils.fatal + +class StreamWriterState(override val kb: EmitClassBuilder[_], indexKey: Option[PStruct]) + extends AggregatorState { + val outb: Settable[OutputBuffer] = kb.genFieldThisRef[OutputBuffer]() + val part: Settable[String] = kb.genFieldThisRef[String]() + + val indexWriter = indexKey.map { key => + val branchingFactor = + Option(kb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) + StagedIndexWriter.withDefaults(key, kb, branchingFactor = branchingFactor) + } + + override def storageType = PCanonicalStringRequired + + override def createState(cb: EmitCodeBuilder): Unit = {} + + override def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = {} + + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = fatal("makes no sense to load a writer's state") + + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = {} + + override def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = ??? + + override def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = ??? + + override def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = ??? + + private[agg] def addToIndex(cb: EmitCodeBuilder, codeRow: SValue): Unit = + indexWriter.foreach { iw => + val row = codeRow.asBaseStruct + val rowKey = row.subset(indexKey.get.fieldNames: _*) + iw.add( + cb, + IEmitCode.present(cb, rowKey), + outb.invoke[Long]("indexOffset"), + IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L)), + ) + } +} + +class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends StagedAggregator { + type State = StreamWriterState + + val initOpTypes: IndexedSeq[Type] = ArraySeq( + TString, // partfile base name + TString, // path root _with_ 'directory' separator + ) ++ (if (indexed) Some(TString) + else None) // if indexed, index root path _with_ 'directory' separator + val seqOpTypes: IndexedSeq[Type] = ArraySeq(spec.encodedVirtualType) + val resultEmitType = EmitType(SJavaString, true) + + override protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { + val (partEC, rootEC, ixrootEC) = init match { + case Array(root, part) => + require(!indexed) + (root, part, None) + case Array(root, part, ixroot) => + require(indexed) + (root, part, Some(ixroot)) + } + + val root = rootEC.toI(cb).getOrFatal(cb, "path cannot be missing").asString.loadString(cb) + val part = partEC.toI(cb).getOrFatal(cb, "part cannot be missing").asString.loadString(cb) + val os = cb.emb.createUnbuffered(root.concat(part)) + + state.indexWriter.foreach { iw => + val root = + ixrootEC.get.toI(cb).getOrFatal(cb, "index path cannot be missing").asString.loadString(cb) + val path = cb.memoize(root.concat(part).concat(".idx")) + iw.init(cb, path, cb.memoize(cb.emb.getObject[Map[String, Any]](Map.empty))) + } + + cb.assign(state.part, part) + cb.assign(state.outb, spec.buildCodeOutputBuffer(os)) + } + + override protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { + val Array(rowEC) = seq + val row = rowEC.toI(cb).getOrFatal(cb, "row cannot be missing") + val encoder = spec.encodedType.buildEncoder(row.st, cb.emb.ecb) + + state.addToIndex(cb, row) + cb += state.outb.writeByte(1.asInstanceOf[Byte]) + encoder.apply(cb, row, state.outb) + } + + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) + : IEmitCode = { + cb += state.outb.writeByte(0.asInstanceOf[Byte]) + cb += state.outb.flush() + cb += state.outb.close() + state.indexWriter.foreach(_.close(cb)) + IEmitCode.present(cb, new SJavaStringValue(state.part)) + } + + override protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: State, + other: State, + ): Unit = fatal("makes no sense to call a combop on the writer") +} diff --git a/hail/hail/src/is/hail/io/BufferSpecs.scala b/hail/hail/src/is/hail/io/BufferSpecs.scala index 6a883155f94..f4e303c909f 100644 --- a/hail/hail/src/is/hail/io/BufferSpecs.scala +++ b/hail/hail/src/is/hail/io/BufferSpecs.scala @@ -4,6 +4,7 @@ import is.hail.asm4s._ import is.hail.compatibility.{LEB128BufferSpec, LZ4BlockBufferSpec} import is.hail.io.compress.LZ4 import is.hail.rvd.AbstractRVDSpec +import is.hail.utils.implicits.ByteTrackingOutputStream import java.io._ @@ -266,8 +267,10 @@ object StreamBlockBufferSpec { final class StreamBlockBufferSpec extends BlockBufferSpec { override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer(in) - override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = - new StreamBlockOutputBuffer(out) + override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = out match { + case out: ByteTrackingOutputStream => new StreamBlockOutputBuffer(out) + case out => new StreamBlockOutputBuffer(new ByteTrackingOutputStream(out)) + } override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = Code.newInstance[StreamBlockInputBuffer, InputStream](in) @@ -280,19 +283,28 @@ final class StreamBlockBufferSpec extends BlockBufferSpec { object StreamBlockBufferSpec2 { def extract(jv: JValue): StreamBlockBufferSpec2 = new StreamBlockBufferSpec2 + + def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = out match { + case out: ByteTrackingOutputStream => new StreamBlockOutputBuffer2(out) + case out => new StreamBlockOutputBuffer2(new ByteTrackingOutputStream(out)) + } } final class StreamBlockBufferSpec2 extends BlockBufferSpec { override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer2(in) override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = - new StreamBlockOutputBuffer2(out) + StreamBlockBufferSpec2.buildOutputBuffer(out) override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = Code.newInstance[StreamBlockInputBuffer2, InputStream](in) override def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.newInstance[StreamBlockOutputBuffer2, OutputStream](out) + Code.invokeScalaObject1[OutputStream, OutputBlockBuffer]( + StreamBlockBufferSpec2.getClass, + "buildOutputBuffer", + out, + ) override def equals(other: Any): Boolean = other.isInstanceOf[StreamBlockBufferSpec2] } diff --git a/hail/hail/src/is/hail/io/OutputBuffers.scala b/hail/hail/src/is/hail/io/OutputBuffers.scala index 543bde1f854..85206eba3a1 100644 --- a/hail/hail/src/is/hail/io/OutputBuffers.scala +++ b/hail/hail/src/is/hail/io/OutputBuffers.scala @@ -268,7 +268,7 @@ final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten } -final class StreamBlockOutputBuffer2(out: OutputStream) extends OutputBlockBuffer { +final class StreamBlockOutputBuffer2(out: ByteTrackingOutputStream) extends OutputBlockBuffer { override def flush(): Unit = out.flush() @@ -288,7 +288,7 @@ final class StreamBlockOutputBuffer2(out: OutputStream) extends OutputBlockBuffe out.write(buf, 0, len) } - override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten + override def getPos(): Long = out.bytesWritten } final class LZ4OutputBlockBuffer(lz4: LZ4, blockSize: Int, out: OutputBlockBuffer)