From 59186f63ea05749d7221a21bc428dc0122effe83 Mon Sep 17 00:00:00 2001 From: Chris Vittal Date: Fri, 24 Apr 2026 13:51:58 -0400 Subject: [PATCH 1/6] [query] Implement table partition writer as an aggregator --- hail/hail/src/is/hail/expr/ir/AggOp.scala | 2 + .../src/is/hail/expr/ir/TableWriter.scala | 84 +++++++++++++++---- hail/hail/src/is/hail/expr/ir/TypeCheck.scala | 5 +- .../src/is/hail/expr/ir/agg/Extract.scala | 8 ++ .../expr/ir/agg/StreamWriterAggregator.scala | 83 ++++++++++++++++++ 5 files changed, 166 insertions(+), 16 deletions(-) create mode 100644 hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala diff --git a/hail/hail/src/is/hail/expr/ir/AggOp.scala b/hail/hail/src/is/hail/expr/ir/AggOp.scala index ebc5885ebf9..c561e0ea81e 100644 --- a/hail/hail/src/is/hail/expr/ir/AggOp.scala +++ b/hail/hail/src/is/hail/expr/ir/AggOp.scala @@ -28,6 +28,7 @@ final case class ImputeType() extends AggOp final case class NDArraySum() extends AggOp final case class NDArrayMultiplyAdd() extends AggOp final case class Fold() extends AggOp +final case class WriteTBD(codec: is.hail.io.TypedCodecSpec) extends AggOp // exists === map(p).sum, needs short-circuiting aggs // forall === map(p).product, needs short-circuiting aggs @@ -55,6 +56,7 @@ object AggOp { case (Downsample(), Seq(_, _, _)) => DownsampleAggregator.resultType case (NDArraySum(), Seq(t)) => t case (NDArrayMultiplyAdd(), Seq(a: TNDArray, _)) => a + case (WriteTBD(_), _) => TString case _ => throw new UnsupportedExtraction(this.toString) } diff --git a/hail/hail/src/is/hail/expr/ir/TableWriter.scala b/hail/hail/src/is/hail/expr/ir/TableWriter.scala index c770165403a..8688141f9c3 100644 --- a/hail/hail/src/is/hail/expr/ir/TableWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/TableWriter.scala @@ -74,19 +74,19 @@ object TableNativeWriter { // write out partitioner key, which may be stricter than table key val partitioner = ts.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - val rowWriter = PartitionNativeWriter( - rowSpec, - pKey.fieldNames, - s"$path/rows/parts/", - Some(s"$path/index/" -> pKey), - if (stageLocally) Some(FileSystems.getDefault.getPath( - ctx.localTmpdir, - s"hail_staging_tmp_${UUID.randomUUID()}", - "rows", - "parts", - )) - else None, - ) + // val _@rowWriter = PartitionNativeWriter( + // rowSpec, + // pKey.fieldNames, + // s"$path/rows/parts/", + // Some(s"$path/index/" -> pKey), + // if (stageLocally) Some(FileSystems.getDefault.getPath( + // ctx.localTmpdir, + // s"hail_staging_tmp_${UUID.randomUUID()}", + // "rows", + // "parts", + // )) + // else None, + // ) val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/parts/", None, None) @@ -106,8 +106,62 @@ object TableNativeWriter { } }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("table_native_writer") { (rows, ctxRef) => - val file = GetField(ctxRef, "writeCtx") - WritePartition(rows, file + UUID4(), rowWriter) + bindIR(GetField(ctxRef, "writeCtx") + UUID4()) { file => + val root = s"$path/rows/parts/" + val partPath = Str(root) + file + UUID4() + + val zero = makestruct( + "distinct" -> !pKey.fieldNames.isEmpty, + "firstKey" -> NA(pKey.virtualType), + "lastKey" -> NA(pKey.virtualType), + ) + val partResult = streamAggIR(rows) { row => + makestruct( + "fullpartpath" -> ApplyAggOp(WriteTBD(rowSpec), partPath)(row), + "partitionCounts" -> ApplyAggOp(Count())(), + "keyMeta" -> aggFoldIR(zero) { accum => + bindIRs(SelectFields(row, pKey.fieldNames), GetField(accum, "lastKey")) { + case Seq(key, prev) => + makestruct( + "distinct" -> (GetField(accum, "distinct") && Coalesce(FastSeq( + prev.cne(key), + True(), + ))), + "firstKey" -> Coalesce(FastSeq(GetField(accum, "firstKey"), key)), + "lastKey" -> Coalesce(FastSeq(key, prev)), + ) + } + } { (accum1, accum2) => + Die("unreachable: calling combop on writer fold makes no sense", zero.typ) + /* val stillDistinct = GetField(accum1, "distinct") && GetField( accum2, + * "distinct", ) && Coalesce(FastSeq( GetField(accum1, + * "lastKey").cne(GetField(accum2, "firstKey")), True(), )) val first = + * Coalesce(FastSeq(GetField(accum1, "firstKey"), GetField(accum2, "firstKey"))) + * val last = + * Coalesce(FastSeq(GetField(accum2, "lastKey"), GetField(accum1, "lastKey"))) + * makestruct( "distinct" -> stillDistinct, "firstKey" -> first, "lastKey" -> + * last, ) */ + }, + ) + } + bindIR(partResult) { result => + bindIR(GetField(result, "keyMeta")) { keymeta => + makestruct( + "filePath" -> invoke( + "slice", + TString, + GetField(result, "fullpartpath"), + I32(root.length()), + I32(Int.MaxValue), + ), + "partitionCounts" -> GetField(result, "partitionCounts"), + "distinctlyKeyed" -> GetField(keymeta, "distinct"), + "firstKey" -> GetField(keymeta, "firstKey"), + "lastKey" -> GetField(keymeta, "lastKey"), + ) + } + } + } } { (parts, globals) => val writeGlobals = WritePartition( MakeStream(FastSeq(globals), TStream(globals.typ)), diff --git a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala index 51238a140e5..d2d934d8759 100644 --- a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala +++ b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala @@ -489,7 +489,10 @@ object TypeCheck { s"${args.map(_.typ)} != ${aggSig.initOpTypes}", ) case SeqOp(_, args, aggSig) => - assert(args.map(_.typ) == aggSig.seqOpTypes) + assert( + args.map(_.typ) == aggSig.seqOpTypes, + s"${args.map(_.typ)} != ${aggSig.seqOpTypes}\n$aggSig", + ) case _: CombOp => case _: ResultOp => case AggStateValue(_, _) => diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index a17b0577cfc..2ca8b40ff6a 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -59,6 +59,9 @@ object AggStateSig { seqVTypes.head.setRequired(false) ) // set required to false to handle empty aggs case NDArrayMultiplyAdd() => NDArrayMultiplyAddStateSig(seqVTypes.head.setRequired(false)) + case WriteTBD(codecSpec) => + assert(codecSpec.encodedVirtualType == seqVTypes.head.t) + WriteSig(seqVTypes.head) case _ => throw new UnsupportedExtraction(op.toString) } } @@ -92,6 +95,7 @@ object AggStateSig { val vWithReq = resultEmitType.typeWithRequiredness new TypedRegionBackedAggState(vWithReq, cb) case LinearRegressionStateSig() => new LinearRegressionAggregatorState(cb) + case WriteSig(_) => new StreamWriterState(cb) } } @@ -142,6 +146,8 @@ case class FoldStateSig( combOpIR: IR, ) extends AggStateSig(ArraySeq(resultEmitType.typeWithRequiredness), None) +case class WriteSig(rowType: VirtualTypeWithReq) extends AggStateSig(ArraySeq(rowType), None) + object PhysicalAggSig { def apply(op: AggOp, state: AggStateSig): PhysicalAggSig = BasicPhysicalAggSig(op, state) @@ -467,6 +473,8 @@ object Extract { new NDArrayMultiplyAddAggregator(nda) case PhysicalAggSig(Fold(), FoldStateSig(res, accumName, otherAccumName, combOpIR)) => new FoldAggregator(res, accumName, otherAccumName, combOpIR) + case PhysicalAggSig(WriteTBD(codec), WriteSig(_)) => + new StreamWriterAggregator(codec) } def apply(ctx: ExecuteContext, ir: IR, r: RequirednessAnalysis, isScan: Boolean = false) diff --git a/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala new file mode 100644 index 00000000000..d3af9d21de5 --- /dev/null +++ b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala @@ -0,0 +1,83 @@ +package is.hail.expr.ir.agg + +import is.hail.annotations.Region +import is.hail.asm4s._ +import is.hail.asm4s.implicits.{valueToRichCodeOutputBuffer, valueToRichCodeRegion} +import is.hail.backend.ExecuteContext +import is.hail.collection.compat.immutable.ArraySeq +import is.hail.expr.ir._ +import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec} +import is.hail.types.physical._ +import is.hail.types.physical.stypes.EmitType +import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaStringValue} +import is.hail.types.virtual._ +import is.hail.utils.fatal + +class StreamWriterState(override val kb: EmitClassBuilder[_]) extends AggregatorState { + val outb: Settable[OutputBuffer] = kb.genFieldThisRef[OutputBuffer]() + val path: Settable[String] = kb.genFieldThisRef[String]() + + override def storageType = PCanonicalStringRequired + + override def createState(cb: EmitCodeBuilder): Unit = {} + + override def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = {} + + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = fatal("makes no sense to load a writer's state") + + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = {} + + override def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = ??? + + override def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = ??? + + override def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = ??? +} + +class StreamWriterAggregator(spec: TypedCodecSpec) extends StagedAggregator { + type State = StreamWriterState + + val initOpTypes: IndexedSeq[Type] = ArraySeq(TString) + val seqOpTypes: IndexedSeq[Type] = ArraySeq(spec.encodedVirtualType) + val resultEmitType = EmitType(SJavaString, true) + + override protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { + val Array(pathEC) = init + val path = pathEC.toI(cb).getOrFatal(cb, "path cannot be missing").asString.loadString(cb) + val os = cb.emb.createUnbuffered(path) + cb.assign(state.path, path) + cb.assign(state.outb, spec.buildCodeOutputBuffer(os)) + } + + override protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { + val Array(rowEC) = seq + val row = rowEC.toI(cb).getOrFatal(cb, "row cannot be missing") + val encoder = spec.encodedType.buildEncoder(row.st, cb.emb.ecb) + cb += state.outb.writeByte(1.asInstanceOf[Byte]) + encoder.apply(cb, row, state.outb) + } + + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) + : IEmitCode = { + cb += state.outb.writeByte(0.asInstanceOf[Byte]) + cb += state.outb.flush() + cb += state.outb.close() + IEmitCode.present(cb, new SJavaStringValue(state.path)) + } + + override protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: State, + other: State, + ): Unit = fatal("makes no sense to call a combop on the writer") +} From cdf9e02040d37ac4b690d7ed02bf08e9d1090f01 Mon Sep 17 00:00:00 2001 From: Chris Vittal Date: Fri, 1 May 2026 12:14:47 -0400 Subject: [PATCH 2/6] index writing --- hail/hail/src/is/hail/expr/ir/AggOp.scala | 6 +- .../src/is/hail/expr/ir/TableWriter.scala | 12 +--- .../src/is/hail/expr/ir/agg/Extract.scala | 13 +++-- .../expr/ir/agg/StreamWriterAggregator.scala | 57 +++++++++++++++---- hail/hail/src/is/hail/io/BufferSpecs.scala | 17 ++++-- hail/hail/src/is/hail/io/OutputBuffers.scala | 8 +-- 6 files changed, 76 insertions(+), 37 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/AggOp.scala b/hail/hail/src/is/hail/expr/ir/AggOp.scala index c561e0ea81e..4c525b5f0e7 100644 --- a/hail/hail/src/is/hail/expr/ir/AggOp.scala +++ b/hail/hail/src/is/hail/expr/ir/AggOp.scala @@ -3,6 +3,8 @@ package is.hail.expr.ir import is.hail.collection.FastSeq import is.hail.expr.ir.agg._ import is.hail.types.virtual._ +import is.hail.types.physical.PStruct +import is.hail.io.TypedCodecSpec sealed trait AggOp {} final case class ApproxCDF() extends AggOp @@ -28,7 +30,7 @@ final case class ImputeType() extends AggOp final case class NDArraySum() extends AggOp final case class NDArrayMultiplyAdd() extends AggOp final case class Fold() extends AggOp -final case class WriteTBD(codec: is.hail.io.TypedCodecSpec) extends AggOp +final case class WriteTBD(codec: TypedCodecSpec, indexKey: Option[PStruct]) extends AggOp // exists === map(p).sum, needs short-circuiting aggs // forall === map(p).product, needs short-circuiting aggs @@ -56,7 +58,7 @@ object AggOp { case (Downsample(), Seq(_, _, _)) => DownsampleAggregator.resultType case (NDArraySum(), Seq(t)) => t case (NDArrayMultiplyAdd(), Seq(a: TNDArray, _)) => a - case (WriteTBD(_), _) => TString + case (WriteTBD(_, _), _) => TString case _ => throw new UnsupportedExtraction(this.toString) } diff --git a/hail/hail/src/is/hail/expr/ir/TableWriter.scala b/hail/hail/src/is/hail/expr/ir/TableWriter.scala index 8688141f9c3..8845a93f45c 100644 --- a/hail/hail/src/is/hail/expr/ir/TableWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/TableWriter.scala @@ -108,7 +108,7 @@ object TableNativeWriter { (rows, ctxRef) => bindIR(GetField(ctxRef, "writeCtx") + UUID4()) { file => val root = s"$path/rows/parts/" - val partPath = Str(root) + file + UUID4() + val partPath = file + UUID4() val zero = makestruct( "distinct" -> !pKey.fieldNames.isEmpty, @@ -117,7 +117,7 @@ object TableNativeWriter { ) val partResult = streamAggIR(rows) { row => makestruct( - "fullpartpath" -> ApplyAggOp(WriteTBD(rowSpec), partPath)(row), + "partpath" -> ApplyAggOp(WriteTBD(rowSpec, Some(pKey)), partPath, Str(root), Str(s"$path/index/"))(row), "partitionCounts" -> ApplyAggOp(Count())(), "keyMeta" -> aggFoldIR(zero) { accum => bindIRs(SelectFields(row, pKey.fieldNames), GetField(accum, "lastKey")) { @@ -147,13 +147,7 @@ object TableNativeWriter { bindIR(partResult) { result => bindIR(GetField(result, "keyMeta")) { keymeta => makestruct( - "filePath" -> invoke( - "slice", - TString, - GetField(result, "fullpartpath"), - I32(root.length()), - I32(Int.MaxValue), - ), + "filePath" -> GetField(result, "partpath"), "partitionCounts" -> GetField(result, "partitionCounts"), "distinctlyKeyed" -> GetField(keymeta, "distinct"), "firstKey" -> GetField(keymeta, "firstKey"), diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index 2ca8b40ff6a..8cb95478485 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -12,6 +12,7 @@ import is.hail.expr.ir._ import is.hail.expr.ir.defs._ import is.hail.io.BufferSpec import is.hail.types.{tcoerce, TypeWithRequiredness, VirtualTypeWithReq} +import is.hail.types.physical.PStruct import is.hail.types.physical.stypes.EmitType import is.hail.types.virtual._ @@ -59,9 +60,9 @@ object AggStateSig { seqVTypes.head.setRequired(false) ) // set required to false to handle empty aggs case NDArrayMultiplyAdd() => NDArrayMultiplyAddStateSig(seqVTypes.head.setRequired(false)) - case WriteTBD(codecSpec) => + case WriteTBD(codecSpec, key) => assert(codecSpec.encodedVirtualType == seqVTypes.head.t) - WriteSig(seqVTypes.head) + WriteSig(seqVTypes.head, key) case _ => throw new UnsupportedExtraction(op.toString) } } @@ -95,7 +96,7 @@ object AggStateSig { val vWithReq = resultEmitType.typeWithRequiredness new TypedRegionBackedAggState(vWithReq, cb) case LinearRegressionStateSig() => new LinearRegressionAggregatorState(cb) - case WriteSig(_) => new StreamWriterState(cb) + case WriteSig(_, key) => new StreamWriterState(cb, key) } } @@ -146,7 +147,7 @@ case class FoldStateSig( combOpIR: IR, ) extends AggStateSig(ArraySeq(resultEmitType.typeWithRequiredness), None) -case class WriteSig(rowType: VirtualTypeWithReq) extends AggStateSig(ArraySeq(rowType), None) +case class WriteSig(rowType: VirtualTypeWithReq, indexKey: Option[PStruct]) extends AggStateSig(ArraySeq(rowType), None) object PhysicalAggSig { def apply(op: AggOp, state: AggStateSig): PhysicalAggSig = BasicPhysicalAggSig(op, state) @@ -473,8 +474,8 @@ object Extract { new NDArrayMultiplyAddAggregator(nda) case PhysicalAggSig(Fold(), FoldStateSig(res, accumName, otherAccumName, combOpIR)) => new FoldAggregator(res, accumName, otherAccumName, combOpIR) - case PhysicalAggSig(WriteTBD(codec), WriteSig(_)) => - new StreamWriterAggregator(codec) + case PhysicalAggSig(WriteTBD(codec, indexKey), WriteSig(_, _)) => + new StreamWriterAggregator(codec, indexKey.isDefined) } def apply(ctx: ExecuteContext, ir: IR, r: RequirednessAnalysis, isScan: Boolean = false) diff --git a/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala index d3af9d21de5..ec57bb5944b 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala @@ -2,20 +2,25 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.asm4s.implicits.{valueToRichCodeOutputBuffer, valueToRichCodeRegion} +import is.hail.asm4s.implicits.{valueToRichCodeOutputBuffer} import is.hail.backend.ExecuteContext import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir._ import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec} +import is.hail.io.index.StagedIndexWriter import is.hail.types.physical._ -import is.hail.types.physical.stypes.EmitType +import is.hail.types.physical.stypes.{EmitType, SValue} import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaStringValue} import is.hail.types.virtual._ import is.hail.utils.fatal -class StreamWriterState(override val kb: EmitClassBuilder[_]) extends AggregatorState { +class StreamWriterState(override val kb: EmitClassBuilder[_], indexKey: Option[PStruct]) extends AggregatorState { val outb: Settable[OutputBuffer] = kb.genFieldThisRef[OutputBuffer]() - val path: Settable[String] = kb.genFieldThisRef[String]() + val part: Settable[String] = kb.genFieldThisRef[String]() + val indexWriter = indexKey.map { key => + val branchingFactor = Option(kb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) + StagedIndexWriter.withDefaults(key, kb, branchingFactor = branchingFactor) + } override def storageType = PCanonicalStringRequired @@ -40,20 +45,46 @@ class StreamWriterState(override val kb: EmitClassBuilder[_]) extends Aggregator override def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = ??? override def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = ??? + + private[agg] def addToIndex(cb: EmitCodeBuilder, codeRow: SValue): Unit = indexWriter.foreach { iw => + val row = codeRow.asBaseStruct + val rowKey = row.subset(indexKey.get.fieldNames: _*) + iw.add(cb, IEmitCode.present(cb, rowKey), outb.invoke[Long]("indexOffset"), + IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L))) + } } -class StreamWriterAggregator(spec: TypedCodecSpec) extends StagedAggregator { +class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends StagedAggregator { type State = StreamWriterState - val initOpTypes: IndexedSeq[Type] = ArraySeq(TString) + val initOpTypes: IndexedSeq[Type] = ArraySeq( + TString, // partfile base name + TString, // path root _with_ 'directory' separator + ) ++ (if (indexed) Some(TString) else None) // if indexed, index root path _with_ 'directory' separator val seqOpTypes: IndexedSeq[Type] = ArraySeq(spec.encodedVirtualType) val resultEmitType = EmitType(SJavaString, true) override protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { - val Array(pathEC) = init - val path = pathEC.toI(cb).getOrFatal(cb, "path cannot be missing").asString.loadString(cb) - val os = cb.emb.createUnbuffered(path) - cb.assign(state.path, path) + val (partEC, rootEC, ixrootEC) = init match { + case Array(root, part) => + require(!indexed) + (root, part, None) + case Array(root, part, ixroot) => + require(indexed) + (root, part, Some(ixroot)) + } + + val root = rootEC.toI(cb).getOrFatal(cb, "path cannot be missing").asString.loadString(cb) + val part = partEC.toI(cb).getOrFatal(cb, "part cannot be missing").asString.loadString(cb) + val os = cb.emb.createUnbuffered(root.concat(part)) + + state.indexWriter.foreach { iw => + val root = ixrootEC.get.toI(cb).getOrFatal(cb, "index path cannot be missing").asString.loadString(cb) + val path = cb.memoize(root.concat(part).concat(".idx")) + iw.init(cb, path, cb.memoize(cb.emb.getObject[Map[String, Any]](Map.empty))) + } + + cb.assign(state.part, part) cb.assign(state.outb, spec.buildCodeOutputBuffer(os)) } @@ -61,16 +92,20 @@ class StreamWriterAggregator(spec: TypedCodecSpec) extends StagedAggregator { val Array(rowEC) = seq val row = rowEC.toI(cb).getOrFatal(cb, "row cannot be missing") val encoder = spec.encodedType.buildEncoder(row.st, cb.emb.ecb) + + state.addToIndex(cb, row) cb += state.outb.writeByte(1.asInstanceOf[Byte]) encoder.apply(cb, row, state.outb) } + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) : IEmitCode = { cb += state.outb.writeByte(0.asInstanceOf[Byte]) cb += state.outb.flush() cb += state.outb.close() - IEmitCode.present(cb, new SJavaStringValue(state.path)) + state.indexWriter.foreach(_.close(cb)) + IEmitCode.present(cb, new SJavaStringValue(state.part)) } override protected def _combOp( diff --git a/hail/hail/src/is/hail/io/BufferSpecs.scala b/hail/hail/src/is/hail/io/BufferSpecs.scala index 6a883155f94..8946cec0b34 100644 --- a/hail/hail/src/is/hail/io/BufferSpecs.scala +++ b/hail/hail/src/is/hail/io/BufferSpecs.scala @@ -4,6 +4,7 @@ import is.hail.asm4s._ import is.hail.compatibility.{LEB128BufferSpec, LZ4BlockBufferSpec} import is.hail.io.compress.LZ4 import is.hail.rvd.AbstractRVDSpec +import is.hail.utils.implicits.ByteTrackingOutputStream import java.io._ @@ -266,8 +267,10 @@ object StreamBlockBufferSpec { final class StreamBlockBufferSpec extends BlockBufferSpec { override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer(in) - override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = - new StreamBlockOutputBuffer(out) + override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = out match { + case out: ByteTrackingOutputStream => new StreamBlockOutputBuffer(out) + case out => new StreamBlockOutputBuffer(new ByteTrackingOutputStream(out)) + } override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = Code.newInstance[StreamBlockInputBuffer, InputStream](in) @@ -280,19 +283,23 @@ final class StreamBlockBufferSpec extends BlockBufferSpec { object StreamBlockBufferSpec2 { def extract(jv: JValue): StreamBlockBufferSpec2 = new StreamBlockBufferSpec2 + + def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = out match { + case out: ByteTrackingOutputStream => new StreamBlockOutputBuffer2(out) + case out => new StreamBlockOutputBuffer2(new ByteTrackingOutputStream(out)) + } } final class StreamBlockBufferSpec2 extends BlockBufferSpec { override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer2(in) - override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = - new StreamBlockOutputBuffer2(out) + override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = StreamBlockBufferSpec2.buildOutputBuffer(out) override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = Code.newInstance[StreamBlockInputBuffer2, InputStream](in) override def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.newInstance[StreamBlockOutputBuffer2, OutputStream](out) + Code.invokeScalaObject1[OutputStream, OutputBlockBuffer](StreamBlockBufferSpec2.getClass, "buildOutputBuffer", out) override def equals(other: Any): Boolean = other.isInstanceOf[StreamBlockBufferSpec2] } diff --git a/hail/hail/src/is/hail/io/OutputBuffers.scala b/hail/hail/src/is/hail/io/OutputBuffers.scala index 543bde1f854..a7678380a09 100644 --- a/hail/hail/src/is/hail/io/OutputBuffers.scala +++ b/hail/hail/src/is/hail/io/OutputBuffers.scala @@ -250,7 +250,7 @@ final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends } } -final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer { +final class StreamBlockOutputBuffer(out: ByteTrackingOutputStream) extends OutputBlockBuffer { private val lenBuf = new Array[Byte](4) override def flush(): Unit = @@ -265,10 +265,10 @@ final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer out.write(buf, 0, len) } - override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten + override def getPos(): Long = out.bytesWritten } -final class StreamBlockOutputBuffer2(out: OutputStream) extends OutputBlockBuffer { +final class StreamBlockOutputBuffer2(out: ByteTrackingOutputStream) extends OutputBlockBuffer { override def flush(): Unit = out.flush() @@ -288,7 +288,7 @@ final class StreamBlockOutputBuffer2(out: OutputStream) extends OutputBlockBuffe out.write(buf, 0, len) } - override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten + override def getPos(): Long = out.bytesWritten } final class LZ4OutputBlockBuffer(lz4: LZ4, blockSize: Int, out: OutputBlockBuffer) From ac36cb412f2ab109dd39cb114c1bc202276a9043 Mon Sep 17 00:00:00 2001 From: Chris Vittal Date: Fri, 1 May 2026 13:27:40 -0400 Subject: [PATCH 3/6] lint --- hail/hail/src/is/hail/expr/ir/AggOp.scala | 4 +-- .../src/is/hail/expr/ir/TableWriter.scala | 7 +++- .../src/is/hail/expr/ir/agg/Extract.scala | 3 +- .../expr/ir/agg/StreamWriterAggregator.scala | 33 ++++++++++++------- hail/hail/src/is/hail/io/BufferSpecs.scala | 9 +++-- 5 files changed, 38 insertions(+), 18 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/AggOp.scala b/hail/hail/src/is/hail/expr/ir/AggOp.scala index 4c525b5f0e7..c149a545e95 100644 --- a/hail/hail/src/is/hail/expr/ir/AggOp.scala +++ b/hail/hail/src/is/hail/expr/ir/AggOp.scala @@ -2,9 +2,9 @@ package is.hail.expr.ir import is.hail.collection.FastSeq import is.hail.expr.ir.agg._ -import is.hail.types.virtual._ -import is.hail.types.physical.PStruct import is.hail.io.TypedCodecSpec +import is.hail.types.physical.PStruct +import is.hail.types.virtual._ sealed trait AggOp {} final case class ApproxCDF() extends AggOp diff --git a/hail/hail/src/is/hail/expr/ir/TableWriter.scala b/hail/hail/src/is/hail/expr/ir/TableWriter.scala index 8845a93f45c..675d2497379 100644 --- a/hail/hail/src/is/hail/expr/ir/TableWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/TableWriter.scala @@ -117,7 +117,12 @@ object TableNativeWriter { ) val partResult = streamAggIR(rows) { row => makestruct( - "partpath" -> ApplyAggOp(WriteTBD(rowSpec, Some(pKey)), partPath, Str(root), Str(s"$path/index/"))(row), + "partpath" -> ApplyAggOp( + WriteTBD(rowSpec, Some(pKey)), + partPath, + Str(root), + Str(s"$path/index/"), + )(row), "partitionCounts" -> ApplyAggOp(Count())(), "keyMeta" -> aggFoldIR(zero) { accum => bindIRs(SelectFields(row, pKey.fieldNames), GetField(accum, "lastKey")) { diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index 8cb95478485..4f8cb414e12 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -147,7 +147,8 @@ case class FoldStateSig( combOpIR: IR, ) extends AggStateSig(ArraySeq(resultEmitType.typeWithRequiredness), None) -case class WriteSig(rowType: VirtualTypeWithReq, indexKey: Option[PStruct]) extends AggStateSig(ArraySeq(rowType), None) +case class WriteSig(rowType: VirtualTypeWithReq, indexKey: Option[PStruct]) + extends AggStateSig(ArraySeq(rowType), None) object PhysicalAggSig { def apply(op: AggOp, state: AggStateSig): PhysicalAggSig = BasicPhysicalAggSig(op, state) diff --git a/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala index ec57bb5944b..8ed483f31e2 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/StreamWriterAggregator.scala @@ -2,7 +2,7 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.asm4s.implicits.{valueToRichCodeOutputBuffer} +import is.hail.asm4s.implicits.valueToRichCodeOutputBuffer import is.hail.backend.ExecuteContext import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir._ @@ -14,11 +14,14 @@ import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaStringValue} import is.hail.types.virtual._ import is.hail.utils.fatal -class StreamWriterState(override val kb: EmitClassBuilder[_], indexKey: Option[PStruct]) extends AggregatorState { +class StreamWriterState(override val kb: EmitClassBuilder[_], indexKey: Option[PStruct]) + extends AggregatorState { val outb: Settable[OutputBuffer] = kb.genFieldThisRef[OutputBuffer]() val part: Settable[String] = kb.genFieldThisRef[String]() + val indexWriter = indexKey.map { key => - val branchingFactor = Option(kb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) + val branchingFactor = + Option(kb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) StagedIndexWriter.withDefaults(key, kb, branchingFactor = branchingFactor) } @@ -46,12 +49,17 @@ class StreamWriterState(override val kb: EmitClassBuilder[_], indexKey: Option[P override def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = ??? - private[agg] def addToIndex(cb: EmitCodeBuilder, codeRow: SValue): Unit = indexWriter.foreach { iw => - val row = codeRow.asBaseStruct - val rowKey = row.subset(indexKey.get.fieldNames: _*) - iw.add(cb, IEmitCode.present(cb, rowKey), outb.invoke[Long]("indexOffset"), - IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L))) - } + private[agg] def addToIndex(cb: EmitCodeBuilder, codeRow: SValue): Unit = + indexWriter.foreach { iw => + val row = codeRow.asBaseStruct + val rowKey = row.subset(indexKey.get.fieldNames: _*) + iw.add( + cb, + IEmitCode.present(cb, rowKey), + outb.invoke[Long]("indexOffset"), + IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L)), + ) + } } class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends StagedAggregator { @@ -60,7 +68,8 @@ class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends Sta val initOpTypes: IndexedSeq[Type] = ArraySeq( TString, // partfile base name TString, // path root _with_ 'directory' separator - ) ++ (if (indexed) Some(TString) else None) // if indexed, index root path _with_ 'directory' separator + ) ++ (if (indexed) Some(TString) + else None) // if indexed, index root path _with_ 'directory' separator val seqOpTypes: IndexedSeq[Type] = ArraySeq(spec.encodedVirtualType) val resultEmitType = EmitType(SJavaString, true) @@ -79,7 +88,8 @@ class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends Sta val os = cb.emb.createUnbuffered(root.concat(part)) state.indexWriter.foreach { iw => - val root = ixrootEC.get.toI(cb).getOrFatal(cb, "index path cannot be missing").asString.loadString(cb) + val root = + ixrootEC.get.toI(cb).getOrFatal(cb, "index path cannot be missing").asString.loadString(cb) val path = cb.memoize(root.concat(part).concat(".idx")) iw.init(cb, path, cb.memoize(cb.emb.getObject[Map[String, Any]](Map.empty))) } @@ -98,7 +108,6 @@ class StreamWriterAggregator(spec: TypedCodecSpec, indexed: Boolean) extends Sta encoder.apply(cb, row, state.outb) } - override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) : IEmitCode = { cb += state.outb.writeByte(0.asInstanceOf[Byte]) diff --git a/hail/hail/src/is/hail/io/BufferSpecs.scala b/hail/hail/src/is/hail/io/BufferSpecs.scala index 8946cec0b34..f4e303c909f 100644 --- a/hail/hail/src/is/hail/io/BufferSpecs.scala +++ b/hail/hail/src/is/hail/io/BufferSpecs.scala @@ -293,13 +293,18 @@ object StreamBlockBufferSpec2 { final class StreamBlockBufferSpec2 extends BlockBufferSpec { override def buildInputBuffer(in: InputStream): InputBlockBuffer = new StreamBlockInputBuffer2(in) - override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = StreamBlockBufferSpec2.buildOutputBuffer(out) + override def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = + StreamBlockBufferSpec2.buildOutputBuffer(out) override def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = Code.newInstance[StreamBlockInputBuffer2, InputStream](in) override def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.invokeScalaObject1[OutputStream, OutputBlockBuffer](StreamBlockBufferSpec2.getClass, "buildOutputBuffer", out) + Code.invokeScalaObject1[OutputStream, OutputBlockBuffer]( + StreamBlockBufferSpec2.getClass, + "buildOutputBuffer", + out, + ) override def equals(other: Any): Boolean = other.isInstanceOf[StreamBlockBufferSpec2] } From bf4b656d4860b846c6af5c75431ef3297ba53ed4 Mon Sep 17 00:00:00 2001 From: Chris Vittal Date: Sat, 9 May 2026 10:13:27 -0400 Subject: [PATCH 4/6] StreamBlockOutputBuffer can be more general It's only used for block matrix --- hail/hail/src/is/hail/io/OutputBuffers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/hail/src/is/hail/io/OutputBuffers.scala b/hail/hail/src/is/hail/io/OutputBuffers.scala index a7678380a09..d9d6929b901 100644 --- a/hail/hail/src/is/hail/io/OutputBuffers.scala +++ b/hail/hail/src/is/hail/io/OutputBuffers.scala @@ -250,7 +250,7 @@ final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends } } -final class StreamBlockOutputBuffer(out: ByteTrackingOutputStream) extends OutputBlockBuffer { +final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer { private val lenBuf = new Array[Byte](4) override def flush(): Unit = From 1de79b831605654382f8965673af1dfe44e0331c Mon Sep 17 00:00:00 2001 From: Chris Vittal Date: Tue, 12 May 2026 09:51:01 -0400 Subject: [PATCH 5/6] fix --- hail/hail/src/is/hail/io/OutputBuffers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/hail/src/is/hail/io/OutputBuffers.scala b/hail/hail/src/is/hail/io/OutputBuffers.scala index d9d6929b901..85206eba3a1 100644 --- a/hail/hail/src/is/hail/io/OutputBuffers.scala +++ b/hail/hail/src/is/hail/io/OutputBuffers.scala @@ -265,7 +265,7 @@ final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer out.write(buf, 0, len) } - override def getPos(): Long = out.bytesWritten + override def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten } final class StreamBlockOutputBuffer2(out: ByteTrackingOutputStream) extends OutputBlockBuffer { From 04aab46f4cd0e675b318e36905f649eb949934e0 Mon Sep 17 00:00:00 2001 From: Chris Vittal Date: Tue, 12 May 2026 15:30:19 -0400 Subject: [PATCH 6/6] Remove unneded comments --- .../src/is/hail/expr/ir/TableWriter.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/TableWriter.scala b/hail/hail/src/is/hail/expr/ir/TableWriter.scala index 675d2497379..4c2502f5dc3 100644 --- a/hail/hail/src/is/hail/expr/ir/TableWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/TableWriter.scala @@ -74,19 +74,6 @@ object TableNativeWriter { // write out partitioner key, which may be stricter than table key val partitioner = ts.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - // val _@rowWriter = PartitionNativeWriter( - // rowSpec, - // pKey.fieldNames, - // s"$path/rows/parts/", - // Some(s"$path/index/" -> pKey), - // if (stageLocally) Some(FileSystems.getDefault.getPath( - // ctx.localTmpdir, - // s"hail_staging_tmp_${UUID.randomUUID()}", - // "rows", - // "parts", - // )) - // else None, - // ) val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/parts/", None, None) @@ -138,14 +125,6 @@ object TableNativeWriter { } } { (accum1, accum2) => Die("unreachable: calling combop on writer fold makes no sense", zero.typ) - /* val stillDistinct = GetField(accum1, "distinct") && GetField( accum2, - * "distinct", ) && Coalesce(FastSeq( GetField(accum1, - * "lastKey").cne(GetField(accum2, "firstKey")), True(), )) val first = - * Coalesce(FastSeq(GetField(accum1, "firstKey"), GetField(accum2, "firstKey"))) - * val last = - * Coalesce(FastSeq(GetField(accum2, "lastKey"), GetField(accum1, "lastKey"))) - * makestruct( "distinct" -> stillDistinct, "firstKey" -> first, "lastKey" -> - * last, ) */ }, ) }