From 7d6bbfd1e2ab8bf31bb862961ee3a3c0f91dded8 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 8 Jun 2026 11:27:29 -0400 Subject: [PATCH 1/5] [query] no-sharing in `MatrixWriter` --- .../src/is/hail/expr/ir/MatrixWriter.scala | 802 +++++++++--------- .../hail/expr/ir/lowering/LowerTableIR.scala | 72 +- 2 files changed, 416 insertions(+), 458 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala b/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala index 3c7739630e0..dd7dfcd7b6d 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala @@ -7,6 +7,7 @@ import is.hail.backend.ExecuteContext import is.hail.collection.{ByteArrayBuilder, FastSeq} import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.{JSONAnnotationImpex, Nat} +import is.hail.expr.ir.{Memoized => M} import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.TableStage import is.hail.expr.ir.streams.StreamProducer @@ -87,8 +88,8 @@ sealed trait MatrixWriterComponents { def stage: TableStage def setup: IR def writePartitionType: Type - def writePartition(rows: IR, ctx: Atom): IR - def finalizeWrite(parts: IR, globals: IR): IR + def writePartition(rows: Atom, ctx: Atom): IR + def finalizeWrite(parts: Atom, globals: Atom): IR } object MatrixNativeWriter { @@ -143,12 +144,6 @@ object MatrixNativeWriter { val partitioner = lowered.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - val emptyWriter = - PartitionNativeWriter(emptySpec, IndexedSeq(), s"$path/globals/globals/parts/", None, None) - val globalWriter = - PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/rows/parts/", None, None) - val colWriter = - PartitionNativeWriter(colSpec, IndexedSeq(), s"$path/cols/rows/parts/", None, None) val rowWriter = SplitPartitionNativeWriter( rowSpec, s"$path/rows/rows/parts/", @@ -161,39 +156,6 @@ object MatrixNativeWriter { else None, ) - val globalTableWriter = TableSpecWriter( - s"$path/globals", - TableType(tm.globalType, FastSeq(), TStruct.empty), - "rows", - "globals", - "../references", - log = false, - ) - val colTableWriter = TableSpecWriter( - s"$path/cols", - tm.colsTableType.copy(key = FastSeq[String]()), - "rows", - "../globals/rows", - "../references", - log = false, - ) - val rowTableWriter = TableSpecWriter( - s"$path/rows", - tm.rowsTableType, - "rows", - "../globals/rows", - "../references", - log = false, - ) - val entriesTableWriter = TableSpecWriter( - s"$path/entries", - TableType(tm.entriesRVType, FastSeq(), tm.globalType), - "rows", - "../globals/rows", - "../references", - log = false, - ) - new MatrixWriterComponents { override val stage: TableStage = @@ -205,7 +167,7 @@ object MatrixNativeWriter { oldCtx, ToStream(Literal(TArray(TString), partFiles)), ArrayZipBehavior.AssertSameLength, - )((ctxElt, pf) => MakeStruct(FastSeq("oldCtx" -> ctxElt, "writeCtx" -> pf))) + )((ctxElt, pf) => makestruct("oldCtx" -> ctxElt, "writeCtx" -> pf)) }(GetField(_, "oldCtx")) override val setup: IR = @@ -223,145 +185,166 @@ object MatrixNativeWriter { override def writePartitionType: Type = rowWriter.returnType - override def writePartition(rows: IR, ctx: Atom): IR = + override def writePartition(rows: Atom, ctx: Atom): IR = WritePartition(rows, GetField(ctx, "writeCtx") + UUID4(), rowWriter) - override def finalizeWrite(parts: IR, globals: IR): IR = { - // parts is array of partition results - val writeEmpty = WritePartition( - MakeStream(FastSeq(makestruct()), TStream(TStruct.empty)), - Str(partFile(1, 0)), - emptyWriter, - ) - val writeCols = - WritePartition(ToStream(GetField(globals, colsFieldName)), Str(partFile(1, 0)), colWriter) - val writeGlobals = WritePartition( - MakeStream( - FastSeq(SelectFields(globals, tm.globalType.fieldNames)), - TStream(tm.globalType), - ), - Str(partFile(1, 0)), - globalWriter, - ) + override def finalizeWrite(parts: Atom, globals: Atom): IR = + M.eval { + for { + partFile <- Str(partFile(1, 0)) + // parts is array of partition results + writeEmpty <- WritePartition( + MakeStream.single(makestruct()), + partFile, + PartitionNativeWriter( + emptySpec, + FastSeq(), + s"$path/globals/globals/parts/", + None, + None, + ), + ) - val matrixWriter = MatrixSpecWriter(path, tm, "rows/rows", "globals/rows", "cols/rows", - "entries/rows", "references", log = true) + colInfo <- WritePartition( + ToStream(GetField(globals, colsFieldName)), + partFile, + PartitionNativeWriter(colSpec, FastSeq(), s"$path/cols/rows/parts/", None, None), + ) - val rowsIndexSpec = IndexSpec.defaultAnnotation(ctx, "../../index", tcoerce[PStruct](pKey)) - val entriesIndexSpec = - IndexSpec.defaultAnnotation( - ctx, - "../../index", - tcoerce[PStruct](pKey), - withOffsetField = true, - ) + writeGlobals <- WritePartition( + MakeStream.single(SelectFields(globals, tm.globalType.fieldNames)), + partFile, + PartitionNativeWriter(globalSpec, FastSeq(), s"$path/globals/rows/parts/", None, None), + ) - bindIR(writeCols) { colInfo => - bindIR(parts) { partInfo => - Begin(FastSeq( - WriteMetadata( - MakeArray(GetField(writeEmpty, "filePath")), - RVDSpecWriter( - s"$path/globals/globals", - RVDSpecMaker(emptySpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), - ), + _ <- WriteMetadata( + MakeArray(GetField(writeEmpty, "filePath")), + RVDSpecWriter( + s"$path/globals/globals", + RVDSpecMaker(emptySpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), ), - WriteMetadata( - MakeArray(GetField(writeGlobals, "filePath")), - RVDSpecWriter( - s"$path/globals/rows", - RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), - ), + ) + + _ <- WriteMetadata( + MakeArray(GetField(writeGlobals, "filePath")), + RVDSpecWriter( + s"$path/globals/rows", + RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), ), - WriteMetadata( - MakeArray(MakeStruct(FastSeq( - "partitionCounts" -> I64(1), - "distinctlyKeyed" -> True(), - "firstKey" -> MakeStruct(FastSeq()), - "lastKey" -> MakeStruct(FastSeq()), - ))), - globalTableWriter, + ) + + _ <- WriteMetadata( + MakeArray(makestruct( + "partitionCounts" -> I64(1), + "distinctlyKeyed" -> True(), + "firstKey" -> makestruct(), + "lastKey" -> makestruct(), + )), + TableSpecWriter( + s"$path/globals", + TableType(tm.globalType, FastSeq(), TStruct.empty), + "rows", + "globals", + "../references", + log = false, ), - WriteMetadata( - MakeArray(GetField(colInfo, "filePath")), - RVDSpecWriter( - s"$path/cols/rows", - RVDSpecMaker(colSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), - ), + ) + + _ <- WriteMetadata( + MakeArray(GetField(colInfo, "filePath")), + RVDSpecWriter( + s"$path/cols/rows", + RVDSpecMaker(colSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), ), - WriteMetadata( - MakeArray(SelectFields( - colInfo, - IndexedSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), - )), - colTableWriter, + ) + + _ <- WriteMetadata( + MakeArray( + selectIR(colInfo, "partitionCounts", "distinctlyKeyed", "firstKey", "lastKey") ), - bindIR(ToArray(mapIR(ToStream(partInfo))(fc => GetField(fc, "filePath")))) { - files => - Begin(FastSeq( - WriteMetadata( - files, - RVDSpecWriter( - s"$path/rows/rows", - RVDSpecMaker(rowSpec, lowered.partitioner, rowsIndexSpec), - ), - ), - WriteMetadata( - files, - RVDSpecWriter( - s"$path/entries/rows", - RVDSpecMaker( - entrySpec, - RVDPartitioner.unkeyed(ctx.stateManager, lowered.numPartitions), - entriesIndexSpec, - ), - ), - ), - )) + TableSpecWriter( + s"$path/cols", + tm.colsTableType.copy(key = FastSeq[String]()), + "rows", + "../globals/rows", + "../references", + log = false, + ), + ) + + files <- mapArray(parts)(GetField(_, "filePath")) + + rowsIndexSpec = IndexSpec.defaultAnnotation(ctx, "../../index", pKey) + _ <- WriteMetadata( + files, + RVDSpecWriter( + s"$path/rows/rows", + RVDSpecMaker(rowSpec, lowered.partitioner, rowsIndexSpec), + ), + ) + + entriesIndexSpec = + IndexSpec.defaultAnnotation(ctx, "../../index", pKey, withOffsetField = true) + + _ <- WriteMetadata( + files, + RVDSpecWriter( + s"$path/entries/rows", + RVDSpecMaker( + entrySpec, + RVDPartitioner.unkeyed(ctx.stateManager, lowered.numPartitions), + entriesIndexSpec, + ), + ), + ) + + _ <- WriteMetadata( + mapArray(parts) { part => + selectIR(part, "partitionCounts", "distinctlyKeyed", "firstKey", "lastKey") }, - bindIR(ToArray(mapIR(ToStream(partInfo)) { fc => - SelectFields( - fc, - FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), + TableSpecWriter( + s"$path/rows", + tm.rowsTableType, + "rows", + "../globals/rows", + "../references", + log = false, + ), + ) + + _ <- WriteMetadata( + mapArray(parts) { part => + insertIR( + selectIR(part, "partitionCounts", "distinctlyKeyed"), + "firstKey" -> makestruct(), + "lastKey" -> makestruct(), ) - })) { countsAndKeyInfo => - Begin(FastSeq( - WriteMetadata(countsAndKeyInfo, rowTableWriter), - WriteMetadata( - ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKeyInfo => - InsertFields( - SelectFields( - countAndKeyInfo, - IndexedSeq("partitionCounts", "distinctlyKeyed"), - ), - IndexedSeq( - "firstKey" -> MakeStruct(FastSeq()), - "lastKey" -> MakeStruct(FastSeq()), - ), - ) - }), - entriesTableWriter, - ), - WriteMetadata( - makestruct( - "cols" -> GetField(colInfo, "partitionCounts"), - "rows" -> ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKey => - GetField(countAndKey, "partitionCounts") - }), - ), - matrixWriter, - ), - )) }, - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(path)), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/globals")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/cols")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/rows")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/entries")), - )) - } + TableSpecWriter( + s"$path/entries", + TableType(tm.entriesRVType, FastSeq(), tm.globalType), + "rows", + "../globals/rows", + "../references", + log = false, + ), + ) + + _ <- WriteMetadata( + makestruct( + "cols" -> GetField(colInfo, "partitionCounts"), + "rows" -> mapArray(parts)(part => GetField(part, "partitionCounts")), + ), + MatrixSpecWriter(path, tm, "rows/rows", "globals/rows", "cols/rows", "entries/rows"), + ) + + _ <- WriteMetadata(makestruct(), RelationalCommit(path)) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/globals")) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/cols")) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/rows")) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/entries")) + } yield Void() } - } } } } @@ -407,9 +390,6 @@ case class SplitPartitionNativeWriter( stageFolder: Option[Path], ) extends PartitionWriter { - val filenameType = PCanonicalString(required = true) - def pContextType = PCanonicalString() - val keyType = spec1.encodedVirtualType.asInstanceOf[TStruct].select(keyFieldNames)._1 override def ctxType: Type = TString @@ -501,11 +481,9 @@ case class SplitPartitionNativeWriter( val pCount = mb.newLocal[Long]("partition_count") cb.assign(pCount, 0L) + // True until proven otherwise, if there's a key to care about all. val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed") - cb.assign( - distinctlyKeyed, - !keyFieldNames.isEmpty, - ) // True until proven otherwise, if there's a key to care about all. + cb.assign(distinctlyKeyed, keyFieldNames.nonEmpty) val keyEmitType = EmitType(spec1.decodedPType(keyType).sType, false) @@ -562,7 +540,7 @@ case class SplitPartitionNativeWriter( keyType.fields.map(f => EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))): _* ) - if (!keyFieldNames.isEmpty) { + if (keyFieldNames.nonEmpty) { cb.if_( distinctlyKeyed, { lastSeenSettable.loadI(cb).consume( @@ -650,9 +628,7 @@ class MatrixSpecHelper( globalRelPath: String, colRelPath: String, entryRelPath: String, - refRelPath: String, typ: MatrixType, - log: Boolean, ) extends Logging with Serializable { def write(fs: FS, nCols: Long, partCounts: Array[Long]): Unit = { val spec = MatrixTableSpecParameters( @@ -686,8 +662,6 @@ case class MatrixSpecWriter( globalRelPath: String, colRelPath: String, entryRelPath: String, - refRelPath: String, - log: Boolean, ) extends MetadataWriter { override def annotationType: Type = TStruct("cols" -> TInt64, "rows" -> TArray(TInt64)) @@ -712,7 +686,7 @@ case class MatrixSpecWriter( }, ) cb += cb.emb.getObject(new MatrixSpecHelper(path, rowRelPath, globalRelPath, colRelPath, - entryRelPath, refRelPath, typ, log)) + entryRelPath, typ)) .invoke[FS, Long, Array[Long], Unit]( "write", cb.emb.getFS, @@ -787,25 +761,18 @@ case class MatrixVCFWriter( ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq( - "oldCtx" -> ctxElt, - "partFile" -> pf, - )) + makestruct("__old_ctx" -> ctxElt, "__part_file" -> pf) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { - (rows, ctxRef) => - val partFile = GetField(ctxRef, "partFile") + UUID4() + Str(ext) - val ctx = MakeStruct(FastSeq( - "cols" -> GetField(ts.globals, colsFieldName), - "partFile" -> partFile, - )) + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_vcf_writer") { (rows, ctxRef) => + val partFile = GetField(ctxRef, "__part_file") + UUID4() + Str(ext) + val ctx = makestruct("cols" -> GetField(ts.globals, colsFieldName), "partFile" -> partFile) WritePartition(rows, ctx, lineWriter) - } { (parts, globals) => - val ctx = - MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "partFiles" -> parts)) - val commit = VCFExportFinalizer(tm, path, appendStr, metadata, exportType, tabix) - Begin(FastSeq(WriteMetadata(ctx, commit))) - } + } { (parts, globals) => + val ctx = makestruct("cols" -> GetField(globals, colsFieldName), "partFiles" -> parts) + val commit = VCFExportFinalizer(tm, path, appendStr, metadata, exportType, tabix) + WriteMetadata(ctx, commit) + } } private def getAppendHeaderValue(fs: FS): Option[String] = append.map { f => @@ -1421,10 +1388,6 @@ case class MatrixGENWriter( r: RTable, ): IR = { val tm = MatrixType.fromTableType(ts.tableType, colsFieldName, entriesFieldName, colKey) - - val sampleWriter = new GenSampleWriter - - val lineWriter = GenVariantWriter(tm, entriesFieldName, precision) val folder = ctx.createTmpPath("export-gen") ts.mapContexts { oldCtx => @@ -1434,25 +1397,22 @@ case class MatrixGENWriter( ArraySeq.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-"), ) - zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq( - "oldCtx" -> ctxElt, - "partFile" -> pf, - )) + zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctx, pf) => + makestruct("__old_ctx" -> ctx, "__part_file" -> pf) + } + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_gen_writer") { (rows, ctxRef) => + val ctx = GetField(ctxRef, "__part_file") + UUID4() + WritePartition(rows, ctx, GenVariantWriter(tm, entriesFieldName, precision)) + } { (parts, globals) => + val cols = ToStream(GetField(globals, colsFieldName)) + val sampleFileName = Str(s"$path.sample") + val writeSamples = WritePartition(cols, sampleFileName, new GenSampleWriter) + val commitSamples = SimpleMetadataWriter(TString) + + val commit = TableTextFinalizer(s"$path.gen", ts.rowType, " ", header = false) + Begin(FastSeq(WriteMetadata(writeSamples, commitSamples), WriteMetadata(parts, commit))) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_gen_writer") { - (rows, ctxRef) => - val ctx = GetField(ctxRef, "partFile") + UUID4() - WritePartition(rows, ctx, lineWriter) - } { (parts, globals) => - val cols = ToStream(GetField(globals, colsFieldName)) - val sampleFileName = Str(s"$path.sample") - val writeSamples = WritePartition(cols, sampleFileName, sampleWriter) - val commitSamples = SimpleMetadataWriter(TString) - - val commit = TableTextFinalizer(s"$path.gen", ts.rowType, " ", header = false) - Begin(FastSeq(WriteMetadata(writeSamples, commitSamples), WriteMetadata(parts, commit))) - } } } @@ -1633,38 +1593,34 @@ case class MatrixBGENWriter( ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) + + // hint: don't writeHeader + val variantCounts = + if (writeHeader) ToStream(ts.countPerPartition().deepCopy) + else mapIR(rangeIR(ts.numPartitions))(_ => NA(TInt64)) + val partFiles = ToStream(Literal( TArray(TString), ArraySeq.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-"), )) - val numVariants = if (writeHeader) ToStream(ts.countPerPartition()) - else ToStream(MakeArray(ArraySeq.tabulate(ts.numPartitions)(_ => NA(TInt64)): _*)) - - val ctxElt = Ref(freshName(), tcoerce[TStream](oldCtx.typ).elementType) - val pf = Ref(freshName(), tcoerce[TStream](partFiles.typ).elementType) - val nv = Ref(freshName(), tcoerce[TStream](numVariants.typ).elementType) - - StreamZip( - FastSeq(oldCtx, partFiles, numVariants), - FastSeq(ctxElt.name, pf.name, nv.name), - MakeStruct(FastSeq("oldCtx" -> ctxElt, "numVariants" -> nv, "partFile" -> pf)), - ArrayZipBehavior.AssertSameLength, - ) - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { - (rows, ctxRef) => - val partFile = GetField(ctxRef, "partFile") + UUID4() - val ctx = MakeStruct(FastSeq( + + zipIR(FastSeq(oldCtx, variantCounts, partFiles), ArrayZipBehavior.AssertSameLength) { + case Seq(ctx, vc, pf) => + makestruct("__old_ctx" -> ctx, "__num_variants" -> vc, "__part_file" -> pf) + } + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_vcf_writer") { (rows, ctx) => + val writeCtx = makestruct( "cols" -> GetField(ts.globals, colsFieldName), - "numVariants" -> GetField(ctxRef, "numVariants"), - "partFile" -> partFile, - )) - WritePartition(rows, ctx, partWriter) - } { (results, globals) => - val ctx = - MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "results" -> results)) - val commit = BGENExportFinalizer(tm, path, exportType, compressionInt) - Begin(FastSeq(WriteMetadata(ctx, commit))) - } + "numVariants" -> GetField(ctx, "__num_variants"), + "partFile" -> (GetField(ctx, "__part_file") + UUID4()), + ) + WritePartition(rows, writeCtx, partWriter) + } { (results, globals) => + val ctx = makestruct("cols" -> GetField(globals, colsFieldName), "results" -> results) + val commit = BGENExportFinalizer(tm, path, exportType, compressionInt) + WriteMetadata(ctx, commit) + } } } @@ -1676,7 +1632,7 @@ case class BGENPartitionWriter( ) extends PartitionWriter { require(typ.entryType.hasField("GP") && typ.entryType.fieldType("GP") == TArray(TFloat64)) - val ctxType: Type = + override val ctxType: Type = TStruct("cols" -> TArray(typ.colType), "numVariants" -> TInt64, "partFile" -> TString) override def returnType: TStruct = @@ -2134,31 +2090,27 @@ case class MatrixPLINKWriter( ) zip2(oldCtx, ToStream(files), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq( - "oldCtx" -> ctxElt, - "file" -> pf, - )) + makestruct("__old_ctx" -> ctxElt, "__files" -> pf) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_plink_writer") { - (rows, ctxRef) => - val id = UUID4() - val bedFile = GetTupleElement(GetField(ctxRef, "file"), 0) + id - val bimFile = GetTupleElement(GetField(ctxRef, "file"), 1) + id - val ctx = MakeStruct(FastSeq("bedFile" -> bedFile, "bimFile" -> bimFile)) - WritePartition(rows, ctx, lineWriter) - } { (parts, globals) => - val commit = PLINKExportFinalizer(tm, path, tmpBedDir + "/header") - val famWriter = TableTextPartitionWriter(tm.colsTableType.rowType, "\t", writeHeader = false) - val famPath = Str(path + ".fam") - val cols = ToStream(GetField(globals, colsFieldName)) - val writeFam = WritePartition(cols, famPath, famWriter) - bindIR(writeFam) { fpath => + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_plink_writer") { (rows, ctxRef) => + bindIRs(UUID4(), GetField(ctxRef, "__files")) { case Seq(id, files) => + val bedFile = GetTupleElement(files, 0) + id + val bimFile = GetTupleElement(files, 1) + id + val ctx = makestruct("bedFile" -> bedFile, "bimFile" -> bimFile) + WritePartition(rows, ctx, lineWriter) + } + } { (parts, globals) => + val commit = PLINKExportFinalizer(tm, path, tmpBedDir + "/header") + val famWriter = + TableTextPartitionWriter(tm.colsTableType.rowType, "\t", writeHeader = false) + val cols = ToStream(GetField(globals, colsFieldName)) + val fpath = WritePartition(cols, Str(path + ".fam"), famWriter) Begin(FastSeq( WriteMetadata(parts, commit), WriteMetadata(fpath, SimpleMetadataWriter(fpath.typ)), )) } - } } } @@ -2166,10 +2118,10 @@ case class PLINKPartitionWriter(typ: MatrixType, entriesFieldName: String) exten val ctxType = TStruct("bedFile" -> TString, "bimFile" -> TString) override def returnType = TStruct("bedFile" -> TString, "bimFile" -> TString) - val locusIdx = typ.rowType.fieldIdx("locus") - val allelesIdx = typ.rowType.fieldIdx("alleles") - val varidIdx = typ.rowType.fieldIdx("varid") - val cmPosIdx = typ.rowType.fieldIdx("cm_position") + private[this] val locusIdx = typ.rowType.fieldIdx("locus") + private[this] val allelesIdx = typ.rowType.fieldIdx("alleles") + private[this] val varidIdx = typ.rowType.fieldIdx("varid") + private[this] val cmPosIdx = typ.rowType.fieldIdx("cm_position") override def unionTypeRequiredness( r: TypeWithRequiredness, @@ -2356,9 +2308,7 @@ case class MatrixBlockMatrixWriter( val numBlockCols: Int = BlockMatrixType.numBlocks(numCols.toLong, blockSize) val lastBlockNumCols = (numCols - 1) % blockSize + 1 - val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(paritionIR => - StreamLen(paritionIR) - ) + val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(StreamLen(_)) val inputRowCountPerPartition: IndexedSeq[Int] = CompileAndEvaluate[IndexedSeq[Int]](ctx, rowCountIR) val inputPartStartsPlusLast = inputRowCountPerPartition.scanLeft(0L)(_ + _) @@ -2369,32 +2319,33 @@ case class MatrixBlockMatrixWriter( val numBlockRows: Int = BlockMatrixType.numBlocks(numRows, blockSize) // Zip contexts with partition starts and ends - val zippedWithStarts = ts.mapContexts { oldContextsStream => - zipIR( - IndexedSeq( - oldContextsStream, - ToStream(Literal(TArray(TInt64), inputPartStarts)), - ToStream(Literal(TArray(TInt64), inputPartStops)), - ), - ArrayZipBehavior.AssertSameLength, - ) { case IndexedSeq(oldCtx, partStart, partStop) => - MakeStruct(FastSeq( - "mwOld" -> oldCtx, - "mwStartIdx" -> Cast(partStart, TInt32), - "mwStopIdx" -> Cast(partStop, TInt32), - )) - } - }(newCtx => GetField(newCtx, "mwOld")) + val zippedWithStarts = + ts.mapContexts { oldContextsStream => + zipIR( + FastSeq( + oldContextsStream, + ToStream(Literal(TArray(TInt64), inputPartStarts)), + ToStream(Literal(TArray(TInt64), inputPartStops)), + ), + ArrayZipBehavior.AssertSameLength, + ) { case Seq(oldCtx, partStart, partStop) => + makestruct( + "__mw_old_ctx" -> oldCtx, + "__mw_start_idx" -> partStart.toI, + "__mw_stop_idx" -> partStop.toI, + ) + } + }(GetField(_, "__mw_old_ctx")) // Now label each row with its idx. - val perRowIdxId = genUID() - val partsZippedWithIdx = zippedWithStarts.mapPartitionWithContext { (part, ctx) => - zip2( - part, - rangeIR(GetField(ctx, "mwStartIdx"), GetField(ctx, "mwStopIdx")), - ArrayZipBehavior.AssertSameLength, - )((partRow, idx) => insertIR(partRow, (perRowIdxId, idx))) - } + val partsZippedWithIdx = + zippedWithStarts.mapPartitionWithContext { (part, ctx) => + zip2( + part, + rangeIR(GetField(ctx, "__mw_start_idx"), GetField(ctx, "__mw_stop_idx")), + ArrayZipBehavior.AssertSameLength, + )((partRow, idx) => insertIR(partRow, "__per_row_idx" -> idx)) + } /* Two steps, make a partitioner that works currently based on row_idx splits, then resplit * accordingly. */ @@ -2404,84 +2355,93 @@ case class MatrixBlockMatrixWriter( } val rowIdxPartitioner = - new RVDPartitioner(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals) + new RVDPartitioner(ctx.stateManager, TStruct("__per_row_idx" -> TInt32), inputRowIntervals) val keyedByRowIdx = partsZippedWithIdx.changePartitionerNoRepartition(rowIdxPartitioner) // Now create a partitioner that makes appropriately sized blocks - val desiredRowStarts = (0 until numBlockRows).map(_ * blockSize) + val desiredRowStarts = ArraySeq.tabulate(numBlockRows)(_ * blockSize) val desiredRowStops = desiredRowStarts.drop(1) :+ numRows.toInt - val desiredRowIntervals = desiredRowStarts.zip(desiredRowStops).map { - case (intervalStart, intervalEnd) => - Interval(Row(intervalStart), Row(intervalEnd), true, false) - } + val desiredRowIntervals = + desiredRowStarts + .view + .zip(desiredRowStops) + .map { case (start, end) => Interval(Row(start), Row(end), true, false) } + .to(ArraySeq) val blockSizeGroupsPartitioner = - RVDPartitioner.generate(ctx.stateManager, TStruct((perRowIdxId, TInt32)), desiredRowIntervals) + RVDPartitioner.generate( + ctx.stateManager, + TStruct("__per_row_idx" -> TInt32), + desiredRowIntervals, + ) val rowsInBlockSizeGroups: TableStage = keyedByRowIdx.repartitionNoShuffle(ctx, blockSizeGroupsPartitioner) - def createBlockMakingContexts(tablePartsStreamIR: IR): IR = { + def createBlockMakingContexts(tablePartsStreamIR: Atom): IR = flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { - case (tableSinglePartCtx, blockRowIdx) => - mapIR(rangeIR(I32(numBlockCols))) { blockColIdx => - MakeStruct(FastSeq( - "oldTableCtx" -> tableSinglePartCtx, - "blockStart" -> (blockColIdx * I32(blockSize)), - "blockSize" -> If( + (tableSinglePartCtx, blockRowIdx) => + mapIR(rangeIR(numBlockCols)) { blockColIdx => + makestruct( + "__old_table_ctx" -> tableSinglePartCtx, + "__block_start" -> (blockColIdx * I32(blockSize)), + "__block_size" -> If( blockColIdx ceq I32(numBlockCols - 1), - I32(lastBlockNumCols), - I32(blockSize), + lastBlockNumCols, + blockSize, ), - "blockColIdx" -> blockColIdx, - "blockRowIdx" -> blockRowIdx, - )) + "__block_col_idx" -> blockColIdx, + "__block_row_idx" -> blockRowIdx, + ) } }) - } - val tableOfNDArrays = rowsInBlockSizeGroups.mapContexts(createBlockMakingContexts)(ir => - GetField(ir, "oldTableCtx") - ).mapPartitionWithContext { (partIr, ctxRef) => - bindIR(GetField(ctxRef, "blockStart")) { blockStartRef => - val numColsOfBlock = GetField(ctxRef, "blockSize") - val arrayOfSlicesAndIndices = ToArray(mapIR(partIr) { singleRow => - val mappedSlice = ToArray(mapIR(ToStream(sliceArrayIR( - GetField(singleRow, entriesFieldName), - blockStartRef, - blockStartRef + numColsOfBlock, - )))(entriesStructRef => - GetField(entriesStructRef, entryField) - )) - MakeStruct(FastSeq( - perRowIdxId -> GetField(singleRow, perRowIdxId), - "rowOfData" -> mappedSlice, - )) - }) - bindIR(arrayOfSlicesAndIndices) { arrayOfSlicesAndIndicesRef => - val idxOfResult = GetField(ArrayRef(arrayOfSlicesAndIndicesRef, I32(0)), perRowIdxId) - val ndarrayData = ToArray(flatMapIR(ToStream(arrayOfSlicesAndIndicesRef)) { idxAndSlice => - ToStream(GetField(idxAndSlice, "rowOfData")) - }) - val numRowsOfBlock = ArrayLen(arrayOfSlicesAndIndicesRef) - val shape = maketuple(Cast(numRowsOfBlock, TInt64), Cast(numColsOfBlock, TInt64)) - val ndarray = MakeNDArray(ndarrayData, shape, True(), ErrorIDs.NO_ERROR) - MakeStream( - FastSeq(MakeStruct(FastSeq( - perRowIdxId -> idxOfResult, - "blockRowIdx" -> GetField(ctxRef, "blockRowIdx"), - "blockColIdx" -> GetField(ctxRef, "blockColIdx"), - "ndBlock" -> ndarray, - ))), - TStream(TStruct( - perRowIdxId -> TInt32, - "blockRowIdx" -> TInt32, - "blockColIdx" -> TInt32, - "ndBlock" -> ndarray.typ, - )), - ) + val tableOfNDArrays = + rowsInBlockSizeGroups + .mapContexts(createBlockMakingContexts)(GetField(_, "__old_table_ctx")) + .mapPartitionWithContext { (part, ctx) => + M.eval { + for { + blockStart <- GetField(ctx, "__block_start") + blockSize <- GetField(ctx, "__block_size") + blockEnd <- blockStart + blockSize + + data <- streamAggIR(part) { row => + makestruct( + "__num_rows" -> + ApplyAggOp(Count())(), + "__result_idx" -> + ArrayRef( + ApplyAggOp( + FastSeq(I32(1)), + FastSeq(GetField(row, "__per_row_idx")), + Take(), + ), + I32(0), + ), + "__block_data" -> { + val slices = sliceArrayIR(GetField(row, entriesFieldName), blockStart, blockEnd) + val elem = Ref(Name("__elem"), TIterable.elementType(slices.typ)) + val collect = ApplyAggOp(Collect())(GetField(elem, entryField)) + AggExplode(ToStream(slices), elem.name, collect, isScan = false) + }, + ) + } + + numRowsOfBlock <- GetField(data, "__num_rows") + idxOfResult <- GetField(data, "__result_idx") + ndarrayData <- GetField(data, "__block_data") + shape <- maketuple(numRowsOfBlock, blockSize.toL) + ndarray <- MakeNDArray(ndarrayData, shape, True(), ErrorIDs.NO_ERROR) + } yield MakeStream.single( + makestruct( + "__per_row_idx" -> idxOfResult, + "__block_row_idx" -> GetField(ctx, "__block_row_idx"), + "__block_col_idx" -> GetField(ctx, "__block_col_idx"), + "__block" -> ndarray, + ) + ) + } } - } - } val elementType = tm.entryType.fieldType(entryField) val etype = EBlockMatrixNDArray( @@ -2499,25 +2459,24 @@ case class MatrixBlockMatrixWriter( val pathsWithColMajorIndices = tableOfNDArrays.mapCollect("matrix_block_matrix_writer") { partition => ToArray(mapIR(partition) { singleNDArrayTuple => - bindIR(GetField(singleNDArrayTuple, "blockRowIdx") + (GetField( - singleNDArrayTuple, - "blockColIdx", - ) * numBlockRows)) { colMajorIndex => - val blockPath = - Str(s"$path/parts/part-") + - invoke("str", TString, colMajorIndex) + Str("-") + UUID4() - maketuple( - colMajorIndex, - WriteValue(GetField(singleNDArrayTuple, "ndBlock"), blockPath, writer), - ) + M.eval { + for { + rowIdx <- GetField(singleNDArrayTuple, "__block_row_idx") + colIdx <- GetField(singleNDArrayTuple, "__block_col_idx") + colMajorIndex <- rowIdx + (colIdx * numBlockRows) + blockPath <- strConcat(s"$path/parts/part-", colMajorIndex, "-", UUID4()) + ndarray <- GetField(singleNDArrayTuple, "__block") + } yield maketuple(colMajorIndex, WriteValue(ndarray, blockPath, writer)) } }) } - val flatPathsAndIndices = flatMapIR(ToStream(pathsWithColMajorIndices))(ToStream(_)) - val sortedColMajorPairs = sortIR(flatPathsAndIndices) { case (l, r) => - ApplyComparisonOp(LT, GetTupleElement(l, 0), GetTupleElement(r, 0)) - } - val flatPaths = ToArray(mapIR(ToStream(sortedColMajorPairs))(GetTupleElement(_, 1))) + + val sortedColMajorPairs = + sortIR(flatten(pathsWithColMajorIndices)) { (l, r) => + GetTupleElement(l, 0) < GetTupleElement(r, 0) + } + + val flatPaths = mapArray(sortedColMajorPairs)(GetTupleElement(_, 1)) val bmt = BlockMatrixType( elementType, numRows, @@ -2527,7 +2486,7 @@ case class MatrixBlockMatrixWriter( ) RelationalWriter.scoped(path, overwrite, None)(WriteMetadata( flatPaths, - BlockMatrixNativeMetadataWriter(path, false, bmt), + BlockMatrixNativeMetadataWriter(path, stageLocally = false, bmt), )) } } @@ -2563,20 +2522,20 @@ case class MatrixNativeMultiWriter( require(tables.map(_._4.tableType.keyType).distinct.length == 1) val unionType = TTuple(components.map(c => TIterable.elementType(c.stage.contexts.typ)): _*) - val contextUnionType = TStruct("matrixId" -> TInt32, "options" -> unionType) + val contextUnionType = TStruct("__matrix_id" -> TInt32, "__options" -> unionType) val emptyUnionIRs: IndexedSeq[(Int, IR)] = - IndexedSeq.tabulate(unionType.size)(i => i -> NA(unionType.types(i))) + ArraySeq.tabulate(unionType.size)(i => i -> NA(unionType.types(i))) val concatenatedContexts = flatten( MakeArray( components.zipWithIndex.map { case (c, matrixId) => ToArray(mapIR(c.stage.contexts) { ctx => - MakeStruct(FastSeq( - "matrixId" -> I32(matrixId), - "options" -> MakeTuple(emptyUnionIRs.updated(matrixId, matrixId -> ctx.ir)), - )) + makestruct( + "__matrix_id" -> I32(matrixId), + "__options" -> MakeTuple(emptyUnionIRs.updated(matrixId, matrixId -> ctx.ir)), + ) }) }, TArray(TArray(contextUnionType)), @@ -2587,42 +2546,49 @@ case class MatrixNativeMultiWriter( n.str -> ir }) - Begin(FastSeq( - Begin(components.map(_.setup)), - Let( - components.flatMap(_.stage.letBindings), - bindIR(cdaIR(concatenatedContexts, allBroadcasts, "matrix_multi_writer") { - case (ctx, globals) => - bindIR(GetField(ctx, "options")) { options => - Switch( - GetField(ctx, "matrixId"), - default = Die("MatrixId exceeds matrix count", components.head.writePartitionType), - cases = components.zipWithIndex.map { case (component, i) => - val binds = component.stage.broadcastVals.map { case (name, _) => - name -> GetField(globals, name.str) - } - - Let( - binds, - bindIR(GetTupleElement(options, i)) { ctxRef => - component.writePartition(component.stage.partition(ctxRef), ctxRef) - }, - ) - }, - ) - } - }) { cdaResult => - val partitionCountScan = - components.map(_.stage.numPartitions).scanLeft(0)(_ + _) - - Begin(components.zipWithIndex.map { case (c, i) => - c.finalizeWrite( - ArraySlice(cdaResult, partitionCountScan(i), Some(partitionCountScan(i + 1))), - c.stage.globals, - ) + M.eval { + for { + _ <- M.defer { b => + components.foreach(c => b.memoize(c.setup)) + components.foreach(_.stage.letBindings.foreach { case (name, value) => + b.strictMemoize(value, name) }) - }, - ), - )) + Void() + } + + result <- cdaIR(concatenatedContexts, allBroadcasts, "matrix_multi_writer") { + (ctx, globals) => + Switch( + GetField(ctx, "__matrix_id"), + default = Die("MatrixId exceeds matrix count", components.head.writePartitionType), + cases = components.zipWithIndex.map { case (component, i) => + val binds = component.stage.broadcastVals.map { case (name, _) => + name -> GetField(globals, name.str) + } + + Let( + binds, + IRBuilder.scoped { b => + val options = GetField(ctx, "__options") + val ctxRef = b.memoize(GetTupleElement(options, i)) + val rows = b.memoize(component.stage.partition(ctxRef)) + component.writePartition(rows, ctxRef) + }, + ) + }, + ) + } + + partCounts = components.map(_.stage.numPartitions).scanLeft(0)(_ + _) + + _ <- M.defer { b => + components.zipWithIndex.foreach { case (c, i) => + val part = b.memoize(sliceArrayIR(result, partCounts(i), partCounts(i + 1))) + b.memoize(c.finalizeWrite(part, c.stage.globals)) + } + Void() + } + } yield Void() + } } } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala index eee29399281..8ede42e1bf6 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -248,28 +248,20 @@ class TableStage( ) } - def mapPartitionWithContext(f: (IR, Ref) => IR): TableStage = + def mapPartitionWithContext(f: (IR, Atom) => IR): TableStage = copy(partitionIR = f(partitionIR, Ref(ctxRefName, ctxType))) - def mapContexts(f: IR => IR)(getOldContext: IR => IR): TableStage = { - val newContexts = f(contexts) - TableStage( - letBindings, - broadcastVals, - globals, - partitioner, - dependency, - newContexts, - ctxRef => bindIR(getOldContext(ctxRef))(partition(_)), + def mapContexts(f: Atom => IR)(getOldContext: Atom => IR): TableStage = { + val newContexts = bindIR(contexts)(f) + val newCtxRef = Ref(freshName(), TIterable.elementType(newContexts.typ)) + copy( + contexts = newContexts, + ctxRefName = newCtxRef.name, + partitionIR = bindIR(getOldContext(newCtxRef))(partition(_)), ) } - def zipContextsWithIdx(): TableStage = { - def getOldContext(ctx: IR) = GetField(ctx, "elt") - mapContexts(zipWithIndex)(getOldContext) - } - - def mapGlobals(f: IR => IR): TableStage = { + def mapGlobals(f: Atom => IR): TableStage = { val newGlobals = f(globals) val globalsRef = Ref(freshName(), newGlobals.typ) @@ -280,16 +272,16 @@ class TableStage( ) } - def mapCollect(staticID: String, dynamicID: IR = NA(TString))(f: IR => IR): IR = - mapCollectWithGlobals(staticID, dynamicID)(f)((parts, globals) => parts) + def mapCollect(staticID: String, dynamicID: IR = NA(TString))(f: Atom => IR): IR = + mapCollectWithGlobals(staticID, dynamicID)(f)((parts, _) => parts) def mapCollectWithGlobals( staticID: String, dynamicID: IR = NA(TString), )( - mapF: IR => IR + mapF: Atom => IR )( - body: (IR, IR) => IR + body: (Atom, Atom) => IR ): IR = mapCollectWithContextsAndGlobals(staticID, dynamicID)((part, ctx) => mapF(part))(body) @@ -298,9 +290,9 @@ class TableStage( staticID: String, dynamicID: IR = NA(TString), )( - mapF: (IR, Ref) => IR + mapF: (Atom, Atom) => IR )( - body: (IR, IR) => IR + body: (Atom, Atom) => IR ): IR = { val broadcastRefs = MakeStruct(broadcastVals.map { case (n, ir) => n.str -> ir }) @@ -312,7 +304,7 @@ class TableStage( glob.name, Let( broadcastVals.map { case (name, _) => name -> GetField(glob, name.str) }, - mapF(partitionIR, Ref(ctxRefName, ctxType)), + bindIR(partitionIR)(mapF(_, Ref(ctxRefName, ctxType))), ), dynamicID, staticID, @@ -323,7 +315,7 @@ class TableStage( } def collectWithGlobals(staticID: String, dynamicID: IR = NA(TString)): IR = - mapCollectWithGlobals(staticID, dynamicID)(ToArray) { (parts, globals) => + mapCollectWithGlobals(staticID, dynamicID)(ToArray(_)) { (parts, globals) => MakeStruct(FastSeq( "rows" -> ToArray(flatMapIR(ToStream(parts))(ToStream(_))), "global" -> globals, @@ -848,7 +840,7 @@ object LowerTableIR extends Logging { ) val writer = ETypeValueWriter(codecSpec) val reader = ETypeValueReader(codecSpec) - lcWithInitBinding.mapCollectWithGlobals("table_aggregate")({ part: IR => + lcWithInitBinding.mapCollectWithGlobals("table_aggregate") { part => Let( FastSeq(TableIR.globalName -> lc.globals), RunAgg( @@ -860,7 +852,7 @@ object LowerTableIR extends Logging { aggSigs.states, ), ) - }) { case (collected, globals) => + } { case (collected, globals) => val treeAggFunction = freshName() val currentAggStates = Ref(freshName(), TArray(TString)) val iterNumber = Ref(freshName(), TInt32) @@ -938,7 +930,7 @@ object LowerTableIR extends Logging { } } } else { - lcWithInitBinding.mapCollectWithGlobals("table_aggregate_singlestage")({ part: IR => + lcWithInitBinding.mapCollectWithGlobals("table_aggregate_singlestage") { part => Let( FastSeq(TableIR.globalName -> lc.globals), RunAgg( @@ -950,7 +942,7 @@ object LowerTableIR extends Logging { aggSigs.states, ), ) - }) { case (collected, globals) => + } { case (collected, globals) => Let( FastSeq(TableIR.globalName -> globals), RunAgg( @@ -1319,7 +1311,7 @@ object LowerTableIR extends Logging { case TableHead(child, targetNumRows) => val loweredChild = lower(child) - def streamLenOrMax(a: IR): IR = + def streamLenOrMax(a: Atom): IR = if (targetNumRows <= Integer.MAX_VALUE) StreamLen(StreamTake(a, targetNumRows.toInt)) else @@ -1345,9 +1337,9 @@ object LowerTableIR extends Logging { val loopBody = bindIR( loweredChild - .mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef)) { - ctx: IR => ctx - } + .mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef))( + identity + ) .mapCollect( "table_head_recursive_count", strConcat( @@ -1487,7 +1479,7 @@ object LowerTableIR extends Logging { ToStream(childContexts), maxIR(totalNumPartitions - howManyPartsToTryRef, 0), ) - ) { ctx: IR => ctx } + )(identity) .mapCollect( "table_tail_recursive_count", strConcat( @@ -1496,7 +1488,7 @@ object LowerTableIR extends Logging { Str(", nParts="), invoke("str", TString, howManyPartsToTryRef), ), - )(StreamLen) + )(StreamLen(_)) ) { counts => If( (Cast( @@ -1643,7 +1635,7 @@ object LowerTableIR extends Logging { val writer = ETypeValueWriter(codecSpec) val reader = ETypeValueReader(codecSpec) val partitionPrefixSumFiles = - lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums")({ part: IR => + lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums") { part => Let( FastSeq(TableIR.globalName -> lcWithInitBinding.globals), RunAgg( @@ -1656,7 +1648,7 @@ object LowerTableIR extends Logging { ), ) // Collected is TArray of TString - }) { case (collected, _) => + } { case (collected, _) => def combineGroup(partArrayRef: IR): IR = { Begin(FastSeq( bindIR(ReadValue( @@ -1836,8 +1828,8 @@ object LowerTableIR extends Logging { } else { val partitionAggs = - lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage")({ - part: IR => + lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage") { + part => Let( FastSeq(TableIR.globalName -> lc.globals), RunAgg( @@ -1849,7 +1841,7 @@ object LowerTableIR extends Logging { aggSigs.states, ), ) - }) { case (collected, globals) => + } { case (collected, globals) => Let( FastSeq(TableIR.globalName -> globals), ToArray(StreamTake( From 12df28ed539b751cf907fff6a3f1f736fabb5b0c Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 8 Jun 2026 11:39:26 -0400 Subject: [PATCH 2/5] [query] no-sharing in `MatrixReader` lowering --- hail/hail/src/is/hail/expr/ir/IR.scala | 19 ++-- hail/hail/src/is/hail/expr/ir/MatrixIR.scala | 107 +++++++----------- .../src/is/hail/expr/ir/TableIRSuite.scala | 2 +- 3 files changed, 51 insertions(+), 77 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index e91756b1d81..980072a9934 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -97,20 +97,16 @@ package defs { } object Let { - def apply(bindings: IndexedSeq[(Name, IR)], body: IR): Block = - Block( - bindings.map { case (name, value) => Binding(name, value) }, - body, - ) + def apply(bindings: IndexedSeq[(Name, IR)], body: IR): IR = + if (bindings.isEmpty) body + else Block(bindings.map { case (n, v) => Binding(n, v) }, body) - def void(bindings: IndexedSeq[(Name, IR)]): IR = { - if (bindings.isEmpty) { - Void() - } else { + def void(bindings: IndexedSeq[(Name, IR)]): IR = + if (bindings.isEmpty) Void() + else { assert(bindings.last._2.typ == TVoid) Let(bindings.init, bindings.last._2) } - } } object Begin { @@ -538,6 +534,9 @@ package defs { MakeArray(args.toFastSeq, TArray(args.head.typ)) } + def empty(elementType: Type): MakeArray = + MakeArray(FastSeq.empty[IR], TArray(elementType)) + def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null) : MakeArray = { assert(requestedType != null || args.nonEmpty) diff --git a/hail/hail/src/is/hail/expr/ir/MatrixIR.scala b/hail/hail/src/is/hail/expr/ir/MatrixIR.scala index 6c86056729a..1d06d866ce1 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixIR.scala @@ -4,7 +4,6 @@ import is.hail.annotations._ import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq -import is.hail.expr.ir.DeprecatedIRBuilder._ import is.hail.expr.ir.analyses.{ColumnCount, PartitionCounts} import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.MatrixToMatrixFunction @@ -198,20 +197,14 @@ abstract class MatrixHybridReader extends TableReaderWithExtraUID with MatrixRea tr, InsertFields( Ref(TableIR.rowName, tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.entryType), - )), + FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray.empty(requestedType.entryType)), ), ) tr = TableMapGlobals( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.colType), - )), + FastSeq(LowerMatrixIR.colsFieldName -> MakeArray.empty(requestedType.colType)), ), ) } @@ -261,7 +254,7 @@ case class MatrixNativeReaderParameters( class MatrixNativeReader( val params: MatrixNativeReaderParameters, - spec: AbstractMatrixTableSpec, + val spec: AbstractMatrixTableSpec, ) extends MatrixReader { override def pathsUsed: Seq[String] = FastSeq(params.path) @@ -300,20 +293,14 @@ class MatrixNativeReader( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.colType), - )), + FastSeq(LowerMatrixIR.colsFieldName -> MakeArray.empty(requestedType.colType)), ), ) TableMapRows( tr, InsertFields( Ref(TableIR.rowName, tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.entryType), - )), + FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray.empty(requestedType.entryType)), ), ) } else { @@ -325,36 +312,27 @@ class MatrixNativeReader( spec.rowsSpec, spec.entriesSpec, ) - val tr: TableIR = TableRead(tt, dropRows, trdr) + val tr = TableRead(tt, dropRows, trdr) val colsRVDSpec = spec.colsSpec.rowsSpec val partFiles = colsRVDSpec.absolutePartPaths(spec.colsSpec.rowsComponent.absolutePath(colsPath)) - val cols = if (partFiles.length == 1) { + def readCols(index: IR, path: IR): IR = ReadPartition( - MakeStruct(ArraySeq("partitionIndex" -> I64(0), "partitionPath" -> Str(partFiles.head))), + makestruct("partitionIndex" -> index.toL, "partitionPath" -> path), requestedType.colType, PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName), ) - } else { - val contextType = TStruct("partitionIndex" -> TInt64, "partitionPath" -> TString) - val partNames = MakeArray( - partFiles.zipWithIndex.map { case (path, idx) => - MakeStruct(ArraySeq("partitionIndex" -> I64(idx.toLong), "partitionPath" -> Str(path))) - }, - TArray(contextType), - ) - val elt = Ref(freshName(), contextType) - StreamFlatMap( - partNames, - elt.name, - ReadPartition( - elt, - requestedType.colType, - PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName), - ), + + val cols = + if (partFiles.length == 1) readCols(0, Str(partFiles.head)) + else flatten( + zip2( + iota(0, 1), + ToStream(Literal(TArray(TString), partFiles)), + ArrayZipBehavior.TakeMinLength, + )(readCols(_, _)) ) - } TableMapGlobals( tr, @@ -377,8 +355,6 @@ class MatrixNativeReader( case that: MatrixNativeReader => params == that.params case _ => false } - - def getSpec(): AbstractMatrixTableSpec = this.spec } object MatrixRangeReader { @@ -402,13 +378,13 @@ object MatrixRangeReader { case class MatrixRangeReaderParameters(nRows: Int, nCols: Int, nPartitions: Option[Int]) case class MatrixRangeReader( - val params: MatrixRangeReaderParameters, + params: MatrixRangeReaderParameters, nPartitionsAdj: Int, ) extends MatrixReader { override def pathsUsed: Seq[String] = FastSeq() - override def rowUIDType = TInt64 - override def colUIDType = TInt64 + override def rowUIDType: Type = TInt64 + override def colUIDType: Type = TInt64 override def fullMatrixTypeWithoutUIDs: MatrixType = MatrixType( globalType = TStruct.empty, @@ -432,32 +408,39 @@ case class MatrixRangeReader( dropCols: Boolean, dropRows: Boolean, ): TableIR = { + import DeprecatedIRBuilder._ + val nRowsAdj = if (dropRows) 0 else params.nRows val nColsAdj = if (dropCols) 0 else params.nCols var ht = TableRange(nRowsAdj, params.nPartitions.getOrElse(ctx.backend.defaultParallelism)) .rename(Map("idx" -> "row_idx")) - if (requestedType.colType.hasField(colUIDFieldName)) - ht = ht.mapGlobals(makeStruct(LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> makeStruct( - 'col_idx -> 'i, - Symbol(colUIDFieldName) -> 'i.toL, - )))) + + ht = if (requestedType.colType.hasField(colUIDFieldName)) + ht.mapGlobals(makeStruct( + LowerMatrixIR.colsField -> + irRange(0, nColsAdj).map('i ~> + makeStruct( + 'col_idx -> 'i, + Symbol(colUIDFieldName) -> 'i.toL, + )) + )) else - ht = ht.mapGlobals(makeStruct(LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i)))) + ht.mapGlobals(makeStruct( + LowerMatrixIR.colsField -> + irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i)) + )) + if (requestedType.rowType.hasField(rowUIDFieldName)) - ht = ht.mapRows('row.insertFields( - LowerMatrixIR.entriesField -> irRange(0, nColsAdj).map('i ~> makeStruct()), + ht.mapRows('row.insertFields( + LowerMatrixIR.entriesField -> irRange(0, nColsAdj).map('i ~> makestruct()), Symbol(rowUIDFieldName) -> 'row('row_idx).toL, )) else - ht = ht.mapRows('row.insertFields( + ht.mapRows('row.insertFields( LowerMatrixIR.entriesField -> - irRange(0, nColsAdj).map('i ~> makeStruct()) + irRange(0, nColsAdj).map('i ~> makestruct()) )) - - ht } override def toJValue: JValue = { @@ -484,14 +467,6 @@ object MatrixRead { !reader.fullMatrixTypeWithoutUIDs.colType.hasField(MatrixReader.colUIDFieldName)) new MatrixRead(typ, dropCols, dropRows, reader) } - - def preserveExistingUIDs( - typ: MatrixType, - dropCols: Boolean, - dropRows: Boolean, - reader: MatrixReader, - ): MatrixRead = - new MatrixRead(typ, dropCols, dropRows, reader) } case class MatrixRead( diff --git a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala index 61cd1e34bcb..4d5a6e449cc 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala @@ -1474,7 +1474,7 @@ class TableIRSuite extends HailSuite { val entriesPath = getTestResource("sample.vcf.mt/entries") val mnr = MatrixNativeReader(fs, getTestResource("sample.vcf.mt")) - val mnrSpec = mnr.getSpec() + val mnrSpec = mnr.spec val reader = TableNativeZippedReader(rowsPath, entriesPath, None, mnrSpec.rowsSpec, mnrSpec.entriesSpec) From 43fcbebcfc970a0fd99a356b5557853df2c75f87 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 8 Jun 2026 11:48:59 -0400 Subject: [PATCH 3/5] [query] `DeprecatedIRBuilder` uses `BindingEnv` --- .../is/hail/expr/ir/DeprecatedIRBuilder.scala | 161 +++++++++--------- 1 file changed, 84 insertions(+), 77 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala index 84458e7dbb3..6a0b2d2696b 100644 --- a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala +++ b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala @@ -3,13 +3,14 @@ package is.hail.expr.ir import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.implicits.toRichIterable +import is.hail.expr.ir.Scope._ import is.hail.expr.ir.defs._ import is.hail.types.virtual._ import scala.language.dynamics object DeprecatedIRBuilder { - type E = Env[Type] + type E = BindingEnv[Type] implicit def funcToIRProxy(ir: E => IR): IRProxy = new IRProxy(ir) @@ -25,7 +26,7 @@ object DeprecatedIRBuilder { implicit def booleanToProxy(b: Boolean): IRProxy = if (b) True() else False() implicit def ref(s: Symbol): IRProxy = (env: E) => - Ref(Name(s.name), env.lookup(Name(s.name))) + Ref(Name(s.name), env.eval.lookup(Name(s.name))) implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s) @@ -56,9 +57,8 @@ object DeprecatedIRBuilder { def concatStructs(struct1: IRProxy, struct2: IRProxy): IRProxy = (env: E) => { val s2Type = struct2(env).typ.asInstanceOf[TStruct] - let(__struct2 = struct2) { - struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*) - }(env) + (let(__struct2 = struct2) in + struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*))(env) } def makeTuple(values: IRProxy*): IRProxy = (env: E) => @@ -69,34 +69,31 @@ object DeprecatedIRBuilder { initOpArgs: IndexedSeq[IRProxy] = FastSeq(), seqOpArgs: IndexedSeq[IRProxy] = FastSeq(), ): IRProxy = (env: E) => { - val i = initOpArgs.map(x => x(env)) - val s = seqOpArgs.map(x => x(env)) + val i = initOpArgs.map(x => x(env.noAgg)) + val s = seqOpArgs.map(x => x(env.promoteAgg)) ApplyAggOp(i, s, op) } def aggFilter(filterCond: IRProxy, query: IRProxy, isScan: Boolean = false): IRProxy = (env: E) => - AggFilter(filterCond(env), query(env), isScan) + AggFilter(filterCond(env.promoteAgg), query(env), isScan) class TableIRProxy(val tir: TableIR) extends AnyVal { - def empty: E = Env.empty - - def globalEnv: E = typ.globalEnv - - def env: E = typ.rowEnv + def empty: E = BindingEnv.empty def typ: TableType = tir.typ def getGlobals: IR = TableGetGlobals(tir) def mapGlobals(newGlobals: IRProxy): TableIR = - TableMapGlobals(tir, newGlobals(globalEnv)) + TableMapGlobals(tir, newGlobals(BindingEnv(typ.globalEnv))) def mapRows(newRow: IRProxy): TableIR = - TableMapRows(tir, newRow(env)) + TableMapRows(tir, newRow(BindingEnv(typ.rowEnv, scan = Some(typ.rowEnv)))) def explode(sym: Symbol): TableIR = TableExplode(tir, FastSeq(sym.name)) - def aggregateByKey(aggIR: IRProxy): TableIR = TableAggregateByKey(tir, aggIR(env)) + def aggregateByKey(aggIR: IRProxy): TableIR = + TableAggregateByKey(tir, aggIR(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv)))) def keyBy(keys: IndexedSeq[String], isSorted: Boolean = false): TableIR = TableKeyBy(tir, keys, isSorted) @@ -108,7 +105,7 @@ object DeprecatedIRBuilder { rename(Map.empty, globalMap) def filter(ir: IRProxy): TableIR = - TableFilter(tir, ir(env)) + TableFilter(tir, ir(BindingEnv(typ.rowEnv))) def distinct(): TableIR = TableDistinct(tir) @@ -129,7 +126,7 @@ object DeprecatedIRBuilder { } def aggregate(ir: IRProxy): IR = - TableAggregate(tir, ir(env)) + TableAggregate(tir, ir(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv)))) } class IRProxy(val ir: E => IR) extends AnyVal with Dynamic { @@ -214,7 +211,7 @@ object DeprecatedIRBuilder { case _: TStruct => GetField(eval, lookup.name) case _: TArray => - ArrayRef(ir(env), ref(lookup)(env)) + ArrayRef(eval, ref(lookup)(env)) } } @@ -259,65 +256,70 @@ object DeprecatedIRBuilder { def isNA: IRProxy = (env: E) => IsNA(ir(env)) - def orElse(alt: IRProxy): IRProxy = { env: E => - val uid = freshName() - val eir = ir(env) - Let(FastSeq(uid -> eir), If(IsNA(Ref(uid, eir.typ)), alt(env), Ref(uid, eir.typ))) - } + def orElse(alt: IRProxy): IRProxy = + (env: E) => bindIR(ir(env))(x => If(IsNA(x), alt(env), x)) - def filter(pred: LambdaProxy): IRProxy = (env: E) => { + def filter(pred: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(pred.s.name) -> TIterable.elementType(array.typ) ToArray(StreamFilter( ToStream(array), - Name(pred.s.name), - pred.body(env.bind(Name(pred.s.name) -> eltType)), + binding._1, + pred.body(env.bindEval(binding)), )) } - def map(f: LambdaProxy): IRProxy = (env: E) => { + def map(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) ToArray(StreamMap( ToStream(array), - Name(f.s.name), - f.body(env.bind(Name(f.s.name) -> eltType)), + binding._1, + f.body(env.bindEval(binding)), )) } - def aggExplode(f: LambdaProxy): IRProxy = (env: E) => { - val array = ir(env) + def aggExplode(f: LambdaProxy): IRProxy = { env: E => + val array = ir(env.promoteAgg) + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) AggExplode( ToStream(array), - Name(f.s.name), - f.body(env.bind(Name(f.s.name), array.typ.asInstanceOf[TArray].elementType)), + binding._1, + f.body(env.bindEval(binding).bindAgg(binding)), isScan = false, ) } - def flatMap(f: LambdaProxy): IRProxy = (env: E) => { + def flatMap(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) ToArray(StreamFlatMap( ToStream(array), - Name(f.s.name), - ToStream(f.body(env.bind(Name(f.s.name) -> eltType))), + binding._1, + ToStream(f.body(env.bindEval(binding))), )) } - def streamAgg(f: LambdaProxy): IRProxy = (env: E) => { + def flatten: IRProxy = + flatMap('a ~> 'a) + + def streamAgg(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType - StreamAgg(ToStream(array), Name(f.s.name), f.body(env.bind(Name(f.s.name) -> eltType))) + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) + StreamAgg( + ToStream(array), + binding._1, + f.body(env.bindEval(binding).createAgg), + ) } - def streamAggScan(f: LambdaProxy): IRProxy = (env: E) => { + def streamAggScan(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) ToArray(StreamAggScan( ToStream(array), - Name(f.s.name), - f.body(env.bind(Name(f.s.name) -> eltType)), + binding._1, + f.body(env.bindEval(binding).createScan), )) } @@ -338,14 +340,20 @@ object DeprecatedIRBuilder { knownLength: Option[IRProxy], )( aggBody: IRProxy - ): IRProxy = (env: E) => { - val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + ): IRProxy = { env: E => + val array = ir(env.promoteAgg) + + val bindings = + FastSeq( + Name(elementsSym.name) -> TIterable.elementType(array.typ), + Name(indexSym.name) -> TInt32, + ) + AggArrayPerElement( array, - Name(elementsSym.name), - Name(indexSym.name), - aggBody.apply(env.bind(Name(elementsSym.name) -> eltType, Name(indexSym.name) -> TInt32)), + bindings(0)._1, + bindings(1)._1, + aggBody(env.bindEval(bindings: _*).bindAgg(bindings: _*)), knownLength.map(_(env)), isScan = false, ) @@ -361,18 +369,17 @@ object DeprecatedIRBuilder { def toDict: IRProxy = (env: E) => ToDict(ToStream(ir(env))) def parallelize(nPartitions: Option[Int] = None): TableIR = - TableParallelize(ir(Env.empty), nPartitions) + TableParallelize(ir(BindingEnv.empty), nPartitions) - def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = { - val element = Symbol(genUID()) - ir - .map(element ~> + def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = + ir.map( + '__elem ~> makeTuple( - element.selectFields(keyFields: _*), - element.dropFieldList(keyFields), - )) + '__elem.selectFields(keyFields: _*), + '__elem.dropFieldList(keyFields), + ) + ) .toDict - } def tupleElement(i: Int): IRProxy = (env: E) => GetTupleElement(ir(env), i) @@ -391,8 +398,13 @@ object DeprecatedIRBuilder { def bind(bindings: IndexedSeq[BindingProxy], body: IRProxy, env: E): IR = { var newEnv = env val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) => - val resolvedValue = value(newEnv) - newEnv = newEnv.bind(Name(sym.name) -> resolvedValue.typ) + val resolvedValue = + value( + if (scope == AGG) newEnv.promoteAgg + else if (scope == SCAN) newEnv.promoteScan + else newEnv + ) + newEnv = newEnv.bindInScope(Name(sym.name), resolvedValue.typ, scope) Binding(Name(sym.name), resolvedValue, scope) } Block(resolvedBindings, body(newEnv)) @@ -413,26 +425,21 @@ object DeprecatedIRBuilder { class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { def apply(body: IRProxy): IRProxy = in(body) - - def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) } + def in(body: IRProxy): IRProxy = { env: E => LetProxy.bind(bindings, body, env) } } object aggLet extends Dynamic { - def applyDynamicNamed(method: String)(args: (String, IRProxy)*): AggLetProxy = { + def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { assert(method == "apply") - new AggLetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) }) + new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) }) } } - class AggLetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { - def apply(body: IRProxy): IRProxy = in(body) - - def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) } - } - - object MapIRProxy { - def apply(f: (IRProxy) => IRProxy)(x: IRProxy): IRProxy = (e: E) => - MapIR(x => f(x)(e))(x(e)) + object scanLet extends Dynamic { + def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { + assert(method == "apply") + new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.SCAN) }) + } } def subst(x: IRProxy, env: BindingEnv[IRProxy]): IRProxy = (e: E) => From 3b7b73d2b6b36300700bf50ee56502f91f26b6ae Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Thu, 28 May 2026 15:11:46 -0400 Subject: [PATCH 4/5] [query] no-sharing in matrix ir lowering --- .../is/hail/expr/ir/DeprecatedIRBuilder.scala | 11 - .../src/is/hail/expr/ir/ForwardLets.scala | 2 +- .../src/is/hail/expr/ir/LowerMatrixIR.scala | 1216 ++++++++--------- hail/hail/src/is/hail/expr/ir/Optimize.scala | 2 +- hail/hail/src/is/hail/expr/ir/Pretty.scala | 6 +- .../src/is/hail/expr/ir/agg/Extract.scala | 5 +- .../expr/ir/lowering/LowerBlockMatrixIR.scala | 16 +- .../hail/expr/ir/lowering/LoweringPass.scala | 10 +- .../expr/ir/lowering/invariant/package.scala | 6 +- hail/hail/test/src/is/hail/HailSuite.scala | 4 +- .../is/hail/expr/ir/Aggregators2Suite.scala | 42 +- .../src/is/hail/expr/ir/MatrixIRSuite.scala | 10 +- .../is/hail/expr/ir/table/TableGenSuite.scala | 32 +- 13 files changed, 649 insertions(+), 713 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala index 6a0b2d2696b..5d24fbe7053 100644 --- a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala +++ b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala @@ -424,7 +424,6 @@ object DeprecatedIRBuilder { } class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { - def apply(body: IRProxy): IRProxy = in(body) def in(body: IRProxy): IRProxy = { env: E => LetProxy.bind(bindings, body, env) } } @@ -442,15 +441,5 @@ object DeprecatedIRBuilder { } } - def subst(x: IRProxy, env: BindingEnv[IRProxy]): IRProxy = (e: E) => - Subst( - x(e), - BindingEnv( - env.eval.mapValues(_(e)), - agg = env.agg.map(_.mapValues(_(e))), - scan = env.scan.map(_.mapValues(_(e))), - ), - ) - def lift(f: (IR) => IRProxy)(x: IRProxy): IRProxy = (e: E) => f(x(e))(e) } diff --git a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala index 5355ba3762d..91a0a203f01 100644 --- a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala +++ b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala @@ -40,7 +40,7 @@ object ForwardLets extends Logging { else { logger.info( f"Eliminating unused binding:\n" + - f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value, preserveNames = true).trim}" + f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value).trim}" ) env } diff --git a/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala b/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala index ecab3fef958..940f80452a3 100644 --- a/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala @@ -4,6 +4,7 @@ import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.compat.mutable.Growable +import is.hail.expr.ir.{Memoized => M} import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.{WrappedMatrixToTableFunction, WrappedMatrixToValueFunction} import is.hail.types.virtual._ @@ -15,44 +16,40 @@ object LowerMatrixIR { val colsField: Symbol = Symbol(colsFieldName) val entriesField: Symbol = Symbol(entriesFieldName) - def apply(ctx: ExecuteContext, ir: IR): IR = { + def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = { val ab = ArraySeq.newBuilder[(Name, IR)] - val l1 = lower(ctx, ir, ab) - ab.result().foldRight[IR](l1) { case ((ident, value), body) => - RelationalLet(ident, value, body) - } - } - def apply(ctx: ExecuteContext, tir: TableIR): TableIR = { - val ab = ArraySeq.newBuilder[(Name, IR)] - val l1 = lower(ctx, tir, ab) - ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => - RelationalLetTable(ident, value, body) - } - } - - def apply(ctx: ExecuteContext, mir: MatrixIR): TableIR = { - val ab = ArraySeq.newBuilder[(Name, IR)] - - val l1 = lower(ctx, mir, ab) - ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => - RelationalLetTable(ident, value, body) - } - } - - def apply(ctx: ExecuteContext, bmir: BlockMatrixIR): BlockMatrixIR = { - val ab = ArraySeq.newBuilder[(Name, IR)] + val lowered = + ir0 match { + case ir: IR => + val l1 = lower(ctx, ir, ab) + ab.result().foldRight[IR](l1) { case ((ident, value), body) => + RelationalLet(ident, value, body) + } + case tir: TableIR => + val l1 = lower(ctx, tir, ab) + ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => + RelationalLetTable(ident, value, body) + } + case mir: MatrixIR => + val l1 = lower(ctx, mir, ab) + ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => + RelationalLetTable(ident, value, body) + } + case bmir: BlockMatrixIR => + val l1 = lower(ctx, bmir, ab) + assert(ab.result().isEmpty) + l1 + } - val l1 = lower(ctx, bmir, ab) - assert(ab.result().isEmpty) - l1 + NormalizeNames()(ctx, lowered) } - private[this] def lowerChildren( + private def lowerChildren( ctx: ExecuteContext, ir: BaseIR, ab: Growable[(Name, IR)], - ): BaseIR = { + ): BaseIR = ir.mapChildren { case tir: TableIR => lower(ctx, tir, ab) case mir: MatrixIR => throw new RuntimeException(s"expect specialized lowering rule for " + @@ -60,59 +57,38 @@ object LowerMatrixIR { case bmir: BlockMatrixIR => lower(ctx, bmir, ab) case vir: IR => lower(ctx, vir, ab) } - } def colVals(tir: TableIR): IR = GetField(Ref(TableIR.globalName, tir.typ.globalType), colsFieldName) - def globals(tir: TableIR): IR = + def globals(tir: TableIR): IR = { + val globalType = tir.typ.globalType SelectFields( - Ref(TableIR.globalName, tir.typ.globalType), - tir.typ.globalType.fieldNames.diff(FastSeq(colsFieldName)), + Ref(TableIR.globalName, globalType), + globalType.fieldNames.diff(FastSeq(colsFieldName)), ) + } - def nCols(tir: TableIR): IR = ArrayLen(colVals(tir)) + def rowVal(tir: TableIR): IR = { + val rowType = tir.typ.rowType + SelectFields( + Ref(TableIR.rowName, rowType), + rowType.fieldNames.diff(FastSeq(entriesFieldName)), + ) + } def entries(tir: TableIR): IR = GetField(Ref(TableIR.rowName, tir.typ.rowType), entriesFieldName) import is.hail.expr.ir.DeprecatedIRBuilder._ - def matrixSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { - val e = Env[IRProxy]( - MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*), - MatrixIR.rowName -> 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) - BindingEnv(e, agg = Some(e), scan = Some(e)) - } - - def matrixGlobalSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { - val e = - Env[IRProxy](MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*)) - BindingEnv(e, agg = Some(e), scan = Some(e)) - } - - def matrixSubstEnvIR(child: MatrixIR, lowered: TableIR): BindingEnv[IR] = { - val e = Env[IR]( - MatrixIR.globalName -> SelectFields( - Ref(TableIR.globalName, lowered.typ.globalType), - child.typ.globalType.fieldNames, - ), - MatrixIR.rowName -> SelectFields( - Ref(TableIR.rowName, lowered.typ.rowType), - child.typ.rowType.fieldNames, - ), - ) - BindingEnv(e, agg = Some(e), scan = Some(e)) - } - private def bindingsToStruct(bindings: IndexedSeq[(Name, IR)]): MakeStruct = MakeStruct(bindings.map { case (n, ir) => n.str -> ir }) - private def unwrapStruct(bindings: IndexedSeq[(Name, IR)], struct: Atom): IndexedSeq[(Name, IR)] = + private def unwrapStruct(bindings: IndexedSeq[(Name, _)], struct: Atom): IndexedSeq[(Name, IR)] = bindings.map { case (name, _) => name -> GetField(struct, name.str) } - private[this] def lower( + private def lower( ctx: ExecuteContext, mir: MatrixIR, liftedRelationalLets: Growable[(Name, IR)], @@ -127,7 +103,7 @@ object LowerMatrixIR { case CastTableToMatrix(child, entries, cols, _) => val lc = lower(ctx, child, liftedRelationalLets) - val row = Ref(TableIR.rowName, lc.typ.rowType) + val row: Atom = Ref(TableIR.rowName, lc.typ.rowType) val glob = Ref(TableIR.globalName, lc.typ.globalType) TableMapRows( lc, @@ -141,14 +117,12 @@ object LowerMatrixIR { entriesLen cne colsLen, Die( strConcat( - Str( - "length mismatch between entry array and column array in 'to_matrix_table_row_major': " - ), - invoke("str", TString, entriesLen), - Str(" entries, "), - invoke("str", TString, colsLen), - Str(" cols, at "), - invoke("str", TString, SelectFields(row, child.typ.key)), + "length mismatch between entry array and column array in 'to_matrix_table_row_major': ", + entriesLen, + " entries, ", + colsLen, + " cols, at ", + SelectFields(row, child.typ.key), ), row.typ, -1, @@ -169,18 +143,16 @@ object LowerMatrixIR { if (colMap.nonEmpty) { val newColsType = TArray(child.typ.colType.rename(colMap)) - t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq(( - colsFieldName, - newColsType, - ))))) + t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq( + colsFieldName -> newColsType + )))) } if (entryMap.nonEmpty) { val newEntriesType = TArray(child.typ.entryType.rename(entryMap)) - t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq(( - entriesFieldName, - newEntriesType, - ))))) + t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq( + entriesFieldName -> newEntriesType + )))) } t @@ -190,22 +162,36 @@ object LowerMatrixIR { case MatrixFilterRows(child, pred) => lower(ctx, child, liftedRelationalLets) - .filter(subst(lower(ctx, pred, liftedRelationalLets), matrixSubstEnv(child))) + .filter( + let( + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in lower(ctx, pred, liftedRelationalLets) + ) case MatrixFilterCols(child, pred) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('newColIdx -> - irRange(0, 'global(colsField).len) - .filter('i ~> - (let(sa = 'global(colsField)('i)) - in subst(lower(ctx, pred, liftedRelationalLets), matrixGlobalSubstEnv(child)))))) - .mapRows('row.insertFields( - entriesField -> 'global('newColIdx).map('i ~> 'row(entriesField)('i)) - )) - .mapGlobals('global - .insertFields(colsField -> - 'global('newColIdx).map('i ~> 'global(colsField)('i))) - .dropFields('newColIdx)) + .mapGlobals( + 'global.insertFields( + '__new_col_idx -> + (let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in irRange(0, '__cols.len).filter('__col_idx ~> + (let(sa = '__cols('__col_idx)) in + lower(ctx, pred, liftedRelationalLets)))) + ) + ) + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields(entriesField -> 'global('__new_col_idx).map('i ~> '__entries('i))) + ) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global + .insertFields(colsField -> 'global('__new_col_idx).map('i ~> '__cols('i))) + .dropFields('__new_col_idx) + ) case MatrixAnnotateRowsTable(child, table, root, product) => val kt = table.typ.keyType @@ -225,44 +211,48 @@ object LowerMatrixIR { case MatrixChooseCols(child, oldIndices) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('newColIdx -> oldIndices.map(I32))) - .mapRows('row.insertFields( - entriesField -> 'global('newColIdx).map('i ~> 'row(entriesField)('i)) - )) - .mapGlobals('global - .insertFields(colsField -> 'global('newColIdx).map('i ~> 'global(colsField)('i))) - .dropFields('newColIdx)) + .mapGlobals('global.insertFields('__new_col_idx -> Literal(TArray(TInt32), oldIndices))) + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields(entriesField -> 'global('__new_col_idx).map('i ~> '__entries('i))) + ) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global + .insertFields(colsField -> 'global('__new_col_idx).map('i ~> '__cols('i))) + .dropFields('__new_col_idx) + ) case MatrixAnnotateColsTable(child, table, root) => - val col = Symbol(genUID()) - val colKey = makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => - Symbol(tk) -> col(Symbol(mck)) - }: _*) lower(ctx, child, liftedRelationalLets) - .mapGlobals(let(__dictfield = - lower(ctx, table, liftedRelationalLets) - .keyBy(FastSeq()) - .collect() - .apply('rows) - .arrayStructToDict(table.typ.key) - ) { - 'global.insertFields(colsField -> - 'global(colsField).map(col ~> col.insertFields(Symbol(root) -> '__dictfield.invoke( - "get", - table.typ.valueType, - colKey, - )))) - }) + .mapGlobals( + let( + __dictfield = + lower(ctx, table, liftedRelationalLets) + .keyBy(FastSeq()) + .collect() + .apply('rows) + .arrayStructToDict(table.typ.key) + ) in 'global.insertFields( + colsField -> { + val key = + makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => + Symbol(tk) -> '__cols(Symbol(mck)) + }: _*) + + 'global(colsField).map('__cols ~> + '__cols.insertFields( + Symbol(root) -> '__dictfield.invoke("get", table.typ.valueType, key) + )) + } + ) + ) case MatrixMapGlobals(child, newGlobals) => lower(ctx, child, liftedRelationalLets) .mapGlobals( - subst( - lower(ctx, newGlobals, liftedRelationalLets), - BindingEnv(Env[IRProxy]( - TableIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*) - )), - ) + (let(global = 'global.selectFields(child.typ.globalType.fieldNames: _*)) in + lower(ctx, newGlobals, liftedRelationalLets)) .insertFields(colsField -> 'global(colsField)) ) @@ -271,12 +261,12 @@ object LowerMatrixIR { def lift(ir: IR, builder: Growable[(Name, IR)]): IR = ir match { case a: ApplyScanOp => val s = freshName() - builder += ((s, a)) + builder += (s -> a) Ref(s, a.typ) case a @ AggFold(_, _, _, _, _, true) => val s = freshName() - builder += ((s, a)) + builder += (s -> a) Ref(s, a.typ) case AggFilter(filt, body, true) => @@ -305,14 +295,14 @@ object LowerMatrixIR { val aggIR = AggGroupBy(a, bindingsToStruct(aggs), true) val uid = Ref(freshName(), aggIR.typ) builder += (uid.name -> aggIR) - val elementType = aggIR.typ.asInstanceOf[TDict].elementType - val valueType = elementType.types(1) - val valueUID = Ref(freshName(), valueType) + ToDict(mapIR(ToStream(uid)) { eltUID => - Let( - (valueUID.name -> GetField(eltUID, "value")) +: unwrapStruct(aggs, valueUID), - MakeTuple.ordered(FastSeq(GetField(eltUID, "key"), liftedBody)), - ) + bindIR(GetField(eltUID, "value")) { value => + Let( + unwrapStruct(aggs, value), + maketuple(GetField(eltUID, "key"), liftedBody), + ) + } }) case AggArrayPerElement(a, elementName, indexName, body, knownLength, true) => @@ -332,9 +322,8 @@ object LowerMatrixIR { case Block(bindings, body) => val newBindings = ArraySeq.newBuilder[Binding] def go(i: Int, builder: Growable[(Name, IR)]): IR = { - if (i == bindings.length) { - lift(body, builder) - } else bindings(i) match { + if (i == bindings.length) lift(body, builder) + else bindings(i) match { case Binding(name, value, Scope.SCAN) => val ab = ArraySeq.newBuilder[(Name, IR)] val liftedBody = go(i + 1, ab) @@ -361,56 +350,49 @@ object LowerMatrixIR { val ab = ArraySeq.newBuilder[(Name, IR)] val b0 = lift(ir, ab) - val scans = ab.result() - val scanStruct = MakeStruct(scans.map { case (n, ir) => n.str -> ir }) - - val scanResultRef = Ref(freshName(), scanStruct.typ) - - val b1 = if (ContainsAgg(b0)) { - irRange(0, 'row(entriesField).len) - .filter('i ~> !'row(entriesField)('i).isNA) - .streamAgg('i ~> - (aggLet(sa = 'global(colsField)('i), g = 'row(entriesField)('i)) - in b0)) - } else + val b1 = if (ContainsAgg(b0)) + irRange(0, '__entries.len) + .filter('i ~> !'__entries('i).isNA) + .streamAgg('i ~> (aggLet(sa = '__cols('i), g = '__entries('i)) in b0)) + else irToProxy(b0) - letDyn( - ((scanResultRef.name, irToProxy(scanStruct)) - +: scans.map { case (name, _) => - name -> irToProxy(GetField(scanResultRef, name.str)) - }): _* - )(b1) + scanLet( + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in (letDyn(ab.result().map { case (name, expr) => name -> irToProxy(expr) }: _*) in b1) } - val lc = lower(ctx, child, liftedRelationalLets) - lc.mapRows(let(n_cols = 'global(colsField).len) { - liftScans(Subst(lower(ctx, newRow, liftedRelationalLets), matrixSubstEnvIR(child, lc))) + lower(ctx, child, liftedRelationalLets).mapRows( + (let( + __cols = 'global(colsField), + __entries = 'row(entriesField), + n_cols = '__cols.len, + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in liftScans(lower(ctx, newRow, liftedRelationalLets))) .insertFields(entriesField -> 'row(entriesField)) - }) + ) case MatrixMapCols(child, newCol, _) => - val loweredChild = lower(ctx, child, liftedRelationalLets) + val lc = lower(ctx, child, liftedRelationalLets) def lift(ir: IR, scanBindings: Growable[(Name, IR)], aggBindings: Growable[(Name, IR)]) : IR = ir match { case a: ApplyScanOp => val s = freshName() - scanBindings += ((s, a)) + scanBindings += (s -> a) Ref(s, a.typ) case a: ApplyAggOp => val s = freshName() - aggBindings += ((s, a)) + aggBindings += (s -> a) Ref(s, a.typ) case a @ AggFold(_, _, _, _, _, isScan) => val s = freshName() - if (isScan) { - scanBindings += ((s, a)) - } else { - aggBindings += ((s, a)) - } + if (isScan) scanBindings += (s -> a) + else aggBindings += (s -> a) Ref(s, a.typ) case AggFilter(filt, body, isScan) => @@ -420,11 +402,11 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val structResult = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val structResult = bindingsToStruct(aggs) val uid = Ref(freshName(), structResult.typ) builder += (uid.name -> AggFilter(filt, structResult, isScan)) - Let(aggs.map { case (name, _) => name -> GetField(uid, name.str) }, liftedBody) + Let(unwrapStruct(aggs, uid), liftedBody) case AggExplode(a, name, body, isScan) => val ab = ArraySeq.newBuilder[(Name, IR)] @@ -433,10 +415,10 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val structResult = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val structResult = bindingsToStruct(aggs) val uid = Ref(freshName(), structResult.typ) builder += (uid.name -> AggExplode(a, name, structResult, isScan)) - Let(aggs.map { case (name, _) => name -> GetField(uid, name.str) }, liftedBody) + Let(unwrapStruct(aggs, uid), liftedBody) case AggGroupBy(a, body, isScan) => val ab = ArraySeq.newBuilder[(Name, IR)] @@ -445,24 +427,15 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val aggIR = AggGroupBy(a, MakeStruct(aggs.map { case (n, ir) => n.str -> ir }), isScan) + val aggIR = AggGroupBy(a, bindingsToStruct(aggs), isScan) val uid = Ref(freshName(), aggIR.typ) builder += (uid.name -> aggIR) - val valueUID = freshName() - val elementType = aggIR.typ.asInstanceOf[TDict].elementType - val valueType = elementType.types(1) ToDict(mapIR(ToStream(uid)) { eltUID => - MakeTuple.ordered( - FastSeq( - GetField(eltUID, "key"), - Let( - (valueUID -> GetField(eltUID, "value")) +: - aggs.map { case (name, _) => - name -> GetField(Ref(valueUID, valueType), name.str) - }, - liftedBody, - ), - ) + maketuple( + GetField(eltUID, "key"), + bindIR(GetField(eltUID, "value")) { value => + Let(unwrapStruct(aggs, value), liftedBody) + }, ) }) @@ -473,22 +446,19 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val aggBody = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val aggBody = bindingsToStruct(aggs) val aggIR = AggArrayPerElement(a, elementName, indexName, aggBody, knownLength, isScan) val uid = Ref(freshName(), aggIR.typ) builder += (uid.name -> aggIR) - ToArray(mapIR(ToStream(uid)) { eltUID => - Let(aggs.map { case (name, _) => name -> GetField(eltUID, name.str) }, liftedBody) - }) + ToArray(mapIR(ToStream(uid))(eltUID => Let(unwrapStruct(aggs, eltUID), liftedBody))) case Block(bindings, body) => val newBindings = ArraySeq.newBuilder[Binding] def go(i: Int, scanBindings: Growable[(Name, IR)], aggBindings: Growable[(Name, IR)]) - : IR = { - if (i == bindings.length) { - lift(body, scanBindings, aggBindings) - } else bindings(i) match { + : IR = + if (i == bindings.length) lift(body, scanBindings, aggBindings) + else bindings(i) match { case Binding(name, value, Scope.EVAL) => val lifted = lift(value, scanBindings, aggBindings) val liftedBody = go(i + 1, scanBindings, aggBindings) @@ -496,24 +466,23 @@ object LowerMatrixIR { liftedBody case Binding(name, value, scope) => val ab = ArraySeq.newBuilder[(Name, IR)] - val liftedBody = if (scope == Scope.SCAN) - go(i + 1, ab, aggBindings) - else - go(i + 1, scanBindings, ab) + val liftedBody = + if (scope == Scope.SCAN) go(i + 1, ab, aggBindings) + else go(i + 1, scanBindings, ab) val builder = if (scope == Scope.SCAN) scanBindings else aggBindings val aggs = ab.result() - val structResult = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val structResult = bindingsToStruct(aggs) - val uid = freshName() - builder += (uid -> Block(FastSeq(Binding(name, value, scope)), structResult)) - newBindings ++= aggs.map { case (name, _) => - Binding(name, GetField(Ref(uid, structResult.typ), name.str), Scope.EVAL) - } + val uid = Ref(freshName(), structResult.typ) + builder += (uid.name -> Block(FastSeq(Binding(name, value, scope)), structResult)) + newBindings ++= unwrapStruct(aggs, uid).map(b => + Binding(b._1, b._2, Scope.EVAL) + ) liftedBody } - } + val newBody = go(0, scanBindings, aggBindings) Block(newBindings.result().reverse, newBody) @@ -528,207 +497,176 @@ object LowerMatrixIR { val aggBuilder = ArraySeq.newBuilder[(Name, IR)] val b0 = lift( - Subst(lower(ctx, newCol, liftedRelationalLets), matrixSubstEnvIR(child, loweredChild)), + lower(ctx, newCol, liftedRelationalLets), scanBuilder, aggBuilder, ) + val aggs = aggBuilder.result() val scans = scanBuilder.result() - val idx = Ref(freshName(), TInt32) - val idxSym = Symbol(idx.name.str) - - val noOp: (IRProxy => IRProxy, IRProxy => IRProxy) = (identity[IRProxy], identity[IRProxy]) + val noOp: (IRProxy => IRProxy, IRProxy => IRProxy) = + (identity[IRProxy], identity[IRProxy]) val ( aggOutsideTransformer: (IRProxy => IRProxy), aggInsideTransformer: (IRProxy => IRProxy), - ) = if (aggs.isEmpty) - noOp - else { - val aggStruct = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) - - val aggResult = loweredChild.aggregate( - aggLet(va = 'row.selectFields(child.typ.rowType.fieldNames: _*)) { - makeStruct( - ('count, applyAggOp(Count(), FastSeq(), FastSeq())), - ( - 'array_aggs, - irRange(0, 'global(colsField).len) - .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( - let(sa = 'global(colsField)('__result_idx)) { - aggLet( - sa = 'global(colsField)('__element_idx), - g = 'row(entriesField)('__element_idx), - ) { - aggFilter(!'g.isNA, aggStruct) - } - } - ), - ), + ) = + if (aggs.isEmpty) noOp + else { + val aggResult = + lc.deepCopy.aggregate( + let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in (aggLet( + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in makeStruct( + 'n_rows -> + applyAggOp(Count(), FastSeq(), FastSeq()), + 'array_aggs -> + irRange(0, '__cols.len) + .aggElements('__element_idx, '__result_idx, Some('__cols.len))( + let(sa = '__cols('__result_idx)) in + (aggLet(sa = '__cols('__element_idx), g = '__entries('__element_idx)) in + aggFilter(!'g.isNA, bindingsToStruct(aggs))) + ), + )) ) - } - ) - val ident = freshName() - liftedRelationalLets += ((ident, aggResult)) + val ident = freshName() + liftedRelationalLets += (ident -> aggResult) - val aggResultRef = Ref(freshName(), aggResult.typ) - val aggResultElementRef = Ref( - freshName(), - aggResult.typ.asInstanceOf[TStruct] - .fieldType("array_aggs") - .asInstanceOf[TArray].elementType, - ) + val bindResult: IRProxy => IRProxy = + let( + __agg_result = RelationalRef(ident, aggResult.typ), + __array_aggs = '__agg_result('array_aggs), + n_rows = '__agg_result('n_rows), + ) in _ - val bindResult: IRProxy => IRProxy = letDyn(( - aggResultRef.name, - irToProxy(RelationalRef(ident, aggResult.typ)), - )).apply(_) - val bodyResult: IRProxy => IRProxy = (x: IRProxy) => - letDyn(( - aggResultRef.name, - irToProxy(RelationalRef(ident, aggResult.typ)), - )) - .apply(let( - n_rows = Symbol(aggResultRef.name.str)('count), - array_aggs = Symbol(aggResultRef.name.str)('array_aggs), - ) { - letDyn((aggResultElementRef.name, 'array_aggs(idx))) { - aggs.foldLeft[IRProxy](x) { case (acc, (name, _)) => - letDyn((name, GetField(aggResultElementRef, name.str)))(acc) - } - } - }) - (bindResult, bodyResult) - } + def bodyResult(body: IRProxy): IRProxy = + let(__agg_elem = '__array_aggs('__col_idx)) in + (letDyn(aggs.map { case (n, _) => n -> '__agg_elem(Symbol(n.str)) }: _*) in + body) + + (bindResult, bodyResult _) + } val ( scanOutsideTransformer: (IRProxy => IRProxy), scanInsideTransformer: (IRProxy => IRProxy), - ) = if (scans.isEmpty) - noOp - else { - val scanStruct = bindingsToStruct(scans) - - val scanResultArray = ToArray(StreamAggScan( - ToStream(GetField(Ref(TableIR.globalName, loweredChild.typ.globalType), colsFieldName)), - MatrixIR.colName, - scanStruct, - )) - - val scanResultRef = Ref(freshName(), scanResultArray.typ) - val scanResultElementRef = - Ref(freshName(), scanResultArray.typ.asInstanceOf[TArray].elementType) - - val bindResult: IRProxy => IRProxy = - letDyn((scanResultRef.name, scanResultArray)).apply(_) - val bodyResult: IRProxy => IRProxy = (x: IRProxy) => - letDyn(( - scanResultElementRef.name, - ArrayRef(scanResultRef, idx), - ))( - scans.foldLeft[IRProxy](x) { case (acc, (name, _)) => - letDyn((name, GetField(scanResultElementRef, name.str)))(acc) - } - ) - (bindResult, bodyResult) - } + ) = + if (scans.isEmpty) noOp + else { + val scanStruct = bindingsToStruct(scans) + + val bindResult: IRProxy => IRProxy = + let(__scan_result = '__cols.streamAggScan('sa ~> scanStruct)) in _ - loweredChild.mapGlobals('global.insertFields(colsField -> - aggOutsideTransformer(scanOutsideTransformer(irRange(0, 'global(colsField).len).map( - idxSym ~> let(__cols_array = 'global(colsField), sa = '__cols_array(idxSym)) { - aggInsideTransformer(scanInsideTransformer(b0)) - } - ))))) + def bodyResult(body: IRProxy): IRProxy = + let(__scan_elem = '__scan_result('__col_idx)) in + (letDyn(scans.map { case (n, _) => n -> '__scan_elem(Symbol(n.str)) }: _*) in + body) + + (bindResult, bodyResult _) + } + + lc.mapGlobals( + let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in 'global.insertFields( + colsField -> + aggOutsideTransformer( + scanOutsideTransformer( + irRange(0, '__cols.len).map('__col_idx ~> + (let(sa = '__cols('__col_idx)) in + aggInsideTransformer(scanInsideTransformer(b0)))) + ) + ) + ) + ) case MatrixFilterEntries(child, pred) => - val lc = lower(ctx, child, liftedRelationalLets) - lc.mapRows('row.insertFields(entriesField -> - irRange(0, 'global(colsField).len).map { - 'i ~> - let(g = 'row(entriesField)('i)) { - irIf(let(sa = 'global(colsField)('i)) - in !subst(lower(ctx, pred, liftedRelationalLets), matrixSubstEnv(child))) { - NA(child.typ.entryType) - } { - 'g - } - } - })) + val mtype = child.typ + lower(ctx, child, liftedRelationalLets) + .mapRows( + let( + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(mtype.rowType.fieldNames: _*), + ) in 'row.insertFields( + entriesField -> + irRange(0, '__cols.len).map('i ~> + (let(sa = '__cols('i), g = '__entries('i)) in + irIf(lower(ctx, pred, liftedRelationalLets))('g)(NA(mtype.entryType)))) + ) + ) case MatrixUnionCols(left, right, joinType) => - val rightEntries = genUID() - val rightCols = genUID() - val ll = lower(ctx, left, liftedRelationalLets).distinct() def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy = - if (joinType == "inner") - 'row(entries) - else - irIf('row(entries).isNA) { - irRange(0, 'global(cols).len) - .map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => - (f, NA(right.typ.entryType.fieldType(f))) - )))) - } { - 'row(entries) - } + if (joinType == "inner") 'row(entries) + else let(__entries = 'row(entries)) in + irIf(!'__entries.isNA)('__entries)( + irRange(0, 'global(cols).len).map('a ~> + MakeStruct(right.typ.entryType.fields.map(f => (f.name, NA(f.typ))))) + ) + + val ll = lower(ctx, left, liftedRelationalLets).distinct() val rr = lower(ctx, right, liftedRelationalLets).distinct() TableJoin( ll, - rr.mapRows('row.castRename(rr.typ.rowType.rename(Map(entriesFieldName -> rightEntries)))) + rr.mapRows( + 'row.castRename(rr.typ.rowType.rename(Map(entriesFieldName -> '__right_entries.name))) + ) .mapGlobals('global - .insertFields(Symbol(rightCols) -> 'global(colsField)) - .selectFields(rightCols)), + .insertFields('__right_cols -> 'global(colsField)) + .selectFields('__right_cols.name)), joinType, ) .mapRows('row - .insertFields(entriesField -> - makeArray( - handleMissingEntriesArray(entriesField, colsField), - handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols)), - ) - .flatMap('a ~> 'a)) - .dropFields(Symbol(rightEntries))) + .insertFields( + entriesField -> { + val ls = handleMissingEntriesArray(entriesField, colsField) + val rs = handleMissingEntriesArray('__right_entries, '__right_cols) + makeArray(ls, rs).flatten + } + ) + .dropFields('__right_entries)) .mapGlobals('global - .insertFields(colsField -> - makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a)) - .dropFields(Symbol(rightCols))) + .insertFields( + colsField -> + makeArray('global(colsField), 'global('__right_cols)).flatten + ) + .dropFields('__right_cols)) case MatrixMapEntries(child, newEntries) => - val loweredChild = lower(ctx, child, liftedRelationalLets) - val rt = loweredChild.typ.rowType - val gt = loweredChild.typ.globalType + val lc = lower(ctx, child, liftedRelationalLets) TableMapRows( - loweredChild, - InsertFields( - Ref(TableIR.rowName, rt), - FastSeq( - entriesFieldName -> ToArray( - zip2( - ToStream(GetField(Ref(TableIR.rowName, rt), entriesFieldName)), - ToStream(GetField(Ref(TableIR.globalName, gt), colsFieldName)), - ArrayZipBehavior.AssumeSameLength, - ) { (entries, cols) => - Subst( + lc, + M.eval { + for { + cols <- Name("__cols") -> colVals(lc) + entries <- Name("__entries") -> entries(lc) + _ <- MatrixIR.globalName -> globals(lc) + row <- MatrixIR.rowName -> rowVal(lc) + } yield InsertFields( + row, + FastSeq( + entriesFieldName -> + ToArray(StreamZip( + FastSeq(ToStream(cols), ToStream(entries)), + FastSeq(MatrixIR.colName, MatrixIR.entryName), lower(ctx, newEntries, liftedRelationalLets), - BindingEnv(Env( - MatrixIR.globalName -> SelectFields( - Ref(TableIR.globalName, gt), - child.typ.globalType.fieldNames, - ), - MatrixIR.rowName -> SelectFields( - Ref(TableIR.rowName, rt), - child.typ.rowType.fieldNames, - ), - MatrixIR.colName -> cols, - MatrixIR.entryName -> entries, - )), - ) - } - ) - ), - ), + ArrayZipBehavior.AssumeSameLength, + )) + ), + ) + }, ) case MatrixRepartition(child, n, shuffle) => @@ -749,109 +687,120 @@ object LowerMatrixIR { case MatrixRowsHead(child, n) => TableHead(lower(ctx, child, liftedRelationalLets), n) case MatrixRowsTail(child, n) => TableTail(lower(ctx, child, liftedRelationalLets), n) - case MatrixColsHead(child, n) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields(colsField -> 'global(colsField).arraySlice( - 0, - Some(n), - 1, - ))) + case MatrixColsHead(child, n) => + lower(ctx, child, liftedRelationalLets) + .mapGlobals('global.insertFields('__cols -> 'global('__cols).arraySlice(0, Some(n), 1))) .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(0, Some(n), 1))) - case MatrixColsTail(child, n) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields(colsField -> 'global(colsField).arraySlice(-n, None, 1))) + case MatrixColsTail(child, n) => + lower(ctx, child, liftedRelationalLets) + .mapGlobals('global.insertFields('__cols -> 'global('__cols).arraySlice(-n, None, 1))) .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(-n, None, 1))) case MatrixExplodeCols(child, path) => - val loweredChild = lower(ctx, child, liftedRelationalLets) - val lengths = Symbol(genUID()) - val colIdx = Symbol(genUID()) - val nestedIdx = Symbol(genUID()) - val colElementUID1 = Symbol(genUID()) - - val nestedRefs = - path.init.scanLeft('global(colsField)(colIdx): IRProxy)((irp, name) => irp(Symbol(name))) - val postExplodeSelector = path.zip(nestedRefs).zipWithIndex.foldRight[IRProxy](nestedIdx) { - case (((field, ref), i), arg) => - ref.insertFields(Symbol(field) -> - (if (i == nestedRefs.length - 1) - ref(Symbol(field)).toArray(arg) - else - arg)) - } - - val arrayIR = path.foldLeft[IRProxy](colElementUID1) { case (irp, fieldName) => - irp(Symbol(fieldName)) - } - loweredChild - .mapGlobals('global.insertFields(lengths -> 'global(colsField).map({ - colElementUID1 ~> arrayIR.len.orElse(0) - }))) - .mapGlobals('global.insertFields(colsField -> - irRange(0, 'global(colsField).len, 1) - .flatMap({ - colIdx ~> - irRange(0, 'global(lengths)(colIdx), 1) - .map({ - nestedIdx ~> postExplodeSelector + lower(ctx, child, liftedRelationalLets) + .mapGlobals( + let( + __cols = + 'global(colsField), + __lengths = + '__cols.map('__elem ~> + path + .foldLeft[IRProxy]('__elem) { case (irp, f) => irp(Symbol(f)) } + .len + .orElse(0)), + ) in 'global.insertFields( + '__cols -> + irRange(0, '__cols.len).flatMap('__col_idx ~> { + val nestedRefs = + path.init.scanLeft('__cols('__col_idx))((irp, name) => irp(Symbol(name))) + + irRange(0, '__lengths('__col_idx)).map('__length_idx ~> + path.zip(nestedRefs).zipWithIndex.foldRight[IRProxy]('__length_idx) { + case (((field, ref), i), arg) => + val s = Symbol(field) + ref.insertFields( + s -> (if (i == nestedRefs.length - 1) ref(s).toArray(arg) else arg) + ) }) - }))) - .mapRows('row.insertFields(entriesField -> - irRange(0, 'row(entriesField).len, 1) - .flatMap(colIdx ~> - irRange(0, 'global(lengths)(colIdx), 1).map( - Symbol(genUID()) ~> 'row(entriesField)(colIdx) - )))) - .mapGlobals('global.dropFields(lengths)) + }), + '__lengths -> + '__lengths, + ) + ) + .mapRows( + let(__entries = 'row(entriesField), __lengths = 'global('__lengths)) in + 'row.insertFields( + entriesField -> + irRange(0, '__entries.len).flatMap('__col_idx ~> + irRange(0, '__lengths('__col_idx)).map('__unused ~> + '__entries('__col_idx))) + ) + ) + .mapGlobals('global.dropFields('__lengths)) case MatrixAggregateRowsByKey(child, entryExpr, rowExpr) => - val substEnv = matrixSubstEnv(child) - val eeSub = subst(lower(ctx, entryExpr, liftedRelationalLets), substEnv) - val reSub = subst(lower(ctx, rowExpr, liftedRelationalLets), substEnv) lower(ctx, child, liftedRelationalLets) .aggregateByKey( - reSub.insertFields(entriesField -> irRange(0, 'global(colsField).len) - .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( - let(sa = 'global(colsField)('__result_idx)) { - aggLet( - sa = 'global(colsField)('__element_idx), - g = 'row(entriesField)('__element_idx), - ) { - aggFilter(!'g.isNA, eeSub) - } - } - )) + let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in (aggLet( + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in lower(ctx, rowExpr, liftedRelationalLets).insertFields( + entriesField -> + irRange(0, '__cols.len) + .aggElements('__element_idx, '__result_idx, Some('__cols.len))( + let(sa = '__cols('__result_idx)) in + (aggLet(sa = '__cols('__element_idx), g = '__entries('__element_idx)) in + aggFilter(!'g.isNA, lower(ctx, entryExpr, liftedRelationalLets))) + ) + )) ) case MatrixCollectColsByKey(child) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('newColIdx -> - irRange(0, 'global(colsField).len).map { - 'i ~> - makeTuple('global(colsField)('i).selectFields(child.typ.colKey: _*), 'i) - }.groupByKey.toArray)) - .mapRows('row.insertFields(entriesField -> - 'global('newColIdx).map { - 'kv ~> - makeStruct(child.typ.entryType.fieldNames.map { s => - ( - Symbol(s), - 'kv('value).map { - 'i ~> 'row(entriesField)('i)(Symbol(s)) - }, - ) - }: _*) - })) - .mapGlobals('global - .insertFields(colsField -> - 'global('newColIdx).map { - 'kv ~> - 'kv('key).insertFields( - child.typ.colValueStruct.fieldNames.map { s => - (Symbol(s), 'kv('value).map('i ~> 'global(colsField)('i)(Symbol(s)))) - }: _* - ) - }) - .dropFields('newColIdx)) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global.insertFields( + '__new_col_idx -> + irRange(0, '__cols.len) + .map('i ~> makeTuple('__cols('i).selectFields(child.typ.colKey: _*), 'i)) + .groupByKey + .toArray + ) + ) + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields( + entriesField -> + 'global('__new_col_idx).map { + 'kv ~> + makeStruct(child.typ.entryType.fieldNames.map { f => + val s = Symbol(f) + s -> 'kv('value).map('i ~> '__entries('i)(s)) + }: _*) + } + ) + ) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global + .insertFields( + colsField -> + 'global('__new_col_idx).map('kv ~> + 'kv('key).insertFields( + child.typ.colValueStruct.fieldNames.map { f => + val s = Symbol(f) + s -> 'kv('value).map('i ~> '__cols('i)(s)) + }: _* + )) + ) + .dropFields('__new_col_idx) + ) case MatrixExplodeRows(child, path) => TableExplode(lower(ctx, child, liftedRelationalLets), path) @@ -859,63 +808,53 @@ object LowerMatrixIR { case mr: MatrixRead => mr.lower(ctx) case MatrixAggregateColsByKey(child, entryExpr, colExpr) => - val colKey = child.typ.colKey - - val originalColIdx = Symbol(genUID()) - val newColIdx1 = Symbol(genUID()) - val newColIdx2 = Symbol(genUID()) - val colsAggIdx = Symbol(genUID()) - val keyMap = Symbol(genUID()) - val aggElementIdx = Symbol(genUID()) - - val e1 = Env[IRProxy]( - MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*), - MatrixIR.rowName -> 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) - val e2 = Env[IRProxy]( - MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*) - ) - val ceSub = - subst(lower(ctx, colExpr, liftedRelationalLets), BindingEnv(e2, agg = Some(e2))) - val eeSub = - subst(lower(ctx, entryExpr, liftedRelationalLets), BindingEnv(e1, agg = Some(e1))) - lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields(keyMap -> - let(__cols_field = 'global(colsField)) { - irRange(0, '__cols_field.len) - .map(originalColIdx ~> let(__cols_field_element = '__cols_field(originalColIdx)) { - makeStruct( - 'key -> '__cols_field_element.selectFields(colKey: _*), - 'value -> originalColIdx, - ) - }) - .groupByKey - .toArray - })) - .mapRows('row.insertFields(entriesField -> - let(__entries = 'row(entriesField), __key_map = 'global(keyMap)) { - irRange(0, '__key_map.len) - .map(newColIdx1 ~> '__key_map(newColIdx1) - .apply('value) - .streamAgg(aggElementIdx ~> - aggLet(g = '__entries(aggElementIdx), sa = 'global(colsField)(aggElementIdx)) { - aggFilter(!'g.isNA, eeSub) - })) - })) .mapGlobals( - 'global.insertFields(colsField -> - let(__key_map = 'global(keyMap)) { - irRange(0, '__key_map.len) - .map(newColIdx2 ~> + let(__cols = 'global(colsField)) in + 'global.insertFields( + '__key_map -> + irRange(0, '__cols.len) + .map('__old_col_idx ~> + (let(__elem = '__cols('__old_col_idx)) in + makeStruct( + 'key -> '__elem.selectFields(child.typ.colKey: _*), + 'value -> '__old_col_idx, + ))) + .groupByKey + .toArray + ) + ) + .mapRows( + let( + __key_map = 'global('__key_map), + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in 'row.insertFields( + entriesField -> + irRange(0, '__key_map.len).map('__new_col_idx ~> + '__key_map('__new_col_idx)('value).streamAgg('__agg_idx ~> + (aggLet(sa = '__cols('__agg_idx), g = '__entries('__agg_idx)) in + aggFilter(!'g.isNA, lower(ctx, entryExpr, liftedRelationalLets))))) + ) + ) + .mapGlobals( + let( + __cols = 'global(colsField), + __key_map = 'global('__key_map), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in 'global.insertFields( + colsField -> + irRange(0, '__key_map.len).map('__new_col_idx ~> + (let(__elem = '__key_map('__new_col_idx)) in concatStructs( - '__key_map(newColIdx2)('key), - '__key_map(newColIdx2)('value) - .streamAgg(colsAggIdx ~> aggLet(sa = 'global(colsField)(colsAggIdx)) { - ceSub - }), - )) - }).dropFields(keyMap) + '__elem('key), + '__elem('value).streamAgg('__agg_idx ~> + (aggLet(sa = '__cols('__agg_idx)) in + lower(ctx, colExpr, liftedRelationalLets))), + ))) + ) ) case MatrixLiteral(_, tl) => tl @@ -928,85 +867,100 @@ object LowerMatrixIR { lowered } - private[this] def lower(ctx: ExecuteContext, tir: TableIR, ab: Growable[(Name, IR)]): TableIR = { + private def lower(ctx: ExecuteContext, tir: TableIR, ab: Growable[(Name, IR)]): TableIR = { val lowered = tir match { case CastMatrixToTable(child, entries, cols) => lower(ctx, child, ab) - .mapRows('row.selectFields(child.typ.rowType.fieldNames ++ Array(entriesFieldName): _*)) - .mapGlobals('global.selectFields( - child.typ.globalType.fieldNames ++ Array(colsFieldName): _* - )) + .mapRows('row.selectFields(child.typ.rowType.fieldNames :+ entriesFieldName: _*)) + .mapGlobals('global.selectFields(child.typ.globalType.fieldNames :+ colsFieldName: _*)) .rename(Map(entriesFieldName -> entries), Map(colsFieldName -> cols)) case x @ MatrixEntriesTable(child) => val lc = lower(ctx, child, ab) if (child.typ.rowKey.nonEmpty && child.typ.colKey.nonEmpty) { - val oldColIdx = Symbol(genUID()) - val lambdaIdx1 = Symbol(genUID()) - val lambdaIdx2 = Symbol(genUID()) - val lambdaIdx3 = Symbol(genUID()) - val toExplode = Symbol(genUID()) - val values = Symbol(genUID()) lc - .mapGlobals('global.insertFields(oldColIdx -> - irRange(0, 'global(colsField).len) - .map(lambdaIdx1 ~> makeStruct( - 'key -> 'global(colsField)(lambdaIdx1).selectFields(child.typ.colKey: _*), - 'value -> lambdaIdx1, - )) - .sort(ascending = true, onKey = true) - .map(lambdaIdx1 ~> lambdaIdx1('value)))) - .aggregateByKey(makeStruct(values -> applyAggOp( - Collect(), - seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*)), - ))) - .mapRows('row.dropFields(values).insertFields(toExplode -> - 'global(oldColIdx) - .flatMap(lambdaIdx1 ~> 'row(values) - .filter(lambdaIdx2 ~> !lambdaIdx2(entriesField)(lambdaIdx1).isNA) - .map(lambdaIdx3 ~> let( - __col = 'global(colsField)(lambdaIdx1), - __entry = lambdaIdx3(entriesField)(lambdaIdx1), - ) { - makeStruct( - child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => - f -> lambdaIdx3(f) - ) ++ - child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col(f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry(f)): _* - ) - })))) - .explode(toExplode) - .mapRows(makeStruct(x.typ.rowType.fieldNames.map { f => - val fd = Symbol(f) - (fd, if (child.typ.rowKey.contains(f)) 'row(fd) else 'row(toExplode)(fd)) - }: _*)) - .mapGlobals('global.dropFields(colsField, oldColIdx)) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global.insertFields( + '__old_col_idx -> + irRange(0, '__cols.len) + .map('__col_idx ~> + makeStruct( + 'key -> '__cols('__col_idx).selectFields(child.typ.colKey: _*), + 'value -> '__col_idx, + )) + .sort(ascending = true, onKey = true) + .map('__elem ~> '__elem('value)) + ) + ) + .aggregateByKey(makeStruct( + '__values -> + applyAggOp( + Collect(), + seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*)), + ) + )) + .mapRows( + let(__cols = 'global(colsField)) in + 'row.dropFields('__values).insertFields( + '__explode -> + 'global('__old_col_idx).flatMap('__old_col_idx ~> + (let(__col = '__cols('__old_col_idx)) in + 'row('__values) + .filter('__v ~> !'__v(entriesField)('__old_col_idx).isNA) + .map('__v ~> + (let(__entry = '__v(entriesField)('__old_col_idx)) in + makeStruct( + child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => + f -> '__v(f) + ) ++ + child.typ.colType.fieldNames.map(Symbol(_)).map(f => + f -> '__col(f) + ) ++ + child.typ.entryType.fieldNames.map(Symbol(_)).map(f => + f -> '__entry(f) + ): _* + ))))) + ) + ) + .explode('__explode) + .mapRows( + let(__exploded = 'row('__explode)) in + makeStruct(x.typ.rowType.fieldNames.map { f => + val fd = Symbol(f) + (fd, if (child.typ.rowKey.contains(f)) 'row(fd) else '__exploded(fd)) + }: _*) + ) + .mapGlobals('global.dropFields(colsField, '__old_col_idx)) .keyBy(child.typ.rowKey ++ child.typ.colKey, isSorted = true) } else { - val colIdx = Symbol(genUID()) - val lambdaIdx = Symbol(genUID()) - val result = lc - .mapRows('row.insertFields(colIdx -> irRange(0, 'global(colsField).len) - .filter(lambdaIdx ~> !'row(entriesField)(lambdaIdx).isNA))) - .explode(colIdx) - .mapRows(let( - __col_struct = 'global(colsField)('row(colIdx)), - __entry_struct = 'row(entriesField)('row(colIdx)), - ) { - val newFields = - child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct(f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct(f)) - - 'row.dropFields(entriesField, colIdx).insertFieldsList( - newFields, - ordering = Some(x.typ.rowType.fieldNames), + val result = + lc + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields( + '__col_idx -> + irRange(0, 'global(colsField).len) + .filter('__idx ~> !'__entries('__idx).isNA) + ) ) - }) - .mapGlobals('global.dropFields(colsField)) - if (child.typ.colKey.isEmpty) - result + .explode('__col_idx) + .mapRows { + val newFields = + child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct(f)) ++ + child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct(f)) + + let( + __col_struct = 'global(colsField)('row('__col_idx)), + __entry_struct = 'row(entriesField)('row('__col_idx)), + ) in 'row + .dropFields(entriesField, '__col_idx) + .insertFieldsList(newFields, ordering = Some(x.typ.rowType.fieldNames)) + } + .mapGlobals('global.dropFields(colsField)) + + if (child.typ.colKey.isEmpty) result else { assert(child.typ.rowKey.isEmpty) result.keyBy(child.typ.colKey) @@ -1033,23 +987,23 @@ object LowerMatrixIR { case MatrixColsTable(child) => val colKey = child.typ.colKey - let(__cols_and_globals = lower(ctx, child, ab).getGlobals) { - val sortedCols = if (colKey.isEmpty) - '__cols_and_globals(colsField) - else - '__cols_and_globals(colsField).map { - '__cols_element ~> - makeStruct( - // key struct - '_1 -> '__cols_element.selectFields(colKey: _*), - '_2 -> '__cols_element, - ) - }.sort(true, onKey = true) - .map { - 'elt ~> 'elt('_2) - } - makeStruct('rows -> sortedCols, 'global -> '__cols_and_globals.dropFields(colsField)) - }.parallelize(None).keyBy(child.typ.colKey) + + val sortedCols = + if (colKey.isEmpty) '__cols_and_global(colsField) + else '__cols_and_global(colsField) + .map('__cols_element ~> + makeStruct( + // key struct + '_1 -> '__cols_element.selectFields(colKey: _*), + '_2 -> '__cols_element, + )) + .sort(true, onKey = true) + .map('elt ~> 'elt('_2)) + + (let(__cols_and_global = lower(ctx, child, ab).getGlobals) in + makeStruct('rows -> sortedCols, 'global -> '__cols_and_global.dropFields(colsField))) + .parallelize(None) + .keyBy(child.typ.colKey) case table => lowerChildren(ctx, table, ab).asInstanceOf[TableIR] } @@ -1058,26 +1012,26 @@ object LowerMatrixIR { lowered } - private[this] def lower(ctx: ExecuteContext, bmir: BlockMatrixIR, ab: Growable[(Name, IR)]) + private def lower(ctx: ExecuteContext, bmir: BlockMatrixIR, ab: Growable[(Name, IR)]) : BlockMatrixIR = { - val lowered = bmir match { - case noMatrixChildren => lowerChildren(ctx, noMatrixChildren, ab).asInstanceOf[BlockMatrixIR] - } + val lowered = lowerChildren(ctx, bmir, ab).asInstanceOf[BlockMatrixIR] assertTypeUnchanged(bmir, lowered) lowered } - private[this] def lower(ctx: ExecuteContext, ir: IR, ab: Growable[(Name, IR)]): IR = { + private def lower(ctx: ExecuteContext, ir: IR, ab: Growable[(Name, IR)]): IR = { val lowered = ir match { - case MatrixToValueApply(child, function) => TableToValueApply( + case MatrixToValueApply(child, function) => + TableToValueApply( lower(ctx, child, ab), - function.lower() - .getOrElse(WrappedMatrixToValueFunction( + function.lower().getOrElse( + WrappedMatrixToValueFunction( function, colsFieldName, entriesFieldName, child.typ.colKey, - )), + ) + ), ) case MatrixWrite(child, writer) => TableWrite( @@ -1096,29 +1050,33 @@ object LowerMatrixIR { val lc = lower(ctx, child, ab) TableAggregate( lc, - aggExplodeIR( - filterIR( - zip2( - ToStream(GetField(Ref(TableIR.rowName, lc.typ.rowType), entriesFieldName)), - ToStream(GetField(Ref(TableIR.globalName, lc.typ.globalType), colsFieldName)), - ArrayZipBehavior.AssertSameLength, - ) { case (e, c) => - MakeTuple.ordered(FastSeq(e, c)) + Let( + FastSeq(MatrixIR.globalName -> globals(lc)), + M.agg { + for { + cols <- Name("__cols") -> colVals(lc) + entries <- Name("__entries") -> entries(lc) + _ <- MatrixIR.globalName -> globals(lc) + _ <- MatrixIR.rowName -> rowVal(lc) + } yield aggExplodeIR( + filterIR( + zip2( + ToStream(cols), + ToStream(entries), + ArrayZipBehavior.AssertSameLength, + ) { + (c, e) => maybeIR(e)(e => maketuple(c, e)) + } + )(r => ApplyUnaryPrimOp(Bang, IsNA(r))) + ) { explodedTuple => + M.agg { + (MatrixIR.colName -> GetTupleElement(explodedTuple, 0)) >> + (MatrixIR.entryName -> GetTupleElement(explodedTuple, 1)) >> + query + } } - )(filterTuple => ApplyUnaryPrimOp(Bang, IsNA(GetTupleElement(filterTuple, 0)))) - ) { explodedTuple => - AggLet( - MatrixIR.entryName, - GetTupleElement(explodedTuple, 0), - AggLet( - MatrixIR.colName, - GetTupleElement(explodedTuple, 1), - Subst(query, matrixSubstEnvIR(child, lc)), - isScan = false, - ), - isScan = false, - ) - }, + }, + ), ) case _ => lowerChildren(ctx, ir, ab).asInstanceOf[IR] } @@ -1126,7 +1084,7 @@ object LowerMatrixIR { lowered } - private[this] def assertTypeUnchanged(original: BaseIR, lowered: BaseIR): Unit = + private def assertTypeUnchanged(original: BaseIR, lowered: BaseIR): Unit = if (lowered.typ != original.typ) fatal( s"lowering changed type:\n before: ${original.typ}\n after: ${lowered.typ}\n ${original.getClass.getName} => ${lowered.getClass.getName}" diff --git a/hail/hail/src/is/hail/expr/ir/Optimize.scala b/hail/hail/src/is/hail/expr/ir/Optimize.scala index 9bdd92a30f9..9d0fbee4296 100644 --- a/hail/hail/src/is/hail/expr/ir/Optimize.scala +++ b/hail/hail/src/is/hail/expr/ir/Optimize.scala @@ -54,7 +54,7 @@ object Optimize { catch { case NonFatal(e) => fatal( - s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir, preserveNames = true)}", + s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir)}", e, ) } diff --git a/hail/hail/src/is/hail/expr/ir/Pretty.scala b/hail/hail/src/is/hail/expr/ir/Pretty.scala index a25d72fa8d5..24fa206b16f 100644 --- a/hail/hail/src/is/hail/expr/ir/Pretty.scala +++ b/hail/hail/src/is/hail/expr/ir/Pretty.scala @@ -28,7 +28,7 @@ object Pretty { elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false, - preserveNames: Boolean = false, + preserveNames: Boolean = true, ): String = { val useSSA = ctx != null && ctx.getFlag("use_ssa_logs") != null val pretty = @@ -56,7 +56,7 @@ object Pretty { elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false, - preserveNames: Boolean = false, + preserveNames: Boolean = true, ): String = { val pretty = new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA = true, @@ -773,7 +773,7 @@ class Pretty( if (i == 1) some(MatrixIR.globalName -> "g") else None case _: MatrixMapRows => - if (i == 1) matrixBlockArgs map { _ :+ (Name("n_rows") -> "n_rows") } + if (i == 1) matrixBlockArgs map { _ :+ (Name("n_cols") -> "n_cols") } else None case NDArrayMap(_, name, _) => if (i == 1) some(name -> "elt") diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index faf4a2a6dfb..474e07ed64d 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -708,7 +708,10 @@ object Extract { val init = Begin(initOps) initBuilder += InitOp( i, - knownLength.fold(ArraySeq(init))(ArraySeq(_, init)), + knownLength.fold(ArraySeq(init)) { ir => + bindInitArgRefs(FastSeq(ir)) + ArraySeq(ir, init) + }, checkSig, ) diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala index dcf82b6481d..50fcffc7766 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala @@ -898,14 +898,14 @@ class BlockMatrixStage2 private ( RVDPartitioner.unkeyed(ctx.stateManager, bmTyp.nDefinedBlocks), TableStageDependency.none, contextsIR, - { newCtxRef => - val s = makestruct( - "blockRow" -> GetTupleElement(newCtxRef, 0), - "blockCol" -> GetTupleElement(newCtxRef, 1), - "block" -> Let(FastSeq(ctxRefName -> GetTupleElement(newCtxRef, 2)), _blockIR), - ) - MakeStream(FastSeq(s), TStream(s.typ)) - }, + newCtxRef => + MakeStream.single( + makestruct( + "blockRow" -> GetTupleElement(newCtxRef, 0), + "blockCol" -> GetTupleElement(newCtxRef, 1), + "block" -> Let(FastSeq(ctxRefName -> GetTupleElement(newCtxRef, 2)), _blockIR), + ) + ), ) } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala b/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala index b49d351a931..e7555939a8a 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala @@ -61,15 +61,9 @@ case class OptimizePass(_context: String) extends LoweringPass { case object LowerMatrixToTablePass extends LoweringPass { override val context: String = "LowerMatrixToTable" - override def before: Invariant = AnyIR + override def before: Invariant = LowerableIR override def after: Invariant = before and NoMatrixIR - - override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = ir match { - case x: IR => LowerMatrixIR(ctx, x) - case x: TableIR => LowerMatrixIR(ctx, x) - case x: MatrixIR => LowerMatrixIR(ctx, x) - case x: BlockMatrixIR => LowerMatrixIR(ctx, x) - } + override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = LowerMatrixIR(ctx, ir) } case object LiftRelationalValuesToRelationalLets extends LoweringPass { diff --git a/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala b/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala index 8c6373c1961..6772fbcbf01 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala @@ -41,7 +41,7 @@ package invariant { IRTraversal.trace(ir).foreach { case trace @ ir :: _ => if (!invariant(ir)) throw new UnsatisfiedInvariantError( s"""Invariant ${E.value} forbids - |${trace.take(5).map(Pretty(ctx, _, preserveNames = true)).mkString("\nin\n")} + |${trace.take(5).map(Pretty(ctx, _)).mkString("\nin\n")} |""".stripMargin ) } @@ -102,9 +102,9 @@ package object invariant { !newNames.add(name) || names.put(name, ir).forall { orig => throw new UnsatisfiedInvariantError( s"""Invariant ${implicitly[Enclosing].value} forbids redefinition of '$name' in - |${Pretty.ssaStyle(ir, preserveNames = true)} + |${Pretty.ssaStyle(ir)} |Originally bound in - |${Pretty.ssaStyle(orig, preserveNames = true)}""".stripMargin + |${Pretty.ssaStyle(orig)}""".stripMargin ) } } diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index e6928ba3f9f..f5e297e4d22 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -286,7 +286,7 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { } def assertBMEvalsTo( - bm: BlockMatrixIR, + bm0: BlockMatrixIR, expected: DenseMatrix[Double], )(implicit execStrats: Set[ExecStrategy] ): Unit = { @@ -296,6 +296,8 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { logger.info("skipping interpret and non-lowering compile steps on non-spark backend") execStrats.intersect(ExecStrategy.backendOnly) } + + val bm = bm0.deepCopy filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => try { val res = diff --git a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala index c5055b0dedb..f92431f57c1 100644 --- a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala @@ -990,27 +990,27 @@ class Aggregators2Suite extends HailSuite { TStruct("row_idx" -> TInt32), TStruct.empty, ) - val ir = TableCollect(MatrixColsTable(MatrixMapCols( - MatrixRead(t, false, false, MatrixRangeReader(ctx, 10, 10, None)), - InsertFields( - Ref(MatrixIR.colName, t.colType), - FastSeq(( - "foo", - bindIR(GetField(Ref(MatrixIR.colName, t.colType), "col_idx") + I32(1)) { bar => - AggFilter( - GetField(Ref(MatrixIR.rowName, t.rowType), "row_idx") < I32(5), - bar.toL + bar.toL + ApplyAggOp( - FastSeq(), - FastSeq(GetField(Ref(MatrixIR.rowName, t.rowType), "row_idx").toL), - Sum(), - ), - false, - ) - }, - )), - ), - Some(FastSeq()), - ))) + + val col: Atom = Ref(MatrixIR.colName, t.colType) + val row: Atom = Ref(MatrixIR.rowName, t.rowType) + + val ir = TableCollect( + MatrixColsTable( + MatrixMapCols( + MatrixRead(t, false, false, MatrixRangeReader(ctx, 10, 10, None)), + insertIR( + col, + "foo" -> bindIR(GetField(col, "col_idx").toL + 1L) { colIdx => + aggBindIR(GetField(row, "row_idx")) { rowIdx => + AggFilter(rowIdx < 5, colIdx + colIdx + ApplyAggOp(Sum())(rowIdx.toL), false) + } + }, + ), + Some(FastSeq()), + ) + ) + ) + assertEvalsTo(ir, Row((0 until 10).map(i => Row(i, 2L * i + 12L)), Row()))( ExecStrategy.interpretOnly ) diff --git a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala index d69831ecaad..a6ee73035ee 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala @@ -124,7 +124,7 @@ class MatrixIRSuite extends HailSuite { val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) val newRow = - InsertFields(oldRow, FastSeq("range" -> IRScanCollect(GetField(oldRow, "row_idx")))) + InsertFields(oldRow.ir, FastSeq("range" -> IRScanCollect(GetField(oldRow, "row_idx")))) val newMatrix = MatrixMapRows(mt, newRow) val rows = getRows(newMatrix) @@ -138,7 +138,7 @@ class MatrixIRSuite extends HailSuite { val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) val newRow = InsertFields( - oldRow, + oldRow.ir, FastSeq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldRow, "row_idx").toL)), ) @@ -165,7 +165,7 @@ class MatrixIRSuite extends HailSuite { val oldCol = Ref(MatrixIR.colName, mt.typ.colType) val newCol = - InsertFields(oldCol, FastSeq("range" -> IRScanCollect(GetField(oldCol, "col_idx")))) + InsertFields(oldCol.ir, FastSeq("range" -> IRScanCollect(GetField(oldCol, "col_idx")))) val newMatrix = MatrixMapCols(mt, newCol, None) val cols = getCols(newMatrix) @@ -179,7 +179,7 @@ class MatrixIRSuite extends HailSuite { val oldCol = Ref(MatrixIR.colName, mt.typ.colType) val newCol = InsertFields( - oldCol, + oldCol.ir, FastSeq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldCol, "col_idx").toL)), ) @@ -199,7 +199,7 @@ class MatrixIRSuite extends HailSuite { MatrixKeyRowsBy(baseRange, FastSeq()), InsertFields( row, - FastSeq("row_idx" -> (GetField(row, "row_idx") + start)), + FastSeq("row_idx" -> (GetField(row.ir, "row_idx") + start)), ), ), FastSeq("row_idx"), diff --git a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala index 96c3c12fff8..00f8b717439 100644 --- a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala @@ -6,10 +6,9 @@ import is.hail.collection.FastSeq import is.hail.expr.ir._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{ - ApplyBinaryPrimOp, Atom, ErrorIDs, GetField, MakeStream, MakeStruct, Ref, Str, StreamRange, - TableAggregate, TableGetGlobals, + Atom, ErrorIDs, GetField, MakeStream, MakeStruct, Ref, Str, StreamRange, TableAggregate, + TableGetGlobals, } -import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ import is.hail.utils.{HailException, Interval} @@ -111,21 +110,19 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testLowering(): Unit = { - val table = collect(mkTableGen()) - val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) - assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) + val rows = collect(mkTableGen()) + assertEvalsTo(rows, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } @Test(groups = Array("lowering")) def testNumberOfContextsMatchesPartitions(): Unit = { val errorId = 42 - val table = collect(mkTableGen( + val rows = collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), errorId = Some(errorId), )) - val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[HailException] { - loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) + loweredExecute(ctx, rows, Env.empty, FastSeq(), None) } ex.errorId shouldBe errorId ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") @@ -134,7 +131,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testRowsAreCorrectlyKeyed(): Unit = { val errorId = 56 - val table = collect(mkTableGen( + val rows = collect(mkTableGen( partitioner = Some(new RVDPartitioner( ctx.stateManager, TStruct("a" -> TInt32), @@ -145,9 +142,8 @@ class TableGenSuite extends HailSuite { )), errorId = Some(errorId), )) - val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[SparkException] { - loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) + loweredExecute(ctx, rows, Env.empty, FastSeq(), None) }.getCause.asInstanceOf[HailException] ex.errorId shouldBe errorId @@ -195,19 +191,13 @@ class TableGenSuite extends HailSuite { body: Option[(Atom, Atom) => IR] = None, partitioner: Option[RVDPartitioner] = None, errorId: Option[Int] = None, - ): TableGen = { + ): TableGen = tableGen( contexts.getOrElse(StreamRange(0, 2, 1)), - globals.getOrElse(MakeStruct(IndexedSeq("g" -> 0))), + globals.getOrElse(makestruct("g" -> 0)), partitioner.getOrElse(RVDPartitioner.unkeyed(ctx.stateManager, 2)), errorId.getOrElse(ErrorIDs.NO_ERROR), )( - body.getOrElse { (c, g) => - val elem = MakeStruct(IndexedSeq( - "a" -> ApplyBinaryPrimOp(Multiply(), c, GetField(g, "g")) - )) - MakeStream(IndexedSeq(elem), TStream(elem.typ)) - } + body.getOrElse((c, g) => MakeStream.single(makestruct("a" -> c * GetField(g, "g")))) ) - } } From 17ae4e1e5af3c2a47fa1eacf833ca20c059a45ae Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 8 Jun 2026 23:13:25 -0400 Subject: [PATCH 5/5] [query] LowerMatrixIR without Symbol --- .../is/hail/expr/ir/DeprecatedIRBuilder.scala | 458 ++----- hail/hail/src/is/hail/expr/ir/IR.scala | 8 +- hail/hail/src/is/hail/expr/ir/IRBuilder.scala | 3 + .../src/is/hail/expr/ir/LowerMatrixIR.scala | 1141 +++++++++-------- hail/hail/src/is/hail/expr/ir/MatrixIR.scala | 66 +- .../src/is/hail/expr/ir/MatrixValue.scala | 10 +- .../hail/src/is/hail/expr/ir/TableValue.scala | 8 +- .../src/is/hail/expr/ir/TableWriter.scala | 2 +- .../is/hail/expr/ir/AggregatorsSuite.scala | 39 +- 9 files changed, 739 insertions(+), 996 deletions(-) diff --git a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala index 5d24fbe7053..feb82cd7d59 100644 --- a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala +++ b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala @@ -1,99 +1,38 @@ package is.hail.expr.ir import is.hail.collection.FastSeq -import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.implicits.toRichIterable -import is.hail.expr.ir.Scope._ import is.hail.expr.ir.defs._ import is.hail.types.virtual._ -import scala.language.dynamics - object DeprecatedIRBuilder { - type E = BindingEnv[Type] - - implicit def funcToIRProxy(ir: E => IR): IRProxy = new IRProxy(ir) - - implicit def tableIRToProxy(tir: TableIR): TableIRProxy = - new TableIRProxy(tir) - - implicit def irToProxy(ir: IR): IRProxy = (_: E) => ir - - implicit def strToProxy(s: String): IRProxy = Str(s) - - implicit def intToProxy(i: Int): IRProxy = I32(i) - - implicit def booleanToProxy(b: Boolean): IRProxy = if (b) True() else False() - - implicit def ref(s: Symbol): IRProxy = (env: E) => - Ref(Name(s.name), env.eval.lookup(Name(s.name))) - - implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s) - - implicit def arrayToProxy(seq: IndexedSeq[IRProxy]): IRProxy = (env: E) => { - val irs = seq.map(_(env)) - val elType = irs.head.typ - MakeArray(irs, TArray(elType)) - } - - implicit def arrayIRToProxy(seq: IndexedSeq[IR]): IRProxy = arrayToProxy(seq.map(irToProxy)) - - def irRange(start: IRProxy, end: IRProxy, step: IRProxy = 1): IRProxy = (env: E) => - ToArray(StreamRange(start(env), end(env), step(env))) - - def irArrayLen(a: IRProxy): IRProxy = (env: E) => ArrayLen(a(env)) - - def irIf(cond: IRProxy)(cnsq: IRProxy)(altr: IRProxy): IRProxy = (env: E) => - If(cond(env), cnsq(env), altr(env)) - - def irDie(message: IRProxy, typ: Type): IRProxy = (env: E) => - Die(message(env), typ, -1) - - def makeArray(first: IRProxy, rest: IRProxy*): IRProxy = - arrayToProxy(first +: rest.toFastSeq) - - def makeStruct(fields: (Symbol, IRProxy)*): IRProxy = (env: E) => - MakeStruct(fields.toFastSeq.map { case (s, ir) => (s.name, ir(env)) }) - def concatStructs(struct1: IRProxy, struct2: IRProxy): IRProxy = (env: E) => { - val s2Type = struct2(env).typ.asInstanceOf[TStruct] - (let(__struct2 = struct2) in - struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*))(env) - } + implicit def tableIRToProxy(ir: TableIR): TableIRProxy = + new TableIRProxy(ir) - def makeTuple(values: IRProxy*): IRProxy = (env: E) => - MakeTuple.ordered(values.toFastSeq.map(_(env))) - - def applyAggOp( - op: AggOp, - initOpArgs: IndexedSeq[IRProxy] = FastSeq(), - seqOpArgs: IndexedSeq[IRProxy] = FastSeq(), - ): IRProxy = (env: E) => { - val i = initOpArgs.map(x => x(env.noAgg)) - val s = seqOpArgs.map(x => x(env.promoteAgg)) - ApplyAggOp(i, s, op) - } + implicit def irToProxy(ir: IR): IRProxy = + new IRProxy(ir) - def aggFilter(filterCond: IRProxy, query: IRProxy, isScan: Boolean = false): IRProxy = (env: E) => - AggFilter(filterCond(env.promoteAgg), query(env), isScan) + implicit def atomToIRProxy(a: Atom): IRProxy = + new IRProxy(a) class TableIRProxy(val tir: TableIR) extends AnyVal { - def empty: E = BindingEnv.empty - def typ: TableType = tir.typ + def global: Atom = Ref(TableIR.rowName, tir.typ.globalType) + def row: Atom = Ref(TableIR.rowName, tir.typ.rowType) def getGlobals: IR = TableGetGlobals(tir) - def mapGlobals(newGlobals: IRProxy): TableIR = - TableMapGlobals(tir, newGlobals(BindingEnv(typ.globalEnv))) + def mapGlobals(f: Atom => IR): TableIR = + TableMapGlobals(tir, f(global)) - def mapRows(newRow: IRProxy): TableIR = - TableMapRows(tir, newRow(BindingEnv(typ.rowEnv, scan = Some(typ.rowEnv)))) + def mapRows(f: (Atom, Atom) => IR): TableIR = + TableMapRows(tir, f(global, row)) - def explode(sym: Symbol): TableIR = TableExplode(tir, FastSeq(sym.name)) + def explode(sym: String*): TableIR = TableExplode(tir, sym.toFastSeq) - def aggregateByKey(aggIR: IRProxy): TableIR = - TableAggregateByKey(tir, aggIR(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv)))) + def aggregateByKey(f: (Atom, Atom) => IR): TableIR = + TableAggregateByKey(tir, f(global, row)) def keyBy(keys: IndexedSeq[String], isSorted: Boolean = false): TableIR = TableKeyBy(tir, keys, isSorted) @@ -101,345 +40,112 @@ object DeprecatedIRBuilder { def rename(rowMap: Map[String, String], globalMap: Map[String, String] = Map.empty): TableIR = TableRename(tir, rowMap, globalMap) - def renameGlobals(globalMap: Map[String, String]): TableIR = - rename(Map.empty, globalMap) - - def filter(ir: IRProxy): TableIR = - TableFilter(tir, ir(BindingEnv(typ.rowEnv))) + def filter(f: (Atom, Atom) => IR): TableIR = + TableFilter(tir, f(global, row)) - def distinct(): TableIR = TableDistinct(tir) + def distinct: TableIR = TableDistinct(tir) - def collect(): IRProxy = TableCollect(tir) + def collect: IR = TableCollect(tir) - def collectAsDict(): IRProxy = { - val uid = genUID() - val keyFields = tir.typ.key - val valueFields = tir.typ.valueType.fieldNames + def collectAsDict: IR = keyBy(FastSeq()) - .collect() - .apply('rows) - .map(Symbol(uid) ~> makeTuple( - Symbol(uid).selectFields(keyFields: _*), - Symbol(uid).selectFields(valueFields: _*), - )) + .collect + .get("rows") + .stream + .streamMap(row => maketuple(row.select(tir.typ.key), row.drop(tir.typ.key))) .toDict - } - def aggregate(ir: IRProxy): IR = - TableAggregate(tir, ir(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv)))) + def aggregate(f: (Atom, Atom) => IR): IR = + TableAggregate(tir, f(global, row)) } - class IRProxy(val ir: E => IR) extends AnyVal with Dynamic { - def apply(idx: IRProxy): IRProxy = (env: E) => - ArrayRef(ir(env), idx(env)) - - def invoke(name: String, rt: Type, args: IRProxy*): IRProxy = { env: E => - val irArgs = ArraySeq(ir(env)) ++ args.map(_(env)) - is.hail.expr.ir.invoke(name, rt, irArgs: _*) - } - - def selectDynamic(field: String): IRProxy = (env: E) => - GetField(ir(env), field) - - def +(other: IRProxy): IRProxy = (env: E) => ApplyBinaryPrimOp(Add(), ir(env), other(env)) - - def -(other: IRProxy): IRProxy = (env: E) => ApplyBinaryPrimOp(Subtract(), ir(env), other(env)) + class IRProxy(val ir: IR) extends AnyVal { - def *(other: IRProxy): IRProxy = (env: E) => ApplyBinaryPrimOp(Multiply(), ir(env), other(env)) + def invoke(name: String, rt: Type, args: IR*): IR = + is.hail.expr.ir.invoke(name, rt, ir +: args: _*) - def /(other: IRProxy): IRProxy = - (env: E) => ApplyBinaryPrimOp(FloatingPointDivide(), ir(env), other(env)) + def get(name: String): IR = + GetField(ir, name) - def floorDiv(other: IRProxy): IRProxy = - (env: E) => ApplyBinaryPrimOp(RoundToNegInfDivide(), ir(env), other(env)) + def get(idx: Int): IR = + GetTupleElement(ir, idx) - def &&(other: IRProxy): IRProxy = invoke("land", TBoolean, ir, other) + def at(idx: IR): IR = + ArrayRef(ir, idx) - def ||(other: IRProxy): IRProxy = invoke("lor", TBoolean, ir, other) + def rename(f: TStruct => TStruct): IR = + CastRename(ir, f(ir.typ.asInstanceOf[TStruct])) - def toI: IRProxy = (env: E) => Cast(ir(env), TInt32) + def update(field: String)(f: Atom => IR): IR = + bindIR(ir.get(field))(x => ir.insert(field -> f(x))) - def toL: IRProxy = (env: E) => Cast(ir(env), TInt64) + def insert(fields: (String, IR)*): IR = + InsertFields(ir, fields.toFastSeq) - def toF: IRProxy = (env: E) => Cast(ir(env), TFloat32) + def insert(fields: IndexedSeq[(String, IR)], ordering: Option[IndexedSeq[String]] = None): IR = + InsertFields(ir, fields, ordering) - def toD: IRProxy = (env: E) => Cast(ir(env), TFloat64) + def select(fields: String*): IR = + ir.select(fields.toFastSeq) - def unary_- : IRProxy = (env: E) => ApplyUnaryPrimOp(Negate, ir(env)) + def select(fields: IndexedSeq[String]): IR = + SelectFields(ir, fields) - def unary_! : IRProxy = (env: E) => ApplyUnaryPrimOp(Bang, ir(env)) - - def ceq(other: IRProxy): IRProxy = (env: E) => { - val left = ir(env) - val right = other(env) - ApplyComparisonOp(EQWithNA, left, right) + def drop(fields: IndexedSeq[String]): IR = { + val typ = ir.typ.asInstanceOf[TStruct] + SelectFields(ir, typ.fieldNames.diff(fields)) } - def cne(other: IRProxy): IRProxy = (env: E) => { - val left = ir(env) - val right = other(env) - ApplyComparisonOp(NEQWithNA, left, right) - } + def drop(fields: String*): IR = + drop(fields.toFastSeq) - def <(other: IRProxy): IRProxy = (env: E) => { - val left = ir(env) - val right = other(env) - ApplyComparisonOp(LT, left, right) - } + def len: IR = ArrayLen(ir) + def isNA: IR = IsNA(ir) - def >(other: IRProxy): IRProxy = (env: E) => { - val left = ir(env) - val right = other(env) - ApplyComparisonOp(GT, left, right) - } + def orElse(alt: IR): IR = + bindIR(ir)(x => If(IsNA(x), alt, x)) - def <=(other: IRProxy): IRProxy = (env: E) => { - val left = ir(env) - val right = other(env) - ApplyComparisonOp(LTEQ, left, right) - } + def filter(f: Atom => IR): IR = + filterIR(ir)(f) - def >=(other: IRProxy): IRProxy = (env: E) => { - val left = ir(env) - val right = other(env) - ApplyComparisonOp(GTEQ, left, right) - } + def aggExplode(f: Atom => IR): IR = + aggExplodeIR(ir, isScan = false)(f) - def apply(lookup: Symbol): IRProxy = (env: E) => { - val eval = ir(env) - eval.typ match { - case _: TStruct => - GetField(eval, lookup.name) - case _: TArray => - ArrayRef(eval, ref(lookup)(env)) - } - } + def streamMap(f: Atom => IR): IR = + mapIR(ir)(f) - def castRename(t: Type): IRProxy = (env: E) => CastRename(ir(env), t) + def streamFlatMap(f: Atom => IR): IR = + flatMapIR(ir)(f) - def insertFields(fields: (Symbol, IRProxy)*): IRProxy = insertFieldsList(fields.toFastSeq) + def streamFlatten: IR = + flatten(ir) - def insertFieldsList( - fields: IndexedSeq[(Symbol, IRProxy)], - ordering: Option[IndexedSeq[String]] = None, - ): IRProxy = (env: E) => - InsertFields(ir(env), fields.map { case (s, fir) => (s.name, fir(env)) }, ordering) + def streamAgg(f: Atom => IR): IR = + streamAggIR(ir)(f) - def selectFields(fields: String*): IRProxy = (env: E) => - SelectFields(ir(env), fields.toFastSeq) + def streamAggScan(f: Atom => IR): IR = + streamAggScanIR(ir)(f) - def dropFieldList(fields: IndexedSeq[String]): IRProxy = (env: E) => { - val struct = ir(env) - val typ = struct.typ.asInstanceOf[TStruct] - SelectFields(struct, typ.fieldNames.diff(fields)) - } + def slice(start: IR, stop: Option[IR], step: IR = I32(1)): IR = + ArraySlice(ir, start, stop, step, ErrorIDs.NO_ERROR) - def dropFields(fields: Symbol*): IRProxy = dropFieldList(fields.toFastSeq.map(_.name)) - - def insertStruct(other: IRProxy, ordering: Option[IndexedSeq[String]] = None): IRProxy = - (env: E) => { - val right = other(env) - val sym = freshName() - Let( - FastSeq(sym -> right), - InsertFields( - ir(env), - right.typ.asInstanceOf[TStruct].fieldNames.map(f => - f -> GetField(Ref(sym, right.typ), f) - ), - ordering, - ), - ) - } - - def len: IRProxy = (env: E) => ArrayLen(ir(env)) - - def isNA: IRProxy = (env: E) => IsNA(ir(env)) - - def orElse(alt: IRProxy): IRProxy = - (env: E) => bindIR(ir(env))(x => If(IsNA(x), alt(env), x)) - - def filter(pred: LambdaProxy): IRProxy = { env: E => - val array = ir(env) - val binding = Name(pred.s.name) -> TIterable.elementType(array.typ) - ToArray(StreamFilter( - ToStream(array), - binding._1, - pred.body(env.bindEval(binding)), - )) - } + def aggElements(knownLength: Option[IR] = None)(aggBody: (Atom, Atom) => IR): IR = + aggArrayPerElement(ir, knownLength, isScan = false)(aggBody) - def map(f: LambdaProxy): IRProxy = { env: E => - val array = ir(env) - val binding = Name(f.s.name) -> TIterable.elementType(array.typ) - ToArray(StreamMap( - ToStream(array), - binding._1, - f.body(env.bindEval(binding)), - )) - } - - def aggExplode(f: LambdaProxy): IRProxy = { env: E => - val array = ir(env.promoteAgg) - val binding = Name(f.s.name) -> TIterable.elementType(array.typ) - AggExplode( - ToStream(array), - binding._1, - f.body(env.bindEval(binding).bindAgg(binding)), - isScan = false, - ) - } - - def flatMap(f: LambdaProxy): IRProxy = { env: E => - val array = ir(env) - val binding = Name(f.s.name) -> TIterable.elementType(array.typ) - ToArray(StreamFlatMap( - ToStream(array), - binding._1, - ToStream(f.body(env.bindEval(binding))), - )) - } - - def flatten: IRProxy = - flatMap('a ~> 'a) - - def streamAgg(f: LambdaProxy): IRProxy = { env: E => - val array = ir(env) - val binding = Name(f.s.name) -> TIterable.elementType(array.typ) - StreamAgg( - ToStream(array), - binding._1, - f.body(env.bindEval(binding).createAgg), - ) - } - - def streamAggScan(f: LambdaProxy): IRProxy = { env: E => - val array = ir(env) - val binding = Name(f.s.name) -> TIterable.elementType(array.typ) - ToArray(StreamAggScan( - ToStream(array), - binding._1, - f.body(env.bindEval(binding).createScan), - )) - } + def sort(ascending: IR, onKey: Boolean = false): IR = + ArraySort(ir, ascending, onKey) - def arraySlice(start: IRProxy, stop: Option[IRProxy], step: IRProxy): IRProxy = { - (env: E) => - ArraySlice( - this.ir(env), - start.ir(env), - stop.map(inner => inner.ir(env)), - step.ir(env), - ErrorIDs.NO_ERROR, - ) - } - - def aggElements( - elementsSym: Symbol, - indexSym: Symbol, - knownLength: Option[IRProxy], - )( - aggBody: IRProxy - ): IRProxy = { env: E => - val array = ir(env.promoteAgg) - - val bindings = - FastSeq( - Name(elementsSym.name) -> TIterable.elementType(array.typ), - Name(indexSym.name) -> TInt32, - ) - - AggArrayPerElement( - array, - bindings(0)._1, - bindings(1)._1, - aggBody(env.bindEval(bindings: _*).bindAgg(bindings: _*)), - knownLength.map(_(env)), - isScan = false, - ) - } + def groupByKey: IR = GroupByKey(ir) - def sort(ascending: IRProxy, onKey: Boolean = false): IRProxy = - (env: E) => ArraySort(ToStream(ir(env)), ascending(env), onKey) + def toArray: IR = ToArray(ir) - def groupByKey: IRProxy = (env: E) => GroupByKey(ToStream(ir(env))) + def stream: IR = ToStream(ir) - def toArray: IRProxy = (env: E) => ToArray(ToStream(ir(env))) - - def toDict: IRProxy = (env: E) => ToDict(ToStream(ir(env))) + def toDict: IR = ToDict(ir) def parallelize(nPartitions: Option[Int] = None): TableIR = - TableParallelize(ir(BindingEnv.empty), nPartitions) - - def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = - ir.map( - '__elem ~> - makeTuple( - '__elem.selectFields(keyFields: _*), - '__elem.dropFieldList(keyFields), - ) - ) - .toDict - - def tupleElement(i: Int): IRProxy = (env: E) => GetTupleElement(ir(env), i) - - private[ir] def apply(env: E): IR = ir(env) - } - - class LambdaProxy(val s: Symbol, val body: IRProxy) + TableParallelize(ir, nPartitions) - class SymbolProxy(val s: Symbol) extends AnyVal { - def ~>(body: IRProxy): LambdaProxy = new LambdaProxy(s, body) + def tupleElement(i: Int): IR = GetTupleElement(ir, i) } - - case class BindingProxy(s: Symbol, value: IRProxy, scope: Scope) - - private object LetProxy { - def bind(bindings: IndexedSeq[BindingProxy], body: IRProxy, env: E): IR = { - var newEnv = env - val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) => - val resolvedValue = - value( - if (scope == AGG) newEnv.promoteAgg - else if (scope == SCAN) newEnv.promoteScan - else newEnv - ) - newEnv = newEnv.bindInScope(Name(sym.name), resolvedValue.typ, scope) - Binding(Name(sym.name), resolvedValue, scope) - } - Block(resolvedBindings, body(newEnv)) - } - } - - object let extends Dynamic { - def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { - assert(method == "apply") - letDyn(args.map { case (n, ir) => Name(n) -> ir }: _*) - } - } - - object letDyn { - def apply(args: (Name, IRProxy)*): LetProxy = - new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s.str), b, Scope.EVAL) }) - } - - class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { - def in(body: IRProxy): IRProxy = { env: E => LetProxy.bind(bindings, body, env) } - } - - object aggLet extends Dynamic { - def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { - assert(method == "apply") - new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) }) - } - } - - object scanLet extends Dynamic { - def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { - assert(method == "apply") - new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.SCAN) }) - } - } - - def lift(f: (IR) => IRProxy)(x: IRProxy): IRProxy = (e: E) => f(x(e))(e) } diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index 980072a9934..355fd97142e 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -644,12 +644,8 @@ package defs { val compare = if (!onKey) ApplyComparisonOp(Compare, l, r) else l.typ match { - case elt: TStruct => - val field = elt.fieldNames(0) - ApplyComparisonOp(Compare, GetField(l, field), GetField(r, field)) - case elt: TTuple => - val index = elt.fields(0).index - ApplyComparisonOp(Compare, GetTupleElement(l, index), GetTupleElement(r, index)) + case _: TBaseStruct => + ApplyComparisonOp(Compare, GetFieldByIdx(l, 0), GetFieldByIdx(r, 0)) case telem => fatal(s"ArraySort(.., onKey = true) requires struct or tuple elements, got $telem") } diff --git a/hail/hail/src/is/hail/expr/ir/IRBuilder.scala b/hail/hail/src/is/hail/expr/ir/IRBuilder.scala index c8639e06d18..82c6f30215c 100644 --- a/hail/hail/src/is/hail/expr/ir/IRBuilder.scala +++ b/hail/hail/src/is/hail/expr/ir/IRBuilder.scala @@ -45,6 +45,9 @@ object Memoized { def let[S](n: Name, ir: IR): Memoized[S] = new Let[S](n, ir) def defer[S](f: IRBuilder => IR): Memoized[S] = new Suspend(f) + def sequence[S](bindings: Memoized[S]*): Memoized[S] = + bindings.view.tail.foldLeft(bindings.head)(_ >> _) + sealed abstract class HasScope[S](val scope: Scope) implicit object evalScope extends HasScope[EVAL.type](EVAL) implicit object aggScope extends HasScope[AGG.type](AGG) diff --git a/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala b/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala index 940f80452a3..d47cb03bd82 100644 --- a/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala @@ -5,16 +5,15 @@ import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.compat.mutable.Growable import is.hail.expr.ir.{Memoized => M} +import is.hail.expr.ir.Scope.EVAL import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.{WrappedMatrixToTableFunction, WrappedMatrixToValueFunction} import is.hail.types.virtual._ import is.hail.utils._ object LowerMatrixIR { - val entriesFieldName: String = MatrixType.entriesIdentifier - val colsFieldName: String = "__cols" - val colsField: Symbol = Symbol(colsFieldName) - val entriesField: Symbol = Symbol(entriesFieldName) + val entriesField: String = MatrixType.entriesIdentifier + val colsField: String = "__cols" def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = { val ab = ArraySeq.newBuilder[(Name, IR)] @@ -58,28 +57,6 @@ object LowerMatrixIR { case vir: IR => lower(ctx, vir, ab) } - def colVals(tir: TableIR): IR = - GetField(Ref(TableIR.globalName, tir.typ.globalType), colsFieldName) - - def globals(tir: TableIR): IR = { - val globalType = tir.typ.globalType - SelectFields( - Ref(TableIR.globalName, globalType), - globalType.fieldNames.diff(FastSeq(colsFieldName)), - ) - } - - def rowVal(tir: TableIR): IR = { - val rowType = tir.typ.rowType - SelectFields( - Ref(TableIR.rowName, rowType), - rowType.fieldNames.diff(FastSeq(entriesFieldName)), - ) - } - - def entries(tir: TableIR): IR = - GetField(Ref(TableIR.rowName, tir.typ.rowType), entriesFieldName) - import is.hail.expr.ir.DeprecatedIRBuilder._ private def bindingsToStruct(bindings: IndexedSeq[(Name, IR)]): MakeStruct = @@ -102,37 +79,35 @@ object LowerMatrixIR { ) case CastTableToMatrix(child, entries, cols, _) => - val lc = lower(ctx, child, liftedRelationalLets) - val row: Atom = Ref(TableIR.rowName, lc.typ.rowType) - val glob = Ref(TableIR.globalName, lc.typ.globalType) - TableMapRows( - lc, - bindIR(GetField(row, entries)) { entries => - If( - IsNA(entries), - Die("missing entry array unsupported in 'to_matrix_table_row_major'", row.typ), - bindIRs(ArrayLen(entries), ArrayLen(GetField(glob, cols))) { - case Seq(entriesLen, colsLen) => - If( - entriesLen cne colsLen, - Die( - strConcat( - "length mismatch between entry array and column array in 'to_matrix_table_row_major': ", - entriesLen, - " entries, ", - colsLen, - " cols, at ", - SelectFields(row, child.typ.key), + lower(ctx, child, liftedRelationalLets) + .mapRows { (global, row) => + bindIR(row.get(entries)) { entries => + If( + IsNA(entries), + Die("missing entry array unsupported in 'to_matrix_table_row_major'", row.typ), + bindIRs(entries.len, global.get(cols).len) { + case Seq(entriesLen, colsLen) => + If( + entriesLen cne colsLen, + Die( + strConcat( + "length mismatch between entry array and column array in 'to_matrix_table_row_major': ", + entriesLen, + " entries, ", + colsLen, + " cols, at ", + row.select(child.typ.key), + ), + row.typ, + -1, ), - row.typ, - -1, - ), - row, - ) - }, - ) - }, - ).rename(Map(entries -> entriesFieldName), Map(cols -> colsFieldName)) + row, + ) + }, + ) + } + } + .rename(Map(entries -> entriesField), Map(cols -> colsField)) case MatrixToMatrixApply(child, function) => val loweredChild = lower(ctx, child, liftedRelationalLets) @@ -141,19 +116,19 @@ object LowerMatrixIR { case MatrixRename(child, globalMap, colMap, rowMap, entryMap) => var t = lower(ctx, child, liftedRelationalLets).rename(rowMap, globalMap) - if (colMap.nonEmpty) { - val newColsType = TArray(child.typ.colType.rename(colMap)) - t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq( - colsFieldName -> newColsType - )))) - } + if (colMap.nonEmpty) + t = t.mapGlobals { global => + global.rename(_.insertFields(FastSeq( + colsField -> TArray(child.typ.colType.rename(colMap)) + ))) + } - if (entryMap.nonEmpty) { - val newEntriesType = TArray(child.typ.entryType.rename(entryMap)) - t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq( - entriesFieldName -> newEntriesType - )))) - } + if (entryMap.nonEmpty) + t = t.mapRows { (_, row) => + row.rename(_.insertFields(FastSeq( + entriesField -> TArray(child.typ.entryType.rename(entryMap)) + ))) + } t @@ -161,37 +136,43 @@ object LowerMatrixIR { lower(ctx, child, liftedRelationalLets).keyBy(keys, isSorted) case MatrixFilterRows(child, pred) => - lower(ctx, child, liftedRelationalLets) - .filter( - let( - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) in lower(ctx, pred, liftedRelationalLets) - ) + lower(ctx, child, liftedRelationalLets).filter { (global, row) => + M.eval { + for { + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + } yield lower(ctx, pred, liftedRelationalLets) + } + } case MatrixFilterCols(child, pred) => lower(ctx, child, liftedRelationalLets) - .mapGlobals( - 'global.insertFields( - '__new_col_idx -> - (let( - __cols = 'global(colsField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - ) in irRange(0, '__cols.len).filter('__col_idx ~> - (let(sa = '__cols('__col_idx)) in - lower(ctx, pred, liftedRelationalLets)))) - ) - ) - .mapRows( - let(__entries = 'row(entriesField)) in - 'row.insertFields(entriesField -> 'global('__new_col_idx).map('i ~> '__entries('i))) - ) - .mapGlobals( - let(__cols = 'global(colsField)) in - 'global - .insertFields(colsField -> 'global('__new_col_idx).map('i ~> '__cols('i))) - .dropFields('__new_col_idx) - ) + .mapGlobals { global => + M.eval { + for { + cols <- global.get(colsField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + } yield global.insert( + "__new_col_idx" -> + ToArray(rangeIR(cols.len).filter { idx => + M.eval { + (Name("sa") -> ArrayRef(cols, idx)) >> + lower(ctx, pred, liftedRelationalLets) + } + }) + ) + } + } + .mapRows { (global, row) => + row.update(entriesField) { entries => + mapArray(global.get("__new_col_idx"))(entries.at(_)) + } + } + .mapGlobals { global => + global + .update(colsField)(cols => mapArray(global.get("__new_col_idx"))(cols.at(_))) + .drop("__new_col_idx") + } case MatrixAnnotateRowsTable(child, table, root, product) => val kt = table.typ.keyType @@ -211,53 +192,46 @@ object LowerMatrixIR { case MatrixChooseCols(child, oldIndices) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('__new_col_idx -> Literal(TArray(TInt32), oldIndices))) - .mapRows( - let(__entries = 'row(entriesField)) in - 'row.insertFields(entriesField -> 'global('__new_col_idx).map('i ~> '__entries('i))) - ) - .mapGlobals( - let(__cols = 'global(colsField)) in - 'global - .insertFields(colsField -> 'global('__new_col_idx).map('i ~> '__cols('i))) - .dropFields('__new_col_idx) - ) + .mapGlobals(_.insert("__new_col_idx" -> Literal(TArray(TInt32), oldIndices))) + .mapRows { (global, row) => + row.update(entriesField) { entries => + mapArray(global.get("__new_col_idx"))(entries.at(_)) + } + } + .mapGlobals { global => + global.update(colsField)(cols => mapArray(global.get("__new_col_idx"))(cols.at(_))) + } case MatrixAnnotateColsTable(child, table, root) => - lower(ctx, child, liftedRelationalLets) - .mapGlobals( - let( - __dictfield = - lower(ctx, table, liftedRelationalLets) - .keyBy(FastSeq()) - .collect() - .apply('rows) - .arrayStructToDict(table.typ.key) - ) in 'global.insertFields( - colsField -> { + lower(ctx, child, liftedRelationalLets).mapGlobals { global => + bindIR(lower(ctx, table, liftedRelationalLets).collectAsDict) { annotations => + global.update(colsField) { cols => + mapArray(cols) { col => val key = - makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => - Symbol(tk) -> '__cols(Symbol(mck)) - }: _*) - - 'global(colsField).map('__cols ~> - '__cols.insertFields( - Symbol(root) -> '__dictfield.invoke("get", table.typ.valueType, key) - )) + MakeStruct( + table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => + tk -> col.get(mck) + } + ) + + col.insert(root -> annotations.invoke("get", table.typ.valueType, key)) } - ) - ) + } + } + } case MatrixMapGlobals(child, newGlobals) => - lower(ctx, child, liftedRelationalLets) - .mapGlobals( - (let(global = 'global.selectFields(child.typ.globalType.fieldNames: _*)) in - lower(ctx, newGlobals, liftedRelationalLets)) - .insertFields(colsField -> 'global(colsField)) - ) + lower(ctx, child, liftedRelationalLets).mapGlobals { global => + M.eval { + for { + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + newGlobals <- lower(ctx, newGlobals, liftedRelationalLets) + } yield newGlobals.insert(colsField -> global.get(colsField)) + } + } case MatrixMapRows(child, newRow) => - def liftScans(ir: IR): IRProxy = { + def liftScans(ir: IR, global: Atom, row: Atom, cols: Atom, entries: Atom): IR = { def lift(ir: IR, builder: Growable[(Name, IR)]): IR = ir match { case a: ApplyScanOp => val s = freshName() @@ -351,28 +325,39 @@ object LowerMatrixIR { val b0 = lift(ir, ab) val b1 = if (ContainsAgg(b0)) - irRange(0, '__entries.len) - .filter('i ~> !'__entries('i).isNA) - .streamAgg('i ~> (aggLet(sa = '__cols('i), g = '__entries('i)) in b0)) + rangeIR(0, entries.len) + .filter(i => !entries.at(i).isNA) + .streamAgg { i => + M.agg { + (MatrixIR.colName -> cols.at(i)) >> + (MatrixIR.entryName -> entries.at(i)) >> + b0 + } + } else - irToProxy(b0) + b0 - scanLet( - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) in (letDyn(ab.result().map { case (name, expr) => name -> irToProxy(expr) }: _*) in b1) + M.scan { + M.sequence( + MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames), + MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames), + M.eval(ab.result().foldRight[M[EVAL.type]](b1)((binding, body) => binding >> body)), + ) + } } - lower(ctx, child, liftedRelationalLets).mapRows( - (let( - __cols = 'global(colsField), - __entries = 'row(entriesField), - n_cols = '__cols.len, - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) in liftScans(lower(ctx, newRow, liftedRelationalLets))) - .insertFields(entriesField -> 'row(entriesField)) - ) + lower(ctx, child, liftedRelationalLets).mapRows { (global, row) => + M.eval { + for { + cols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- Name("n_cols") -> cols.len + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + r <- liftScans(lower(ctx, newRow, liftedRelationalLets), global, row, cols, entries) + } yield r.insert(entriesField -> entries.ir) + } + } case MatrixMapCols(child, newCol, _) => val lc = lower(ctx, child, liftedRelationalLets) @@ -505,169 +490,194 @@ object LowerMatrixIR { val aggs = aggBuilder.result() val scans = scanBuilder.result() - val noOp: (IRProxy => IRProxy, IRProxy => IRProxy) = - (identity[IRProxy], identity[IRProxy]) + def unit: IR = makestruct() + def munit: M[EVAL.type] = M.pure(unit) + + def cols = Ref(Name("__cols"), child.typ.colType) + def colIdx = Ref(Name("__col_idx"), TInt32) val ( - aggOutsideTransformer: (IRProxy => IRProxy), - aggInsideTransformer: (IRProxy => IRProxy), + setupAggContext: M[EVAL.type], + setupAggInnerContext: M[EVAL.type], ) = - if (aggs.isEmpty) noOp + if (aggs.isEmpty) (munit, munit) else { val aggResult = - lc.deepCopy.aggregate( - let( - __cols = 'global(colsField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - ) in (aggLet( - __cols = 'global(colsField), - __entries = 'row(entriesField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) in makeStruct( - 'n_rows -> - applyAggOp(Count(), FastSeq(), FastSeq()), - 'array_aggs -> - irRange(0, '__cols.len) - .aggElements('__element_idx, '__result_idx, Some('__cols.len))( - let(sa = '__cols('__result_idx)) in - (aggLet(sa = '__cols('__element_idx), g = '__entries('__element_idx)) in - aggFilter(!'g.isNA, bindingsToStruct(aggs))) - ), - )) - ) + lc.deepCopy.aggregate { (global, row) => + M.eval { + for { + cols <- global.get(colsField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + } yield M.agg { + for { + aggCols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + } yield makestruct( + "__n_rows" -> + ApplyAggOp(Count())(), + "__array_aggs" -> + rangeIR(aggCols.len).aggElements(Some(cols.len)) { (elem, index) => + M.eval { + (MatrixIR.colName -> cols.at(index)) >> + M.agg { + for { + _ <- MatrixIR.colName -> aggCols.at(elem) + g <- MatrixIR.entryName -> entries.at(elem) + } yield AggFilter(!g.isNA, bindingsToStruct(aggs), false) + } + } + }, + ) + } + } + } val ident = freshName() liftedRelationalLets += (ident -> aggResult) - val bindResult: IRProxy => IRProxy = - let( - __agg_result = RelationalRef(ident, aggResult.typ), - __array_aggs = '__agg_result('array_aggs), - n_rows = '__agg_result('n_rows), - ) in _ + val arrayAggs = Ref(Name("__array_aggs"), aggResult.get("__array_aggs").typ) - def bodyResult(body: IRProxy): IRProxy = - let(__agg_elem = '__array_aggs('__col_idx)) in - (letDyn(aggs.map { case (n, _) => n -> '__agg_elem(Symbol(n.str)) }: _*) in - body) + val bindResult: M[EVAL.type] = + for { + result <- RelationalRef(ident, aggResult.typ) + _ <- arrayAggs.name -> result.get("__array_aggs") + _ <- Name("n_rows") -> result.get("__n_rows") + } yield unit + + val bodyResult: M[EVAL.type] = + arrayAggs.asInstanceOf[Atom].at(colIdx).flatMap { elem => + M.sequence(aggs.map[M[EVAL.type]] { case (n, _) => n -> elem.get(n.str) }: _*) + } - (bindResult, bodyResult _) + (bindResult, bodyResult) } val ( - scanOutsideTransformer: (IRProxy => IRProxy), - scanInsideTransformer: (IRProxy => IRProxy), + setupScanContext: M[EVAL.type], + setupScanInnerContext: M[EVAL.type], ) = - if (scans.isEmpty) noOp + if (scans.isEmpty) (munit, munit) else { - val scanStruct = bindingsToStruct(scans) + val ScanResult = StreamAggScan(cols, MatrixIR.colName, bindingsToStruct(scans)) + val scanResult = Ref(Name("__scan_result"), ScanResult.typ) - val bindResult: IRProxy => IRProxy = - let(__scan_result = '__cols.streamAggScan('sa ~> scanStruct)) in _ + val bindResult: M[EVAL.type] = + scanResult.name -> ScanResult - def bodyResult(body: IRProxy): IRProxy = - let(__scan_elem = '__scan_result('__col_idx)) in - (letDyn(scans.map { case (n, _) => n -> '__scan_elem(Symbol(n.str)) }: _*) in - body) + val bodyResult: M[EVAL.type] = + scanResult.asInstanceOf[Atom].at(colIdx).flatMap { elem => + M.sequence(scans.map[M[EVAL.type]] { case (n, _) => n -> elem.get(n.str) }: _*) + } - (bindResult, bodyResult _) + (bindResult, bodyResult) } - lc.mapGlobals( - let( - __cols = 'global(colsField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - ) in 'global.insertFields( - colsField -> - aggOutsideTransformer( - scanOutsideTransformer( - irRange(0, '__cols.len).map('__col_idx ~> - (let(sa = '__cols('__col_idx)) in - aggInsideTransformer(scanInsideTransformer(b0)))) - ) - ) - ) - ) + lc.mapGlobals { global => + M.eval { + for { + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + cs <- cols.name -> global.get(colsField) + _ <- setupAggContext + _ <- setupScanContext + } yield global.insert( + colsField -> + ToArray(StreamMap( + rangeIR(cs.len), + colIdx.name, + M.eval { + M.sequence( + MatrixIR.colName -> cs.at(colIdx), + setupAggInnerContext, + setupScanInnerContext, + b0, + ) + }, + )) + ) + } + } case MatrixFilterEntries(child, pred) => val mtype = child.typ - lower(ctx, child, liftedRelationalLets) - .mapRows( - let( - __cols = 'global(colsField), - __entries = 'row(entriesField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(mtype.rowType.fieldNames: _*), - ) in 'row.insertFields( + lower(ctx, child, liftedRelationalLets).mapRows { (global, row) => + M.eval { + for { + cols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- MatrixIR.globalName -> global.select(mtype.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(mtype.rowType.fieldNames) + } yield row.insert( entriesField -> - irRange(0, '__cols.len).map('i ~> - (let(sa = '__cols('i), g = '__entries('i)) in - irIf(lower(ctx, pred, liftedRelationalLets))('g)(NA(mtype.entryType)))) + ToArray(rangeIR(cols.len).streamMap { i => + M.eval { + for { + _ <- MatrixIR.colName -> cols.at(i) + g <- MatrixIR.entryName -> entries.at(i) + } yield If(lower(ctx, pred, liftedRelationalLets), g, NA(mtype.entryType)) + } + }) ) - ) + } + } case MatrixUnionCols(left, right, joinType) => - def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy = - if (joinType == "inner") 'row(entries) - else let(__entries = 'row(entries)) in - irIf(!'__entries.isNA)('__entries)( - irRange(0, 'global(cols).len).map('a ~> - MakeStruct(right.typ.entryType.fields.map(f => (f.name, NA(f.typ))))) - ) - - val ll = lower(ctx, left, liftedRelationalLets).distinct() - val rr = lower(ctx, right, liftedRelationalLets).distinct() + val ll = lower(ctx, left, liftedRelationalLets).distinct + val rr = lower(ctx, right, liftedRelationalLets).distinct TableJoin( ll, - rr.mapRows( - 'row.castRename(rr.typ.rowType.rename(Map(entriesFieldName -> '__right_entries.name))) - ) - .mapGlobals('global - .insertFields('__right_cols -> 'global(colsField)) - .selectFields('__right_cols.name)), + rr + .mapRows((_, row) => row.rename(_.rename(Map(entriesField -> "__right_entries")))) + .mapGlobals(global => makestruct("__right_cols" -> global.get(colsField))), joinType, ) - .mapRows('row - .insertFields( - entriesField -> { - val ls = handleMissingEntriesArray(entriesField, colsField) - val rs = handleMissingEntriesArray('__right_entries, '__right_cols) - makeArray(ls, rs).flatten + .mapRows { (global, row) => + row + .insert( + entriesField -> { + def handleMissingEntriesArray(entries: String, cols: String): IR = + if (joinType == "inner") row.get(entries) + else row.get(entries).orElse { + ToArray(rangeIR(global.get(cols).len).streamMap { _ => + MakeStruct(right.typ.entryType.fields.map(f => f.name -> NA(f.typ))) + }) + } + + val ls = handleMissingEntriesArray(entriesField, colsField) + val rs = handleMissingEntriesArray("__right_entries", "__right_cols") + MakeArray(ls, rs).stream.streamFlatten.toArray + } + ) + .drop("__right_entries") + } + .mapGlobals { global => + global + .update(colsField) { cols => + MakeArray(cols, global.get("__right_cols")).stream.streamFlatten.toArray } - ) - .dropFields('__right_entries)) - .mapGlobals('global - .insertFields( - colsField -> - makeArray('global(colsField), 'global('__right_cols)).flatten - ) - .dropFields('__right_cols)) + .drop("__right_cols") + } case MatrixMapEntries(child, newEntries) => - val lc = lower(ctx, child, liftedRelationalLets) - TableMapRows( - lc, + lower(ctx, child, liftedRelationalLets).mapRows { (global, row) => M.eval { for { - cols <- Name("__cols") -> colVals(lc) - entries <- Name("__entries") -> entries(lc) - _ <- MatrixIR.globalName -> globals(lc) - row <- MatrixIR.rowName -> rowVal(lc) - } yield InsertFields( - row, - FastSeq( - entriesFieldName -> - ToArray(StreamZip( - FastSeq(ToStream(cols), ToStream(entries)), - FastSeq(MatrixIR.colName, MatrixIR.entryName), - lower(ctx, newEntries, liftedRelationalLets), - ArrayZipBehavior.AssumeSameLength, - )) - ), + cols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + } yield row.insert( + entriesField -> + ToArray(StreamZip( + FastSeq(ToStream(cols), ToStream(entries)), + FastSeq(MatrixIR.colName, MatrixIR.entryName), + lower(ctx, newEntries, liftedRelationalLets), + ArrayZipBehavior.AssumeSameLength, + )) ) - }, - ) + } + } case MatrixRepartition(child, n, shuffle) => TableRepartition(lower(ctx, child, liftedRelationalLets), n, shuffle) @@ -679,8 +689,9 @@ object LowerMatrixIR { // FIXME: this should check that all children have the same column keys. val first = lower(ctx, children.head, liftedRelationalLets) TableUnion(FastSeq(first) ++ - children.tail.map(lower(ctx, _, liftedRelationalLets) - .mapRows('row.selectFields(first.typ.rowType.fieldNames: _*)))) + children.tail.map(lower(ctx, _, liftedRelationalLets).mapRows { (_, row) => + SelectFields(row, first.typ.rowType.fieldNames) + })) case MatrixDistinctByRow(child) => TableDistinct(lower(ctx, child, liftedRelationalLets)) @@ -689,118 +700,134 @@ object LowerMatrixIR { case MatrixColsHead(child, n) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('__cols -> 'global('__cols).arraySlice(0, Some(n), 1))) - .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(0, Some(n), 1))) + .mapGlobals(_.update(colsField)(_.slice(0, Some(n)))) + .mapRows((_, row) => row.update(entriesField)(_.slice(0, Some(n)))) case MatrixColsTail(child, n) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('__cols -> 'global('__cols).arraySlice(-n, None, 1))) - .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(-n, None, 1))) + .mapGlobals(_.update(colsField)(_.slice(-n, None))) + .mapRows((_, row) => row.update(entriesField)(_.slice(-n, None))) case MatrixExplodeCols(child, path) => lower(ctx, child, liftedRelationalLets) - .mapGlobals( - let( - __cols = - 'global(colsField), - __lengths = - '__cols.map('__elem ~> - path - .foldLeft[IRProxy]('__elem) { case (irp, f) => irp(Symbol(f)) } - .len - .orElse(0)), - ) in 'global.insertFields( - '__cols -> - irRange(0, '__cols.len).flatMap('__col_idx ~> { - val nestedRefs = - path.init.scanLeft('__cols('__col_idx))((irp, name) => irp(Symbol(name))) - - irRange(0, '__lengths('__col_idx)).map('__length_idx ~> - path.zip(nestedRefs).zipWithIndex.foldRight[IRProxy]('__length_idx) { - case (((field, ref), i), arg) => - val s = Symbol(field) - ref.insertFields( - s -> (if (i == nestedRefs.length - 1) ref(s).toArray(arg) else arg) - ) - }) - }), - '__lengths -> - '__lengths, - ) - ) - .mapRows( - let(__entries = 'row(entriesField), __lengths = 'global('__lengths)) in - 'row.insertFields( + .mapGlobals { global => + M.eval { + for { + cols <- + global.get(colsField) + lengths <- + mapArray(cols) { elem => + path + .foldLeft[IR](elem)(_.get(_)) + .len + .orElse(0) + } + } yield global.insert( + colsField -> + ToArray(rangeIR(cols.len).streamFlatMap { colIdx => + val nestedRefs = + path.init.scanLeft(cols.at(colIdx))(_.get(_)) + + rangeIR(lengths.at(colIdx)).streamMap { lengthIdx => + path.zip(nestedRefs).zipWithIndex.foldRight[IR](lengthIdx) { + case (((field, ref), i), arg) => + ref.insert( + field -> + (if (i == nestedRefs.length - 1) ref.get(field).toArray.at(arg) + else arg) + ) + } + } + }), + "__lengths" -> + lengths.ir, + ) + } + } + .mapRows { (global, row) => + M.eval { + for { + entries <- row.get(entriesField) + lengths <- global.get("__lengths") + } yield row.insert( entriesField -> - irRange(0, '__entries.len).flatMap('__col_idx ~> - irRange(0, '__lengths('__col_idx)).map('__unused ~> - '__entries('__col_idx))) + ToArray(rangeIR(entries.len).streamFlatMap { idx => + rangeIR(lengths.at(idx)).streamMap(_ => entries.at(idx)) + }) ) - ) - .mapGlobals('global.dropFields('__lengths)) + } + } + .mapGlobals(_.drop("lengths")) case MatrixAggregateRowsByKey(child, entryExpr, rowExpr) => - lower(ctx, child, liftedRelationalLets) - .aggregateByKey( - let( - __cols = 'global(colsField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - ) in (aggLet( - __cols = 'global(colsField), - __entries = 'row(entriesField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) in lower(ctx, rowExpr, liftedRelationalLets).insertFields( - entriesField -> - irRange(0, '__cols.len) - .aggElements('__element_idx, '__result_idx, Some('__cols.len))( - let(sa = '__cols('__result_idx)) in - (aggLet(sa = '__cols('__element_idx), g = '__entries('__element_idx)) in - aggFilter(!'g.isNA, lower(ctx, entryExpr, liftedRelationalLets))) - ) - )) - ) + lower(ctx, child, liftedRelationalLets).aggregateByKey { (global, row) => + M.eval { + for { + cols <- global.get(colsField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + } yield M.agg { + for { + aggCols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + } yield lower(ctx, rowExpr, liftedRelationalLets).insert( + entriesField -> + rangeIR(aggCols.len).aggElements(Some(cols.len)) { (elem, index) => + M.eval { + (MatrixIR.colName -> cols.at(index)) >> + M.agg { + for { + _ <- MatrixIR.colName -> cols.at(elem) + g <- entries.at(elem) + } yield AggFilter( + !g.isNA, + lower(ctx, entryExpr, liftedRelationalLets), + isScan = false, + ) + } + } + } + ) + } + } + } case MatrixCollectColsByKey(child) => lower(ctx, child, liftedRelationalLets) - .mapGlobals( - let(__cols = 'global(colsField)) in - 'global.insertFields( - '__new_col_idx -> - irRange(0, '__cols.len) - .map('i ~> makeTuple('__cols('i).selectFields(child.typ.colKey: _*), 'i)) + .mapGlobals { global => + bindIR(global.get(colsField)) { cols => + global.insert( + "__new_col_idx" -> + rangeIR(cols.len) + .streamMap(i => maketuple(cols.at(i).select(child.typ.colKey), i)) .groupByKey .toArray ) - ) - .mapRows( - let(__entries = 'row(entriesField)) in - 'row.insertFields( - entriesField -> - 'global('__new_col_idx).map { - 'kv ~> - makeStruct(child.typ.entryType.fieldNames.map { f => - val s = Symbol(f) - s -> 'kv('value).map('i ~> '__entries('i)(s)) - }: _*) - } - ) - ) - .mapGlobals( - let(__cols = 'global(colsField)) in - 'global - .insertFields( - colsField -> - 'global('__new_col_idx).map('kv ~> - 'kv('key).insertFields( - child.typ.colValueStruct.fieldNames.map { f => - val s = Symbol(f) - s -> 'kv('value).map('i ~> '__cols('i)(s)) - }: _* - )) + } + } + .mapRows { (global, row) => + row.update(entriesField) { entries => + mapArray(global.get("__new_col_idx")) { kv => + MakeStruct(child.typ.entryType.fieldNames.map { f => + f -> mapArray(kv.get("value"))(i => entries.at(i).get(f)) + }) + } + } + } + .mapGlobals { global => + global.update(colsField) { cols => + mapArray(global.get("__new_col_idx")) { kv => + InsertFields( + kv.get("key"), + child.typ.colValueStruct.fieldNames.map { f => + f -> mapArray(kv.get("value"))(i => cols.at(i).get(f)) + }, ) - .dropFields('__new_col_idx) - ) + } + } + .drop("__new_col_idx") + } case MatrixExplodeRows(child, path) => TableExplode(lower(ctx, child, liftedRelationalLets), path) @@ -809,53 +836,77 @@ object LowerMatrixIR { case MatrixAggregateColsByKey(child, entryExpr, colExpr) => lower(ctx, child, liftedRelationalLets) - .mapGlobals( - let(__cols = 'global(colsField)) in - 'global.insertFields( - '__key_map -> - irRange(0, '__cols.len) - .map('__old_col_idx ~> - (let(__elem = '__cols('__old_col_idx)) in - makeStruct( - 'key -> '__elem.selectFields(child.typ.colKey: _*), - 'value -> '__old_col_idx, - ))) + .mapGlobals { global => + bindIR(global.get(colsField)) { cols => + global.insert( + "__key_map" -> + rangeIR(cols.len) + .streamMap { idx => + makestruct( + "__key" -> cols.at(idx).select(child.typ.colKey), + "__value" -> idx, + ) + } .groupByKey .toArray ) - ) - .mapRows( - let( - __key_map = 'global('__key_map), - __cols = 'global(colsField), - __entries = 'row(entriesField), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - va = 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) in 'row.insertFields( - entriesField -> - irRange(0, '__key_map.len).map('__new_col_idx ~> - '__key_map('__new_col_idx)('value).streamAgg('__agg_idx ~> - (aggLet(sa = '__cols('__agg_idx), g = '__entries('__agg_idx)) in - aggFilter(!'g.isNA, lower(ctx, entryExpr, liftedRelationalLets))))) - ) - ) - .mapGlobals( - let( - __cols = 'global(colsField), - __key_map = 'global('__key_map), - global = 'global.selectFields(child.typ.globalType.fieldNames: _*), - ) in 'global.insertFields( - colsField -> - irRange(0, '__key_map.len).map('__new_col_idx ~> - (let(__elem = '__key_map('__new_col_idx)) in - concatStructs( - '__elem('key), - '__elem('value).streamAgg('__agg_idx ~> - (aggLet(sa = '__cols('__agg_idx)) in - lower(ctx, colExpr, liftedRelationalLets))), - ))) - ) - ) + } + } + .mapRows { (global, row) => + M.eval { + for { + keyMap <- global.get("__key_map") + cols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + } yield row.insert( + entriesField -> + ToArray(rangeIR(keyMap.len).streamMap { idx => + keyMap.at(idx).get("__value").streamAgg { aggIdx => + M.agg { + for { + _ <- MatrixIR.colName -> cols.at(aggIdx) + g <- MatrixIR.entryName -> entries.at(aggIdx) + } yield AggFilter( + !g.isNA, + lower(ctx, entryExpr, liftedRelationalLets), + isScan = false, + ) + } + } + }) + ) + } + } + .mapGlobals { global => + M.eval { + for { + cols <- global.get(colsField) + keyMap <- global.get("__key_map") + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + } yield global.insert( + colsField -> + ToArray(rangeIR(keyMap.len).streamMap { idx => + M.eval { + for { + elem <- keyMap.at(idx) + key <- elem.get("__key") + value <- + elem.get("__value").stream.streamAgg { aggIdx => + M.agg { + (MatrixIR.colName -> cols.at(aggIdx)) >> + lower(ctx, colExpr, liftedRelationalLets) + } + } + } yield key.insert( + value.typ.asInstanceOf[TStruct].fieldNames.map(f => f -> value.get(f)) + ) + } + }) + ) + } + } case MatrixLiteral(_, tl) => tl } @@ -871,94 +922,96 @@ object LowerMatrixIR { val lowered = tir match { case CastMatrixToTable(child, entries, cols) => lower(ctx, child, ab) - .mapRows('row.selectFields(child.typ.rowType.fieldNames :+ entriesFieldName: _*)) - .mapGlobals('global.selectFields(child.typ.globalType.fieldNames :+ colsFieldName: _*)) - .rename(Map(entriesFieldName -> entries), Map(colsFieldName -> cols)) + .mapRows((_, row) => row.select(child.typ.rowType.fieldNames :+ entriesField)) + .mapGlobals(_.select(child.typ.globalType.fieldNames :+ colsField)) + .rename(Map(entriesField -> entries), Map(colsField -> cols)) case x @ MatrixEntriesTable(child) => val lc = lower(ctx, child, ab) if (child.typ.rowKey.nonEmpty && child.typ.colKey.nonEmpty) { lc - .mapGlobals( - let(__cols = 'global(colsField)) in - 'global.insertFields( - '__old_col_idx -> - irRange(0, '__cols.len) - .map('__col_idx ~> - makeStruct( - 'key -> '__cols('__col_idx).selectFields(child.typ.colKey: _*), - 'value -> '__col_idx, - )) + .mapGlobals { global => + bindIR(global.get(colsField)) { cols => + global.insert( + "__old_col_idx" -> + rangeIR(cols.len) + .streamMap(idx => maketuple(cols.at(idx).select(child.typ.colKey), idx)) .sort(ascending = true, onKey = true) - .map('__elem ~> '__elem('value)) - ) - ) - .aggregateByKey(makeStruct( - '__values -> - applyAggOp( - Collect(), - seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*)), + .stream + .streamMap(_.get(1)) + .toArray ) - )) - .mapRows( - let(__cols = 'global(colsField)) in - 'row.dropFields('__values).insertFields( - '__explode -> - 'global('__old_col_idx).flatMap('__old_col_idx ~> - (let(__col = '__cols('__old_col_idx)) in - 'row('__values) - .filter('__v ~> !'__v(entriesField)('__old_col_idx).isNA) - .map('__v ~> - (let(__entry = '__v(entriesField)('__old_col_idx)) in - makeStruct( - child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => - f -> '__v(f) - ) ++ - child.typ.colType.fieldNames.map(Symbol(_)).map(f => - f -> '__col(f) - ) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => - f -> '__entry(f) - ): _* - ))))) + } + } + .aggregateByKey { (_, row) => + makestruct( + "__values" -> + ApplyAggOp(Collect())(row.select(lc.typ.valueType.fieldNames)) + ) + } + .mapRows { (global, row) => + bindIR(global.get(colsField)) { cols => + row.drop("__values").insert( + "__explode" -> + ToArray(global.get("__old_col_idx").stream.streamFlatMap { oldColIndex => + bindIR(cols.at(oldColIndex)) { col => + row.get("__values").stream.streamFlatMap { v => + bindIR(v.get(entriesField).at(oldColIndex)) { entry => + val newRow = + MakeStruct( + child.typ.rowValueStruct.fieldNames.map(f => f -> v.get(f)) ++ + child.typ.colType.fieldNames.map(f => f -> col.get(f)) ++ + child.typ.entryType.fieldNames.map(f => f -> entry.get(f)) + ) + + If(IsNA(entry), MakeArray.empty(newRow.typ), MakeArray(newRow)) + } + } + } + }) ) - ) - .explode('__explode) - .mapRows( - let(__exploded = 'row('__explode)) in - makeStruct(x.typ.rowType.fieldNames.map { f => - val fd = Symbol(f) - (fd, if (child.typ.rowKey.contains(f)) 'row(fd) else '__exploded(fd)) - }: _*) - ) - .mapGlobals('global.dropFields(colsField, '__old_col_idx)) + } + } + .explode("__explode") + .mapRows { (_, row) => + bindIR(row.get("__explode")) { exploded => + MakeStruct(x.typ.rowType.fieldNames.map { f => + f -> (if (child.typ.rowKey.contains(f)) row.get(f) else exploded.get(f)) + }) + } + } + .mapGlobals(_.drop(colsField, "__old_col_idx")) .keyBy(child.typ.rowKey ++ child.typ.colKey, isSorted = true) } else { val result = lc - .mapRows( - let(__entries = 'row(entriesField)) in - 'row.insertFields( - '__col_idx -> - irRange(0, 'global(colsField).len) - .filter('__idx ~> !'__entries('__idx).isNA) + .mapRows { (global, row) => + bindIR(row.get(entriesField)) { entries => + row.insert( + "__col_idx" -> + ToArray(rangeIR(global.get(colsField).len).filter(!entries.at(_).isNA)) ) - ) - .explode('__col_idx) - .mapRows { - val newFields = - child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct(f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct(f)) - - let( - __col_struct = 'global(colsField)('row('__col_idx)), - __entry_struct = 'row(entriesField)('row('__col_idx)), - ) in 'row - .dropFields(entriesField, '__col_idx) - .insertFieldsList(newFields, ordering = Some(x.typ.rowType.fieldNames)) + } + } + .explode("__col_idx") + .mapRows { (global, row) => + M.eval { + for { + colIdx <- row.get("__col_idx") + colStruct <- global.get(colsField).at(colIdx) + entryStruct <- row.get(entriesField).at(colIdx) + + newFields = + child.typ.colType.fieldNames.map(f => f -> colStruct.get(f)) ++ + child.typ.entryType.fieldNames.map(f => f -> entryStruct.get(f)) + + } yield row + .drop(entriesField, "__col_idx") + .insert(newFields, ordering = Some(x.typ.rowType.fieldNames)) + } } - .mapGlobals('global.dropFields(colsField)) + .mapGlobals(_.drop(colsField)) if (child.typ.colKey.isEmpty) result else { @@ -974,34 +1027,34 @@ object LowerMatrixIR { function.lower() .getOrElse(WrappedMatrixToTableFunction( function, - colsFieldName, - entriesFieldName, + colsField, + entriesField, child.typ.colKey, )), ) case MatrixRowsTable(child) => lower(ctx, child, ab) - .mapGlobals('global.dropFields(colsField)) - .mapRows('row.dropFields(entriesField)) + .mapGlobals(_.drop(colsField)) + .mapRows((_, row) => row.drop(entriesField)) case MatrixColsTable(child) => val colKey = child.typ.colKey - val sortedCols = - if (colKey.isEmpty) '__cols_and_global(colsField) - else '__cols_and_global(colsField) - .map('__cols_element ~> - makeStruct( - // key struct - '_1 -> '__cols_element.selectFields(colKey: _*), - '_2 -> '__cols_element, - )) - .sort(true, onKey = true) - .map('elt ~> 'elt('_2)) - - (let(__cols_and_global = lower(ctx, child, ab).getGlobals) in - makeStruct('rows -> sortedCols, 'global -> '__cols_and_global.dropFields(colsField))) + bindIR(lower(ctx, child, ab).getGlobals) { global => + val sortedCols = + if (colKey.isEmpty) global.get(colsField) + else global + .get(colsField) + .stream + .streamMap(elem => maketuple(elem.select(colKey), elem)) + .sort(ascending = true, onKey = true) + .stream + .streamMap(_.get(1)) + .toArray + + makestruct("rows" -> sortedCols, "global" -> global.drop(colsField)) + } .parallelize(None) .keyBy(child.typ.colKey) @@ -1027,8 +1080,8 @@ object LowerMatrixIR { function.lower().getOrElse( WrappedMatrixToValueFunction( function, - colsFieldName, - entriesFieldName, + colsField, + entriesField, child.typ.colKey, ) ), @@ -1036,7 +1089,7 @@ object LowerMatrixIR { case MatrixWrite(child, writer) => TableWrite( lower(ctx, child, ab), - WrappedMatrixWriter(writer, colsFieldName, entriesFieldName, child.typ.colKey), + WrappedMatrixWriter(writer, colsField, entriesField, child.typ.colKey), ) case MatrixMultiWrite(children, writer) => TableMultiWrite( @@ -1044,41 +1097,39 @@ object LowerMatrixIR { WrappedMatrixNativeMultiWriter(writer, children.head.typ.colKey), ) case MatrixCount(child) => - lower(ctx, child, ab) - .aggregate(makeTuple(applyAggOp(Count(), FastSeq(), FastSeq()), 'global(colsField).len)) + lower(ctx, child, ab).aggregate { (global, _) => + maketuple(ApplyAggOp(Count())(), global.get(colsField).len) + } case MatrixAggregate(child, query) => - val lc = lower(ctx, child, ab) - TableAggregate( - lc, - Let( - FastSeq(MatrixIR.globalName -> globals(lc)), - M.agg { - for { - cols <- Name("__cols") -> colVals(lc) - entries <- Name("__entries") -> entries(lc) - _ <- MatrixIR.globalName -> globals(lc) - _ <- MatrixIR.rowName -> rowVal(lc) - } yield aggExplodeIR( - filterIR( - zip2( - ToStream(cols), - ToStream(entries), - ArrayZipBehavior.AssertSameLength, - ) { - (c, e) => maybeIR(e)(e => maketuple(c, e)) - } - )(r => ApplyUnaryPrimOp(Bang, IsNA(r))) - ) { explodedTuple => - M.agg { - (MatrixIR.colName -> GetTupleElement(explodedTuple, 0)) >> - (MatrixIR.entryName -> GetTupleElement(explodedTuple, 1)) >> - query + lower(ctx, child, ab).aggregate { (global, row) => + M.eval { + (MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames)) >> + M.agg { + for { + cols <- global.get(colsField) + entries <- row.get(entriesField) + _ <- MatrixIR.globalName -> global.select(child.typ.globalType.fieldNames) + _ <- MatrixIR.rowName -> row.select(child.typ.rowType.fieldNames) + } yield zip2( + ToStream(cols), + ToStream(entries), + ArrayZipBehavior.AssertSameLength, + ) { + (c, e) => maybeIR(e)(e => maketuple(c, e)) } + .filter(!_.isNA) + .aggExplode { explodedTuple => + M.agg { + (MatrixIR.colName -> GetTupleElement(explodedTuple, 0)) >> + (MatrixIR.entryName -> GetTupleElement(explodedTuple, 1)) >> + query + } + } } - }, - ), - ) - case _ => lowerChildren(ctx, ir, ab).asInstanceOf[IR] + } + } + case _ => + lowerChildren(ctx, ir, ab).asInstanceOf[IR] } assertTypeUnchanged(ir, lowered) lowered diff --git a/hail/hail/src/is/hail/expr/ir/MatrixIR.scala b/hail/hail/src/is/hail/expr/ir/MatrixIR.scala index 1d06d866ce1..dbec89da3a6 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixIR.scala @@ -156,15 +156,14 @@ trait MatrixReader { TableType( rowType = if (mt.rowType.hasField(rowUIDFieldName)) mt.rowType.deleteKey(rowUIDFieldName) - .appendKey(LowerMatrixIR.entriesFieldName, TArray(mt.entryType)) + .appendKey(LowerMatrixIR.entriesField, TArray(mt.entryType)) .appendKey(TableReader.uidFieldName, mt.rowType.fieldType(rowUIDFieldName)) else - mt.rowType.appendKey(LowerMatrixIR.entriesFieldName, TArray(mt.entryType)), + mt.rowType.appendKey(LowerMatrixIR.entriesField, TArray(mt.entryType)), key = mt.rowKey, - globalType = if (includeColsArray) - mt.globalType.appendKey(LowerMatrixIR.colsFieldName, TArray(mt.colType)) - else - mt.globalType, + globalType = + if (includeColsArray) mt.globalType.appendKey(LowerMatrixIR.colsField, TArray(mt.colType)) + else mt.globalType, ) } @@ -197,14 +196,14 @@ abstract class MatrixHybridReader extends TableReaderWithExtraUID with MatrixRea tr, InsertFields( Ref(TableIR.rowName, tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray.empty(requestedType.entryType)), + FastSeq(LowerMatrixIR.entriesField -> MakeArray.empty(requestedType.entryType)), ), ) tr = TableMapGlobals( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray.empty(requestedType.colType)), + FastSeq(LowerMatrixIR.colsField -> MakeArray.empty(requestedType.colType)), ), ) } @@ -293,14 +292,14 @@ class MatrixNativeReader( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray.empty(requestedType.colType)), + FastSeq(LowerMatrixIR.colsField -> MakeArray.empty(requestedType.colType)), ), ) TableMapRows( tr, InsertFields( Ref(TableIR.rowName, tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray.empty(requestedType.entryType)), + FastSeq(LowerMatrixIR.entriesField -> MakeArray.empty(requestedType.entryType)), ), ) } else { @@ -338,7 +337,7 @@ class MatrixNativeReader( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> ToArray(cols)), + FastSeq(LowerMatrixIR.colsField -> ToArray(cols)), ), ) } @@ -417,30 +416,35 @@ case class MatrixRangeReader( .rename(Map("idx" -> "row_idx")) ht = if (requestedType.colType.hasField(colUIDFieldName)) - ht.mapGlobals(makeStruct( - LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> - makeStruct( - 'col_idx -> 'i, - Symbol(colUIDFieldName) -> 'i.toL, - )) - )) + ht.mapGlobals { _ => + makestruct( + LowerMatrixIR.colsField -> + ToArray(rangeIR(nColsAdj).streamMap { i => + makestruct("col_idx" -> i, colUIDFieldName -> i.toL) + }) + ) + } else - ht.mapGlobals(makeStruct( - LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i)) - )) + ht.mapGlobals { _ => + makestruct( + LowerMatrixIR.colsField -> + rangeIR(nColsAdj).streamMap(i => makestruct("col_idx" -> i)).toArray + ) + } if (requestedType.rowType.hasField(rowUIDFieldName)) - ht.mapRows('row.insertFields( - LowerMatrixIR.entriesField -> irRange(0, nColsAdj).map('i ~> makestruct()), - Symbol(rowUIDFieldName) -> 'row('row_idx).toL, - )) + ht.mapRows { (_, row) => + row.insert( + LowerMatrixIR.entriesField -> rangeIR(nColsAdj).streamMap(_ => makestruct()).toArray, + rowUIDFieldName -> row.get("row_idx").toL, + ) + } else - ht.mapRows('row.insertFields( - LowerMatrixIR.entriesField -> - irRange(0, nColsAdj).map('i ~> makestruct()) - )) + ht.mapRows { (_, row) => + row.insert( + LowerMatrixIR.entriesField -> rangeIR(nColsAdj).streamMap(_ => makestruct()).toArray + ) + } } override def toJValue: JValue = { diff --git a/hail/hail/src/is/hail/expr/ir/MatrixValue.scala b/hail/hail/src/is/hail/expr/ir/MatrixValue.scala index c145b7aee01..34d89b9bcf5 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixValue.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixValue.scala @@ -24,20 +24,20 @@ case class MatrixValue( typ: MatrixType, tv: TableValue, ) extends Logging { - val colFieldType = tv.globals.t.fieldType(LowerMatrixIR.colsFieldName).asInstanceOf[PArray] + val colFieldType = tv.globals.t.fieldType(LowerMatrixIR.colsField).asInstanceOf[PArray] assert(colFieldType.required) assert(colFieldType.elementType.required) lazy val globals: BroadcastRow = { val prevGlobals = tv.globals - val newT = prevGlobals.t.deleteField(LowerMatrixIR.colsFieldName) + val newT = prevGlobals.t.deleteField(LowerMatrixIR.colsField) val rvb = new RegionValueBuilder(HailStateManager(Map.empty), prevGlobals.value.region) rvb.start(newT) rvb.startStruct() rvb.addFields( prevGlobals.t, prevGlobals.value, - prevGlobals.t.fields.filter(_.name != LowerMatrixIR.colsFieldName).map(_.index), + prevGlobals.t.fields.filter(_.name != LowerMatrixIR.colsField).map(_.index), ) rvb.endStruct() BroadcastRow(tv.ctx, RegionValue(prevGlobals.value.region, rvb.end()), newT) @@ -45,7 +45,7 @@ case class MatrixValue( lazy val colValues: BroadcastIndexedSeq = { val prevGlobals = tv.globals - val field = prevGlobals.t.field(LowerMatrixIR.colsFieldName) + val field = prevGlobals.t.field(LowerMatrixIR.colsField) val t = field.typ.asInstanceOf[PArray] BroadcastIndexedSeq( tv.ctx, @@ -361,7 +361,7 @@ object MatrixValue { colValues: IndexedSeq[Row], rvd: RVD, ): MatrixValue = { - val globalsType = typ.globalType.appendKey(LowerMatrixIR.colsFieldName, TArray(typ.colType)) + val globalsType = typ.globalType.appendKey(LowerMatrixIR.colsField, TArray(typ.colType)) val globalsPType = PType.canonical(globalsType).asInstanceOf[PStruct] val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) rvb.start(globalsPType) diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index 8e2a362d6d8..e3f310bf0aa 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -391,8 +391,8 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow def toMatrixValue( colKey: IndexedSeq[String], - colsFieldName: String = LowerMatrixIR.colsFieldName, - entriesFieldName: String = LowerMatrixIR.entriesFieldName, + colsFieldName: String = LowerMatrixIR.colsField, + entriesFieldName: String = LowerMatrixIR.entriesField, ): MatrixValue = { val (colType, colsFieldIdx) = typ.globalType.field(colsFieldName) match { @@ -429,8 +429,8 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow MatrixValue( mType, newTV.rename( - Map(colsFieldName -> LowerMatrixIR.colsFieldName), - Map(entriesFieldName -> LowerMatrixIR.entriesFieldName), + Map(colsFieldName -> LowerMatrixIR.colsField), + Map(entriesFieldName -> LowerMatrixIR.entriesField), ), ) } diff --git a/hail/hail/src/is/hail/expr/ir/TableWriter.scala b/hail/hail/src/is/hail/expr/ir/TableWriter.scala index c770165403a..7b40fc60f81 100644 --- a/hail/hail/src/is/hail/expr/ir/TableWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/TableWriter.scala @@ -1147,7 +1147,7 @@ case class WrappedMatrixNativeMultiWriter( writer.lower( ctx, ts.map { case (ts, rt) => - (LowerMatrixIR.colsFieldName, LowerMatrixIR.entriesFieldName, colKey, ts, rt) + (LowerMatrixIR.colsField, LowerMatrixIR.entriesField, colKey, ts, rt) }, ) diff --git a/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala index 9ca456af5b5..0b82e8651c5 100644 --- a/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala @@ -5,11 +5,7 @@ import is.hail.ExecStrategy.ExecStrategy import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir.DeprecatedIRBuilder._ -import is.hail.expr.ir.defs.{ - AggFilter, AggGroupBy, ApplyAggOp, ApplyBinaryPrimOp, ArrayRef, Cast, GetField, I32, InsertFields, - MakeStruct, MakeTuple, Ref, Str, StreamAgg, StreamAggScan, StreamRange, TableAggregate, ToArray, - ToStream, -} +import is.hail.expr.ir.defs.{AggFilter, AggGroupBy, ApplyAggOp, ApplyBinaryPrimOp, ArrayRef, GetField, I32, InsertFields, MakeStruct, MakeTuple, Ref, Str, StreamAgg, StreamAggScan, StreamRange, TableAggregate, ToArray, ToStream} import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.types.virtual._ import is.hail.variant.Call2 @@ -1222,17 +1218,10 @@ class AggregatorsSuite extends HailSuite { @Test def testArrayElementsAggregator(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly - def getAgg(n: Int, m: Int): IR = { - val ht = TableRange(10, 3) - .mapRows('row.insertFields('aRange -> irRange(0, m, 1))) - - TableAggregate( - ht, - aggArrayPerElement( - GetField(Ref(TableIR.rowName, ht.typ.rowType), "aRange") - )((elt, _) => ApplyAggOp(FastSeq(), FastSeq(Cast(elt, TInt64)), Sum())), - ) - } + def getAgg(n: Int, m: Int): IR = + TableRange(n, 3) + .mapRows((_, row) => row.insert("aRange" -> rangeIR(m).toArray)) + .aggregate((_, row) => row.aggElements()((elt, _) => ApplyAggOp(Sum())(elt.toL))) assertEvalsTo(getAgg(10, 10), IndexedSeq.range(0, 10).map(_ * 10L)) } @@ -1241,18 +1230,12 @@ class AggregatorsSuite extends HailSuite { implicit val execStrats = ExecStrategy.interpretOnly def getAgg(n: Int, m: Int, knownLength: Option[IR]): IR = { - val ht = TableRange(10, 3) - .mapRows('row.insertFields('aRange -> irRange(0, m, 1))) - .mapGlobals('global.insertFields('m -> m)) - .filter(false) - - TableAggregate( - ht, - aggArrayPerElement( - GetField(Ref(TableIR.rowName, ht.typ.rowType), "aRange"), - knownLength, - )((elt, _) => ApplyAggOp(FastSeq(), FastSeq(Cast(elt, TInt64)), Sum())), - ) + TableRange(n, 3) + .mapRows((_, row) => row.insert("aRange" -> rangeIR(m).toArray)) + .filter((_, _) => false) + .aggregate { (_, row) => + row.get("aRange").aggElements(knownLength)((elt, _) => ApplyAggOp(Sum())(elt.toL)) + } } assertEvalsTo(getAgg(10, 10, None), null)