diff --git a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala index 84458e7dbb3..5d24fbe7053 100644 --- a/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala +++ b/hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala @@ -3,13 +3,14 @@ package is.hail.expr.ir import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.implicits.toRichIterable +import is.hail.expr.ir.Scope._ import is.hail.expr.ir.defs._ import is.hail.types.virtual._ import scala.language.dynamics object DeprecatedIRBuilder { - type E = Env[Type] + type E = BindingEnv[Type] implicit def funcToIRProxy(ir: E => IR): IRProxy = new IRProxy(ir) @@ -25,7 +26,7 @@ object DeprecatedIRBuilder { implicit def booleanToProxy(b: Boolean): IRProxy = if (b) True() else False() implicit def ref(s: Symbol): IRProxy = (env: E) => - Ref(Name(s.name), env.lookup(Name(s.name))) + Ref(Name(s.name), env.eval.lookup(Name(s.name))) implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s) @@ -56,9 +57,8 @@ object DeprecatedIRBuilder { def concatStructs(struct1: IRProxy, struct2: IRProxy): IRProxy = (env: E) => { val s2Type = struct2(env).typ.asInstanceOf[TStruct] - let(__struct2 = struct2) { - struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*) - }(env) + (let(__struct2 = struct2) in + struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*))(env) } def makeTuple(values: IRProxy*): IRProxy = (env: E) => @@ -69,34 +69,31 @@ object DeprecatedIRBuilder { initOpArgs: IndexedSeq[IRProxy] = FastSeq(), seqOpArgs: IndexedSeq[IRProxy] = FastSeq(), ): IRProxy = (env: E) => { - val i = initOpArgs.map(x => x(env)) - val s = seqOpArgs.map(x => x(env)) + val i = initOpArgs.map(x => x(env.noAgg)) + val s = seqOpArgs.map(x => x(env.promoteAgg)) ApplyAggOp(i, s, op) } def aggFilter(filterCond: IRProxy, query: IRProxy, isScan: Boolean = false): IRProxy = (env: E) => - AggFilter(filterCond(env), query(env), isScan) + AggFilter(filterCond(env.promoteAgg), query(env), isScan) class TableIRProxy(val tir: TableIR) extends AnyVal { - def empty: E = Env.empty - - def globalEnv: E = typ.globalEnv - - def env: E = typ.rowEnv + def empty: E = BindingEnv.empty def typ: TableType = tir.typ def getGlobals: IR = TableGetGlobals(tir) def mapGlobals(newGlobals: IRProxy): TableIR = - TableMapGlobals(tir, newGlobals(globalEnv)) + TableMapGlobals(tir, newGlobals(BindingEnv(typ.globalEnv))) def mapRows(newRow: IRProxy): TableIR = - TableMapRows(tir, newRow(env)) + TableMapRows(tir, newRow(BindingEnv(typ.rowEnv, scan = Some(typ.rowEnv)))) def explode(sym: Symbol): TableIR = TableExplode(tir, FastSeq(sym.name)) - def aggregateByKey(aggIR: IRProxy): TableIR = TableAggregateByKey(tir, aggIR(env)) + def aggregateByKey(aggIR: IRProxy): TableIR = + TableAggregateByKey(tir, aggIR(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv)))) def keyBy(keys: IndexedSeq[String], isSorted: Boolean = false): TableIR = TableKeyBy(tir, keys, isSorted) @@ -108,7 +105,7 @@ object DeprecatedIRBuilder { rename(Map.empty, globalMap) def filter(ir: IRProxy): TableIR = - TableFilter(tir, ir(env)) + TableFilter(tir, ir(BindingEnv(typ.rowEnv))) def distinct(): TableIR = TableDistinct(tir) @@ -129,7 +126,7 @@ object DeprecatedIRBuilder { } def aggregate(ir: IRProxy): IR = - TableAggregate(tir, ir(env)) + TableAggregate(tir, ir(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv)))) } class IRProxy(val ir: E => IR) extends AnyVal with Dynamic { @@ -214,7 +211,7 @@ object DeprecatedIRBuilder { case _: TStruct => GetField(eval, lookup.name) case _: TArray => - ArrayRef(ir(env), ref(lookup)(env)) + ArrayRef(eval, ref(lookup)(env)) } } @@ -259,65 +256,70 @@ object DeprecatedIRBuilder { def isNA: IRProxy = (env: E) => IsNA(ir(env)) - def orElse(alt: IRProxy): IRProxy = { env: E => - val uid = freshName() - val eir = ir(env) - Let(FastSeq(uid -> eir), If(IsNA(Ref(uid, eir.typ)), alt(env), Ref(uid, eir.typ))) - } + def orElse(alt: IRProxy): IRProxy = + (env: E) => bindIR(ir(env))(x => If(IsNA(x), alt(env), x)) - def filter(pred: LambdaProxy): IRProxy = (env: E) => { + def filter(pred: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(pred.s.name) -> TIterable.elementType(array.typ) ToArray(StreamFilter( ToStream(array), - Name(pred.s.name), - pred.body(env.bind(Name(pred.s.name) -> eltType)), + binding._1, + pred.body(env.bindEval(binding)), )) } - def map(f: LambdaProxy): IRProxy = (env: E) => { + def map(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) ToArray(StreamMap( ToStream(array), - Name(f.s.name), - f.body(env.bind(Name(f.s.name) -> eltType)), + binding._1, + f.body(env.bindEval(binding)), )) } - def aggExplode(f: LambdaProxy): IRProxy = (env: E) => { - val array = ir(env) + def aggExplode(f: LambdaProxy): IRProxy = { env: E => + val array = ir(env.promoteAgg) + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) AggExplode( ToStream(array), - Name(f.s.name), - f.body(env.bind(Name(f.s.name), array.typ.asInstanceOf[TArray].elementType)), + binding._1, + f.body(env.bindEval(binding).bindAgg(binding)), isScan = false, ) } - def flatMap(f: LambdaProxy): IRProxy = (env: E) => { + def flatMap(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) ToArray(StreamFlatMap( ToStream(array), - Name(f.s.name), - ToStream(f.body(env.bind(Name(f.s.name) -> eltType))), + binding._1, + ToStream(f.body(env.bindEval(binding))), )) } - def streamAgg(f: LambdaProxy): IRProxy = (env: E) => { + def flatten: IRProxy = + flatMap('a ~> 'a) + + def streamAgg(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType - StreamAgg(ToStream(array), Name(f.s.name), f.body(env.bind(Name(f.s.name) -> eltType))) + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) + StreamAgg( + ToStream(array), + binding._1, + f.body(env.bindEval(binding).createAgg), + ) } - def streamAggScan(f: LambdaProxy): IRProxy = (env: E) => { + def streamAggScan(f: LambdaProxy): IRProxy = { env: E => val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + val binding = Name(f.s.name) -> TIterable.elementType(array.typ) ToArray(StreamAggScan( ToStream(array), - Name(f.s.name), - f.body(env.bind(Name(f.s.name) -> eltType)), + binding._1, + f.body(env.bindEval(binding).createScan), )) } @@ -338,14 +340,20 @@ object DeprecatedIRBuilder { knownLength: Option[IRProxy], )( aggBody: IRProxy - ): IRProxy = (env: E) => { - val array = ir(env) - val eltType = array.typ.asInstanceOf[TArray].elementType + ): IRProxy = { env: E => + val array = ir(env.promoteAgg) + + val bindings = + FastSeq( + Name(elementsSym.name) -> TIterable.elementType(array.typ), + Name(indexSym.name) -> TInt32, + ) + AggArrayPerElement( array, - Name(elementsSym.name), - Name(indexSym.name), - aggBody.apply(env.bind(Name(elementsSym.name) -> eltType, Name(indexSym.name) -> TInt32)), + bindings(0)._1, + bindings(1)._1, + aggBody(env.bindEval(bindings: _*).bindAgg(bindings: _*)), knownLength.map(_(env)), isScan = false, ) @@ -361,18 +369,17 @@ object DeprecatedIRBuilder { def toDict: IRProxy = (env: E) => ToDict(ToStream(ir(env))) def parallelize(nPartitions: Option[Int] = None): TableIR = - TableParallelize(ir(Env.empty), nPartitions) + TableParallelize(ir(BindingEnv.empty), nPartitions) - def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = { - val element = Symbol(genUID()) - ir - .map(element ~> + def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = + ir.map( + '__elem ~> makeTuple( - element.selectFields(keyFields: _*), - element.dropFieldList(keyFields), - )) + '__elem.selectFields(keyFields: _*), + '__elem.dropFieldList(keyFields), + ) + ) .toDict - } def tupleElement(i: Int): IRProxy = (env: E) => GetTupleElement(ir(env), i) @@ -391,8 +398,13 @@ object DeprecatedIRBuilder { def bind(bindings: IndexedSeq[BindingProxy], body: IRProxy, env: E): IR = { var newEnv = env val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) => - val resolvedValue = value(newEnv) - newEnv = newEnv.bind(Name(sym.name) -> resolvedValue.typ) + val resolvedValue = + value( + if (scope == AGG) newEnv.promoteAgg + else if (scope == SCAN) newEnv.promoteScan + else newEnv + ) + newEnv = newEnv.bindInScope(Name(sym.name), resolvedValue.typ, scope) Binding(Name(sym.name), resolvedValue, scope) } Block(resolvedBindings, body(newEnv)) @@ -412,38 +424,22 @@ object DeprecatedIRBuilder { } class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { - def apply(body: IRProxy): IRProxy = in(body) - - def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) } + def in(body: IRProxy): IRProxy = { env: E => LetProxy.bind(bindings, body, env) } } object aggLet extends Dynamic { - def applyDynamicNamed(method: String)(args: (String, IRProxy)*): AggLetProxy = { + def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { assert(method == "apply") - new AggLetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) }) + new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) }) } } - class AggLetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { - def apply(body: IRProxy): IRProxy = in(body) - - def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) } - } - - object MapIRProxy { - def apply(f: (IRProxy) => IRProxy)(x: IRProxy): IRProxy = (e: E) => - MapIR(x => f(x)(e))(x(e)) + object scanLet extends Dynamic { + def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { + assert(method == "apply") + new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.SCAN) }) + } } - def subst(x: IRProxy, env: BindingEnv[IRProxy]): IRProxy = (e: E) => - Subst( - x(e), - BindingEnv( - env.eval.mapValues(_(e)), - agg = env.agg.map(_.mapValues(_(e))), - scan = env.scan.map(_.mapValues(_(e))), - ), - ) - def lift(f: (IR) => IRProxy)(x: IRProxy): IRProxy = (e: E) => f(x(e))(e) } diff --git a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala index 5355ba3762d..91a0a203f01 100644 --- a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala +++ b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala @@ -40,7 +40,7 @@ object ForwardLets extends Logging { else { logger.info( f"Eliminating unused binding:\n" + - f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value, preserveNames = true).trim}" + f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value).trim}" ) env } diff --git a/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala b/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala index e2ec5c944ef..86a96379a1a 100644 --- a/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala @@ -6,7 +6,7 @@ import is.hail.asm4s.implicits.toRichCodeIterator import is.hail.backend.ExecuteContext import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer -import is.hail.expr.ir.defs.{Literal, PartitionReader, ReadPartition, ToStream} +import is.hail.expr.ir.defs.{Atom, Literal, PartitionReader, ReadPartition, ToStream} import is.hail.expr.ir.functions.UtilFunctions import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.expr.ir.streams.StreamProducer @@ -182,7 +182,7 @@ class GenericTableValue( def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any) : TableStage = { val globalsIR = Literal(requestedType.globalType, globals(requestedType.globalType)) - val requestedBody: (IR) => (IR) = (ctx: IR) => + val requestedBody: Atom => IR = ctx => ReadPartition( ctx, requestedType.rowType, diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index e91756b1d81..3ae64f37ca5 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -97,20 +97,16 @@ package defs { } object Let { - def apply(bindings: IndexedSeq[(Name, IR)], body: IR): Block = - Block( - bindings.map { case (name, value) => Binding(name, value) }, - body, - ) + def apply(bindings: IndexedSeq[(Name, IR)], body: IR): IR = + if (bindings.isEmpty) body + else Block(bindings.map { case (n, v) => Binding(n, v) }, body) - def void(bindings: IndexedSeq[(Name, IR)]): IR = { - if (bindings.isEmpty) { - Void() - } else { + def void(bindings: IndexedSeq[(Name, IR)]): IR = + if (bindings.isEmpty) Void() + else { assert(bindings.last._2.typ == TVoid) Let(bindings.init, bindings.last._2) } - } } object Begin { @@ -149,10 +145,8 @@ package defs { requiresMemoryManagement: Boolean, rightKeyIsDistinct: Boolean = false, ): IR = { - val lType = tcoerce[TStream](left.typ) - val rType = tcoerce[TStream](right.typ) - val lEltType = tcoerce[TStruct](lType.elementType) - val rEltType = tcoerce[TStruct](rType.elementType) + val lEltType = TIterable.elementType(left.typ).asInstanceOf[TStruct] + val rEltType = TIterable.elementType(right.typ).asInstanceOf[TStruct] assert(lEltType.typeAfterSelectNames(lKey) isJoinableWith rEltType.typeAfterSelectNames(rKey)) if (!rightKeyIsDistinct) { @@ -169,19 +163,10 @@ package defs { } } - val rElt = Ref(freshName(), tcoerce[TStream](rightGrouped.typ).elementType) - val lElt = Ref(freshName(), lEltType) - val makeTupleFromJoin = MakeStruct(FastSeq("left" -> lElt, "rightGroup" -> rElt)) - val joined = StreamJoinRightDistinct( - left, - rightGrouped, - lKey, - rKey, - lElt.name, - rElt.name, - makeTupleFromJoin, - joinType, - ) + val joined = + joinRightDistinctIR(left, rightGrouped, lKey, rKey, joinType) { (l, r) => + makestruct("left" -> l, "rightGroup" -> r) + } // joined is a stream of {leftElement, rightGroup} bindIR(MakeArray(NA(rEltType))) { missingSingleton => @@ -538,6 +523,9 @@ package defs { MakeArray(args.toFastSeq, TArray(args.head.typ)) } + def empty(elementType: Type): MakeArray = + MakeArray(FastSeq.empty[IR], TArray(elementType)) + def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null) : MakeArray = { assert(requestedType != null || args.nonEmpty) @@ -637,6 +625,9 @@ package defs { def single(x: IR): IR with TypedIR[TStream] = MakeStream(FastSeq(x), TStream(x.typ)) + + def empty(elemType: Type): IR with TypedIR[TStream] = + MakeStream(FastSeq.empty, TStream(elemType)) } abstract class ArraySortCompanionExt { diff --git a/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala b/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala index ecab3fef958..940f80452a3 100644 --- a/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/LowerMatrixIR.scala @@ -4,6 +4,7 @@ import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.compat.mutable.Growable +import is.hail.expr.ir.{Memoized => M} import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.{WrappedMatrixToTableFunction, WrappedMatrixToValueFunction} import is.hail.types.virtual._ @@ -15,44 +16,40 @@ object LowerMatrixIR { val colsField: Symbol = Symbol(colsFieldName) val entriesField: Symbol = Symbol(entriesFieldName) - def apply(ctx: ExecuteContext, ir: IR): IR = { + def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = { val ab = ArraySeq.newBuilder[(Name, IR)] - val l1 = lower(ctx, ir, ab) - ab.result().foldRight[IR](l1) { case ((ident, value), body) => - RelationalLet(ident, value, body) - } - } - def apply(ctx: ExecuteContext, tir: TableIR): TableIR = { - val ab = ArraySeq.newBuilder[(Name, IR)] - val l1 = lower(ctx, tir, ab) - ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => - RelationalLetTable(ident, value, body) - } - } - - def apply(ctx: ExecuteContext, mir: MatrixIR): TableIR = { - val ab = ArraySeq.newBuilder[(Name, IR)] - - val l1 = lower(ctx, mir, ab) - ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => - RelationalLetTable(ident, value, body) - } - } - - def apply(ctx: ExecuteContext, bmir: BlockMatrixIR): BlockMatrixIR = { - val ab = ArraySeq.newBuilder[(Name, IR)] + val lowered = + ir0 match { + case ir: IR => + val l1 = lower(ctx, ir, ab) + ab.result().foldRight[IR](l1) { case ((ident, value), body) => + RelationalLet(ident, value, body) + } + case tir: TableIR => + val l1 = lower(ctx, tir, ab) + ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => + RelationalLetTable(ident, value, body) + } + case mir: MatrixIR => + val l1 = lower(ctx, mir, ab) + ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => + RelationalLetTable(ident, value, body) + } + case bmir: BlockMatrixIR => + val l1 = lower(ctx, bmir, ab) + assert(ab.result().isEmpty) + l1 + } - val l1 = lower(ctx, bmir, ab) - assert(ab.result().isEmpty) - l1 + NormalizeNames()(ctx, lowered) } - private[this] def lowerChildren( + private def lowerChildren( ctx: ExecuteContext, ir: BaseIR, ab: Growable[(Name, IR)], - ): BaseIR = { + ): BaseIR = ir.mapChildren { case tir: TableIR => lower(ctx, tir, ab) case mir: MatrixIR => throw new RuntimeException(s"expect specialized lowering rule for " + @@ -60,59 +57,38 @@ object LowerMatrixIR { case bmir: BlockMatrixIR => lower(ctx, bmir, ab) case vir: IR => lower(ctx, vir, ab) } - } def colVals(tir: TableIR): IR = GetField(Ref(TableIR.globalName, tir.typ.globalType), colsFieldName) - def globals(tir: TableIR): IR = + def globals(tir: TableIR): IR = { + val globalType = tir.typ.globalType SelectFields( - Ref(TableIR.globalName, tir.typ.globalType), - tir.typ.globalType.fieldNames.diff(FastSeq(colsFieldName)), + Ref(TableIR.globalName, globalType), + globalType.fieldNames.diff(FastSeq(colsFieldName)), ) + } - def nCols(tir: TableIR): IR = ArrayLen(colVals(tir)) + def rowVal(tir: TableIR): IR = { + val rowType = tir.typ.rowType + SelectFields( + Ref(TableIR.rowName, rowType), + rowType.fieldNames.diff(FastSeq(entriesFieldName)), + ) + } def entries(tir: TableIR): IR = GetField(Ref(TableIR.rowName, tir.typ.rowType), entriesFieldName) import is.hail.expr.ir.DeprecatedIRBuilder._ - def matrixSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { - val e = Env[IRProxy]( - MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*), - MatrixIR.rowName -> 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) - BindingEnv(e, agg = Some(e), scan = Some(e)) - } - - def matrixGlobalSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { - val e = - Env[IRProxy](MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*)) - BindingEnv(e, agg = Some(e), scan = Some(e)) - } - - def matrixSubstEnvIR(child: MatrixIR, lowered: TableIR): BindingEnv[IR] = { - val e = Env[IR]( - MatrixIR.globalName -> SelectFields( - Ref(TableIR.globalName, lowered.typ.globalType), - child.typ.globalType.fieldNames, - ), - MatrixIR.rowName -> SelectFields( - Ref(TableIR.rowName, lowered.typ.rowType), - child.typ.rowType.fieldNames, - ), - ) - BindingEnv(e, agg = Some(e), scan = Some(e)) - } - private def bindingsToStruct(bindings: IndexedSeq[(Name, IR)]): MakeStruct = MakeStruct(bindings.map { case (n, ir) => n.str -> ir }) - private def unwrapStruct(bindings: IndexedSeq[(Name, IR)], struct: Atom): IndexedSeq[(Name, IR)] = + private def unwrapStruct(bindings: IndexedSeq[(Name, _)], struct: Atom): IndexedSeq[(Name, IR)] = bindings.map { case (name, _) => name -> GetField(struct, name.str) } - private[this] def lower( + private def lower( ctx: ExecuteContext, mir: MatrixIR, liftedRelationalLets: Growable[(Name, IR)], @@ -127,7 +103,7 @@ object LowerMatrixIR { case CastTableToMatrix(child, entries, cols, _) => val lc = lower(ctx, child, liftedRelationalLets) - val row = Ref(TableIR.rowName, lc.typ.rowType) + val row: Atom = Ref(TableIR.rowName, lc.typ.rowType) val glob = Ref(TableIR.globalName, lc.typ.globalType) TableMapRows( lc, @@ -141,14 +117,12 @@ object LowerMatrixIR { entriesLen cne colsLen, Die( strConcat( - Str( - "length mismatch between entry array and column array in 'to_matrix_table_row_major': " - ), - invoke("str", TString, entriesLen), - Str(" entries, "), - invoke("str", TString, colsLen), - Str(" cols, at "), - invoke("str", TString, SelectFields(row, child.typ.key)), + "length mismatch between entry array and column array in 'to_matrix_table_row_major': ", + entriesLen, + " entries, ", + colsLen, + " cols, at ", + SelectFields(row, child.typ.key), ), row.typ, -1, @@ -169,18 +143,16 @@ object LowerMatrixIR { if (colMap.nonEmpty) { val newColsType = TArray(child.typ.colType.rename(colMap)) - t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq(( - colsFieldName, - newColsType, - ))))) + t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq( + colsFieldName -> newColsType + )))) } if (entryMap.nonEmpty) { val newEntriesType = TArray(child.typ.entryType.rename(entryMap)) - t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq(( - entriesFieldName, - newEntriesType, - ))))) + t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq( + entriesFieldName -> newEntriesType + )))) } t @@ -190,22 +162,36 @@ object LowerMatrixIR { case MatrixFilterRows(child, pred) => lower(ctx, child, liftedRelationalLets) - .filter(subst(lower(ctx, pred, liftedRelationalLets), matrixSubstEnv(child))) + .filter( + let( + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in lower(ctx, pred, liftedRelationalLets) + ) case MatrixFilterCols(child, pred) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('newColIdx -> - irRange(0, 'global(colsField).len) - .filter('i ~> - (let(sa = 'global(colsField)('i)) - in subst(lower(ctx, pred, liftedRelationalLets), matrixGlobalSubstEnv(child)))))) - .mapRows('row.insertFields( - entriesField -> 'global('newColIdx).map('i ~> 'row(entriesField)('i)) - )) - .mapGlobals('global - .insertFields(colsField -> - 'global('newColIdx).map('i ~> 'global(colsField)('i))) - .dropFields('newColIdx)) + .mapGlobals( + 'global.insertFields( + '__new_col_idx -> + (let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in irRange(0, '__cols.len).filter('__col_idx ~> + (let(sa = '__cols('__col_idx)) in + lower(ctx, pred, liftedRelationalLets)))) + ) + ) + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields(entriesField -> 'global('__new_col_idx).map('i ~> '__entries('i))) + ) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global + .insertFields(colsField -> 'global('__new_col_idx).map('i ~> '__cols('i))) + .dropFields('__new_col_idx) + ) case MatrixAnnotateRowsTable(child, table, root, product) => val kt = table.typ.keyType @@ -225,44 +211,48 @@ object LowerMatrixIR { case MatrixChooseCols(child, oldIndices) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('newColIdx -> oldIndices.map(I32))) - .mapRows('row.insertFields( - entriesField -> 'global('newColIdx).map('i ~> 'row(entriesField)('i)) - )) - .mapGlobals('global - .insertFields(colsField -> 'global('newColIdx).map('i ~> 'global(colsField)('i))) - .dropFields('newColIdx)) + .mapGlobals('global.insertFields('__new_col_idx -> Literal(TArray(TInt32), oldIndices))) + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields(entriesField -> 'global('__new_col_idx).map('i ~> '__entries('i))) + ) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global + .insertFields(colsField -> 'global('__new_col_idx).map('i ~> '__cols('i))) + .dropFields('__new_col_idx) + ) case MatrixAnnotateColsTable(child, table, root) => - val col = Symbol(genUID()) - val colKey = makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => - Symbol(tk) -> col(Symbol(mck)) - }: _*) lower(ctx, child, liftedRelationalLets) - .mapGlobals(let(__dictfield = - lower(ctx, table, liftedRelationalLets) - .keyBy(FastSeq()) - .collect() - .apply('rows) - .arrayStructToDict(table.typ.key) - ) { - 'global.insertFields(colsField -> - 'global(colsField).map(col ~> col.insertFields(Symbol(root) -> '__dictfield.invoke( - "get", - table.typ.valueType, - colKey, - )))) - }) + .mapGlobals( + let( + __dictfield = + lower(ctx, table, liftedRelationalLets) + .keyBy(FastSeq()) + .collect() + .apply('rows) + .arrayStructToDict(table.typ.key) + ) in 'global.insertFields( + colsField -> { + val key = + makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => + Symbol(tk) -> '__cols(Symbol(mck)) + }: _*) + + 'global(colsField).map('__cols ~> + '__cols.insertFields( + Symbol(root) -> '__dictfield.invoke("get", table.typ.valueType, key) + )) + } + ) + ) case MatrixMapGlobals(child, newGlobals) => lower(ctx, child, liftedRelationalLets) .mapGlobals( - subst( - lower(ctx, newGlobals, liftedRelationalLets), - BindingEnv(Env[IRProxy]( - TableIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*) - )), - ) + (let(global = 'global.selectFields(child.typ.globalType.fieldNames: _*)) in + lower(ctx, newGlobals, liftedRelationalLets)) .insertFields(colsField -> 'global(colsField)) ) @@ -271,12 +261,12 @@ object LowerMatrixIR { def lift(ir: IR, builder: Growable[(Name, IR)]): IR = ir match { case a: ApplyScanOp => val s = freshName() - builder += ((s, a)) + builder += (s -> a) Ref(s, a.typ) case a @ AggFold(_, _, _, _, _, true) => val s = freshName() - builder += ((s, a)) + builder += (s -> a) Ref(s, a.typ) case AggFilter(filt, body, true) => @@ -305,14 +295,14 @@ object LowerMatrixIR { val aggIR = AggGroupBy(a, bindingsToStruct(aggs), true) val uid = Ref(freshName(), aggIR.typ) builder += (uid.name -> aggIR) - val elementType = aggIR.typ.asInstanceOf[TDict].elementType - val valueType = elementType.types(1) - val valueUID = Ref(freshName(), valueType) + ToDict(mapIR(ToStream(uid)) { eltUID => - Let( - (valueUID.name -> GetField(eltUID, "value")) +: unwrapStruct(aggs, valueUID), - MakeTuple.ordered(FastSeq(GetField(eltUID, "key"), liftedBody)), - ) + bindIR(GetField(eltUID, "value")) { value => + Let( + unwrapStruct(aggs, value), + maketuple(GetField(eltUID, "key"), liftedBody), + ) + } }) case AggArrayPerElement(a, elementName, indexName, body, knownLength, true) => @@ -332,9 +322,8 @@ object LowerMatrixIR { case Block(bindings, body) => val newBindings = ArraySeq.newBuilder[Binding] def go(i: Int, builder: Growable[(Name, IR)]): IR = { - if (i == bindings.length) { - lift(body, builder) - } else bindings(i) match { + if (i == bindings.length) lift(body, builder) + else bindings(i) match { case Binding(name, value, Scope.SCAN) => val ab = ArraySeq.newBuilder[(Name, IR)] val liftedBody = go(i + 1, ab) @@ -361,56 +350,49 @@ object LowerMatrixIR { val ab = ArraySeq.newBuilder[(Name, IR)] val b0 = lift(ir, ab) - val scans = ab.result() - val scanStruct = MakeStruct(scans.map { case (n, ir) => n.str -> ir }) - - val scanResultRef = Ref(freshName(), scanStruct.typ) - - val b1 = if (ContainsAgg(b0)) { - irRange(0, 'row(entriesField).len) - .filter('i ~> !'row(entriesField)('i).isNA) - .streamAgg('i ~> - (aggLet(sa = 'global(colsField)('i), g = 'row(entriesField)('i)) - in b0)) - } else + val b1 = if (ContainsAgg(b0)) + irRange(0, '__entries.len) + .filter('i ~> !'__entries('i).isNA) + .streamAgg('i ~> (aggLet(sa = '__cols('i), g = '__entries('i)) in b0)) + else irToProxy(b0) - letDyn( - ((scanResultRef.name, irToProxy(scanStruct)) - +: scans.map { case (name, _) => - name -> irToProxy(GetField(scanResultRef, name.str)) - }): _* - )(b1) + scanLet( + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in (letDyn(ab.result().map { case (name, expr) => name -> irToProxy(expr) }: _*) in b1) } - val lc = lower(ctx, child, liftedRelationalLets) - lc.mapRows(let(n_cols = 'global(colsField).len) { - liftScans(Subst(lower(ctx, newRow, liftedRelationalLets), matrixSubstEnvIR(child, lc))) + lower(ctx, child, liftedRelationalLets).mapRows( + (let( + __cols = 'global(colsField), + __entries = 'row(entriesField), + n_cols = '__cols.len, + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in liftScans(lower(ctx, newRow, liftedRelationalLets))) .insertFields(entriesField -> 'row(entriesField)) - }) + ) case MatrixMapCols(child, newCol, _) => - val loweredChild = lower(ctx, child, liftedRelationalLets) + val lc = lower(ctx, child, liftedRelationalLets) def lift(ir: IR, scanBindings: Growable[(Name, IR)], aggBindings: Growable[(Name, IR)]) : IR = ir match { case a: ApplyScanOp => val s = freshName() - scanBindings += ((s, a)) + scanBindings += (s -> a) Ref(s, a.typ) case a: ApplyAggOp => val s = freshName() - aggBindings += ((s, a)) + aggBindings += (s -> a) Ref(s, a.typ) case a @ AggFold(_, _, _, _, _, isScan) => val s = freshName() - if (isScan) { - scanBindings += ((s, a)) - } else { - aggBindings += ((s, a)) - } + if (isScan) scanBindings += (s -> a) + else aggBindings += (s -> a) Ref(s, a.typ) case AggFilter(filt, body, isScan) => @@ -420,11 +402,11 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val structResult = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val structResult = bindingsToStruct(aggs) val uid = Ref(freshName(), structResult.typ) builder += (uid.name -> AggFilter(filt, structResult, isScan)) - Let(aggs.map { case (name, _) => name -> GetField(uid, name.str) }, liftedBody) + Let(unwrapStruct(aggs, uid), liftedBody) case AggExplode(a, name, body, isScan) => val ab = ArraySeq.newBuilder[(Name, IR)] @@ -433,10 +415,10 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val structResult = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val structResult = bindingsToStruct(aggs) val uid = Ref(freshName(), structResult.typ) builder += (uid.name -> AggExplode(a, name, structResult, isScan)) - Let(aggs.map { case (name, _) => name -> GetField(uid, name.str) }, liftedBody) + Let(unwrapStruct(aggs, uid), liftedBody) case AggGroupBy(a, body, isScan) => val ab = ArraySeq.newBuilder[(Name, IR)] @@ -445,24 +427,15 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val aggIR = AggGroupBy(a, MakeStruct(aggs.map { case (n, ir) => n.str -> ir }), isScan) + val aggIR = AggGroupBy(a, bindingsToStruct(aggs), isScan) val uid = Ref(freshName(), aggIR.typ) builder += (uid.name -> aggIR) - val valueUID = freshName() - val elementType = aggIR.typ.asInstanceOf[TDict].elementType - val valueType = elementType.types(1) ToDict(mapIR(ToStream(uid)) { eltUID => - MakeTuple.ordered( - FastSeq( - GetField(eltUID, "key"), - Let( - (valueUID -> GetField(eltUID, "value")) +: - aggs.map { case (name, _) => - name -> GetField(Ref(valueUID, valueType), name.str) - }, - liftedBody, - ), - ) + maketuple( + GetField(eltUID, "key"), + bindIR(GetField(eltUID, "value")) { value => + Let(unwrapStruct(aggs, value), liftedBody) + }, ) }) @@ -473,22 +446,19 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val aggBody = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val aggBody = bindingsToStruct(aggs) val aggIR = AggArrayPerElement(a, elementName, indexName, aggBody, knownLength, isScan) val uid = Ref(freshName(), aggIR.typ) builder += (uid.name -> aggIR) - ToArray(mapIR(ToStream(uid)) { eltUID => - Let(aggs.map { case (name, _) => name -> GetField(eltUID, name.str) }, liftedBody) - }) + ToArray(mapIR(ToStream(uid))(eltUID => Let(unwrapStruct(aggs, eltUID), liftedBody))) case Block(bindings, body) => val newBindings = ArraySeq.newBuilder[Binding] def go(i: Int, scanBindings: Growable[(Name, IR)], aggBindings: Growable[(Name, IR)]) - : IR = { - if (i == bindings.length) { - lift(body, scanBindings, aggBindings) - } else bindings(i) match { + : IR = + if (i == bindings.length) lift(body, scanBindings, aggBindings) + else bindings(i) match { case Binding(name, value, Scope.EVAL) => val lifted = lift(value, scanBindings, aggBindings) val liftedBody = go(i + 1, scanBindings, aggBindings) @@ -496,24 +466,23 @@ object LowerMatrixIR { liftedBody case Binding(name, value, scope) => val ab = ArraySeq.newBuilder[(Name, IR)] - val liftedBody = if (scope == Scope.SCAN) - go(i + 1, ab, aggBindings) - else - go(i + 1, scanBindings, ab) + val liftedBody = + if (scope == Scope.SCAN) go(i + 1, ab, aggBindings) + else go(i + 1, scanBindings, ab) val builder = if (scope == Scope.SCAN) scanBindings else aggBindings val aggs = ab.result() - val structResult = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) + val structResult = bindingsToStruct(aggs) - val uid = freshName() - builder += (uid -> Block(FastSeq(Binding(name, value, scope)), structResult)) - newBindings ++= aggs.map { case (name, _) => - Binding(name, GetField(Ref(uid, structResult.typ), name.str), Scope.EVAL) - } + val uid = Ref(freshName(), structResult.typ) + builder += (uid.name -> Block(FastSeq(Binding(name, value, scope)), structResult)) + newBindings ++= unwrapStruct(aggs, uid).map(b => + Binding(b._1, b._2, Scope.EVAL) + ) liftedBody } - } + val newBody = go(0, scanBindings, aggBindings) Block(newBindings.result().reverse, newBody) @@ -528,207 +497,176 @@ object LowerMatrixIR { val aggBuilder = ArraySeq.newBuilder[(Name, IR)] val b0 = lift( - Subst(lower(ctx, newCol, liftedRelationalLets), matrixSubstEnvIR(child, loweredChild)), + lower(ctx, newCol, liftedRelationalLets), scanBuilder, aggBuilder, ) + val aggs = aggBuilder.result() val scans = scanBuilder.result() - val idx = Ref(freshName(), TInt32) - val idxSym = Symbol(idx.name.str) - - val noOp: (IRProxy => IRProxy, IRProxy => IRProxy) = (identity[IRProxy], identity[IRProxy]) + val noOp: (IRProxy => IRProxy, IRProxy => IRProxy) = + (identity[IRProxy], identity[IRProxy]) val ( aggOutsideTransformer: (IRProxy => IRProxy), aggInsideTransformer: (IRProxy => IRProxy), - ) = if (aggs.isEmpty) - noOp - else { - val aggStruct = MakeStruct(aggs.map { case (n, ir) => n.str -> ir }) - - val aggResult = loweredChild.aggregate( - aggLet(va = 'row.selectFields(child.typ.rowType.fieldNames: _*)) { - makeStruct( - ('count, applyAggOp(Count(), FastSeq(), FastSeq())), - ( - 'array_aggs, - irRange(0, 'global(colsField).len) - .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( - let(sa = 'global(colsField)('__result_idx)) { - aggLet( - sa = 'global(colsField)('__element_idx), - g = 'row(entriesField)('__element_idx), - ) { - aggFilter(!'g.isNA, aggStruct) - } - } - ), - ), + ) = + if (aggs.isEmpty) noOp + else { + val aggResult = + lc.deepCopy.aggregate( + let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in (aggLet( + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in makeStruct( + 'n_rows -> + applyAggOp(Count(), FastSeq(), FastSeq()), + 'array_aggs -> + irRange(0, '__cols.len) + .aggElements('__element_idx, '__result_idx, Some('__cols.len))( + let(sa = '__cols('__result_idx)) in + (aggLet(sa = '__cols('__element_idx), g = '__entries('__element_idx)) in + aggFilter(!'g.isNA, bindingsToStruct(aggs))) + ), + )) ) - } - ) - val ident = freshName() - liftedRelationalLets += ((ident, aggResult)) + val ident = freshName() + liftedRelationalLets += (ident -> aggResult) - val aggResultRef = Ref(freshName(), aggResult.typ) - val aggResultElementRef = Ref( - freshName(), - aggResult.typ.asInstanceOf[TStruct] - .fieldType("array_aggs") - .asInstanceOf[TArray].elementType, - ) + val bindResult: IRProxy => IRProxy = + let( + __agg_result = RelationalRef(ident, aggResult.typ), + __array_aggs = '__agg_result('array_aggs), + n_rows = '__agg_result('n_rows), + ) in _ - val bindResult: IRProxy => IRProxy = letDyn(( - aggResultRef.name, - irToProxy(RelationalRef(ident, aggResult.typ)), - )).apply(_) - val bodyResult: IRProxy => IRProxy = (x: IRProxy) => - letDyn(( - aggResultRef.name, - irToProxy(RelationalRef(ident, aggResult.typ)), - )) - .apply(let( - n_rows = Symbol(aggResultRef.name.str)('count), - array_aggs = Symbol(aggResultRef.name.str)('array_aggs), - ) { - letDyn((aggResultElementRef.name, 'array_aggs(idx))) { - aggs.foldLeft[IRProxy](x) { case (acc, (name, _)) => - letDyn((name, GetField(aggResultElementRef, name.str)))(acc) - } - } - }) - (bindResult, bodyResult) - } + def bodyResult(body: IRProxy): IRProxy = + let(__agg_elem = '__array_aggs('__col_idx)) in + (letDyn(aggs.map { case (n, _) => n -> '__agg_elem(Symbol(n.str)) }: _*) in + body) + + (bindResult, bodyResult _) + } val ( scanOutsideTransformer: (IRProxy => IRProxy), scanInsideTransformer: (IRProxy => IRProxy), - ) = if (scans.isEmpty) - noOp - else { - val scanStruct = bindingsToStruct(scans) - - val scanResultArray = ToArray(StreamAggScan( - ToStream(GetField(Ref(TableIR.globalName, loweredChild.typ.globalType), colsFieldName)), - MatrixIR.colName, - scanStruct, - )) - - val scanResultRef = Ref(freshName(), scanResultArray.typ) - val scanResultElementRef = - Ref(freshName(), scanResultArray.typ.asInstanceOf[TArray].elementType) - - val bindResult: IRProxy => IRProxy = - letDyn((scanResultRef.name, scanResultArray)).apply(_) - val bodyResult: IRProxy => IRProxy = (x: IRProxy) => - letDyn(( - scanResultElementRef.name, - ArrayRef(scanResultRef, idx), - ))( - scans.foldLeft[IRProxy](x) { case (acc, (name, _)) => - letDyn((name, GetField(scanResultElementRef, name.str)))(acc) - } - ) - (bindResult, bodyResult) - } + ) = + if (scans.isEmpty) noOp + else { + val scanStruct = bindingsToStruct(scans) + + val bindResult: IRProxy => IRProxy = + let(__scan_result = '__cols.streamAggScan('sa ~> scanStruct)) in _ - loweredChild.mapGlobals('global.insertFields(colsField -> - aggOutsideTransformer(scanOutsideTransformer(irRange(0, 'global(colsField).len).map( - idxSym ~> let(__cols_array = 'global(colsField), sa = '__cols_array(idxSym)) { - aggInsideTransformer(scanInsideTransformer(b0)) - } - ))))) + def bodyResult(body: IRProxy): IRProxy = + let(__scan_elem = '__scan_result('__col_idx)) in + (letDyn(scans.map { case (n, _) => n -> '__scan_elem(Symbol(n.str)) }: _*) in + body) + + (bindResult, bodyResult _) + } + + lc.mapGlobals( + let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in 'global.insertFields( + colsField -> + aggOutsideTransformer( + scanOutsideTransformer( + irRange(0, '__cols.len).map('__col_idx ~> + (let(sa = '__cols('__col_idx)) in + aggInsideTransformer(scanInsideTransformer(b0)))) + ) + ) + ) + ) case MatrixFilterEntries(child, pred) => - val lc = lower(ctx, child, liftedRelationalLets) - lc.mapRows('row.insertFields(entriesField -> - irRange(0, 'global(colsField).len).map { - 'i ~> - let(g = 'row(entriesField)('i)) { - irIf(let(sa = 'global(colsField)('i)) - in !subst(lower(ctx, pred, liftedRelationalLets), matrixSubstEnv(child))) { - NA(child.typ.entryType) - } { - 'g - } - } - })) + val mtype = child.typ + lower(ctx, child, liftedRelationalLets) + .mapRows( + let( + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(mtype.rowType.fieldNames: _*), + ) in 'row.insertFields( + entriesField -> + irRange(0, '__cols.len).map('i ~> + (let(sa = '__cols('i), g = '__entries('i)) in + irIf(lower(ctx, pred, liftedRelationalLets))('g)(NA(mtype.entryType)))) + ) + ) case MatrixUnionCols(left, right, joinType) => - val rightEntries = genUID() - val rightCols = genUID() - val ll = lower(ctx, left, liftedRelationalLets).distinct() def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy = - if (joinType == "inner") - 'row(entries) - else - irIf('row(entries).isNA) { - irRange(0, 'global(cols).len) - .map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => - (f, NA(right.typ.entryType.fieldType(f))) - )))) - } { - 'row(entries) - } + if (joinType == "inner") 'row(entries) + else let(__entries = 'row(entries)) in + irIf(!'__entries.isNA)('__entries)( + irRange(0, 'global(cols).len).map('a ~> + MakeStruct(right.typ.entryType.fields.map(f => (f.name, NA(f.typ))))) + ) + + val ll = lower(ctx, left, liftedRelationalLets).distinct() val rr = lower(ctx, right, liftedRelationalLets).distinct() TableJoin( ll, - rr.mapRows('row.castRename(rr.typ.rowType.rename(Map(entriesFieldName -> rightEntries)))) + rr.mapRows( + 'row.castRename(rr.typ.rowType.rename(Map(entriesFieldName -> '__right_entries.name))) + ) .mapGlobals('global - .insertFields(Symbol(rightCols) -> 'global(colsField)) - .selectFields(rightCols)), + .insertFields('__right_cols -> 'global(colsField)) + .selectFields('__right_cols.name)), joinType, ) .mapRows('row - .insertFields(entriesField -> - makeArray( - handleMissingEntriesArray(entriesField, colsField), - handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols)), - ) - .flatMap('a ~> 'a)) - .dropFields(Symbol(rightEntries))) + .insertFields( + entriesField -> { + val ls = handleMissingEntriesArray(entriesField, colsField) + val rs = handleMissingEntriesArray('__right_entries, '__right_cols) + makeArray(ls, rs).flatten + } + ) + .dropFields('__right_entries)) .mapGlobals('global - .insertFields(colsField -> - makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a)) - .dropFields(Symbol(rightCols))) + .insertFields( + colsField -> + makeArray('global(colsField), 'global('__right_cols)).flatten + ) + .dropFields('__right_cols)) case MatrixMapEntries(child, newEntries) => - val loweredChild = lower(ctx, child, liftedRelationalLets) - val rt = loweredChild.typ.rowType - val gt = loweredChild.typ.globalType + val lc = lower(ctx, child, liftedRelationalLets) TableMapRows( - loweredChild, - InsertFields( - Ref(TableIR.rowName, rt), - FastSeq( - entriesFieldName -> ToArray( - zip2( - ToStream(GetField(Ref(TableIR.rowName, rt), entriesFieldName)), - ToStream(GetField(Ref(TableIR.globalName, gt), colsFieldName)), - ArrayZipBehavior.AssumeSameLength, - ) { (entries, cols) => - Subst( + lc, + M.eval { + for { + cols <- Name("__cols") -> colVals(lc) + entries <- Name("__entries") -> entries(lc) + _ <- MatrixIR.globalName -> globals(lc) + row <- MatrixIR.rowName -> rowVal(lc) + } yield InsertFields( + row, + FastSeq( + entriesFieldName -> + ToArray(StreamZip( + FastSeq(ToStream(cols), ToStream(entries)), + FastSeq(MatrixIR.colName, MatrixIR.entryName), lower(ctx, newEntries, liftedRelationalLets), - BindingEnv(Env( - MatrixIR.globalName -> SelectFields( - Ref(TableIR.globalName, gt), - child.typ.globalType.fieldNames, - ), - MatrixIR.rowName -> SelectFields( - Ref(TableIR.rowName, rt), - child.typ.rowType.fieldNames, - ), - MatrixIR.colName -> cols, - MatrixIR.entryName -> entries, - )), - ) - } - ) - ), - ), + ArrayZipBehavior.AssumeSameLength, + )) + ), + ) + }, ) case MatrixRepartition(child, n, shuffle) => @@ -749,109 +687,120 @@ object LowerMatrixIR { case MatrixRowsHead(child, n) => TableHead(lower(ctx, child, liftedRelationalLets), n) case MatrixRowsTail(child, n) => TableTail(lower(ctx, child, liftedRelationalLets), n) - case MatrixColsHead(child, n) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields(colsField -> 'global(colsField).arraySlice( - 0, - Some(n), - 1, - ))) + case MatrixColsHead(child, n) => + lower(ctx, child, liftedRelationalLets) + .mapGlobals('global.insertFields('__cols -> 'global('__cols).arraySlice(0, Some(n), 1))) .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(0, Some(n), 1))) - case MatrixColsTail(child, n) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields(colsField -> 'global(colsField).arraySlice(-n, None, 1))) + case MatrixColsTail(child, n) => + lower(ctx, child, liftedRelationalLets) + .mapGlobals('global.insertFields('__cols -> 'global('__cols).arraySlice(-n, None, 1))) .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(-n, None, 1))) case MatrixExplodeCols(child, path) => - val loweredChild = lower(ctx, child, liftedRelationalLets) - val lengths = Symbol(genUID()) - val colIdx = Symbol(genUID()) - val nestedIdx = Symbol(genUID()) - val colElementUID1 = Symbol(genUID()) - - val nestedRefs = - path.init.scanLeft('global(colsField)(colIdx): IRProxy)((irp, name) => irp(Symbol(name))) - val postExplodeSelector = path.zip(nestedRefs).zipWithIndex.foldRight[IRProxy](nestedIdx) { - case (((field, ref), i), arg) => - ref.insertFields(Symbol(field) -> - (if (i == nestedRefs.length - 1) - ref(Symbol(field)).toArray(arg) - else - arg)) - } - - val arrayIR = path.foldLeft[IRProxy](colElementUID1) { case (irp, fieldName) => - irp(Symbol(fieldName)) - } - loweredChild - .mapGlobals('global.insertFields(lengths -> 'global(colsField).map({ - colElementUID1 ~> arrayIR.len.orElse(0) - }))) - .mapGlobals('global.insertFields(colsField -> - irRange(0, 'global(colsField).len, 1) - .flatMap({ - colIdx ~> - irRange(0, 'global(lengths)(colIdx), 1) - .map({ - nestedIdx ~> postExplodeSelector + lower(ctx, child, liftedRelationalLets) + .mapGlobals( + let( + __cols = + 'global(colsField), + __lengths = + '__cols.map('__elem ~> + path + .foldLeft[IRProxy]('__elem) { case (irp, f) => irp(Symbol(f)) } + .len + .orElse(0)), + ) in 'global.insertFields( + '__cols -> + irRange(0, '__cols.len).flatMap('__col_idx ~> { + val nestedRefs = + path.init.scanLeft('__cols('__col_idx))((irp, name) => irp(Symbol(name))) + + irRange(0, '__lengths('__col_idx)).map('__length_idx ~> + path.zip(nestedRefs).zipWithIndex.foldRight[IRProxy]('__length_idx) { + case (((field, ref), i), arg) => + val s = Symbol(field) + ref.insertFields( + s -> (if (i == nestedRefs.length - 1) ref(s).toArray(arg) else arg) + ) }) - }))) - .mapRows('row.insertFields(entriesField -> - irRange(0, 'row(entriesField).len, 1) - .flatMap(colIdx ~> - irRange(0, 'global(lengths)(colIdx), 1).map( - Symbol(genUID()) ~> 'row(entriesField)(colIdx) - )))) - .mapGlobals('global.dropFields(lengths)) + }), + '__lengths -> + '__lengths, + ) + ) + .mapRows( + let(__entries = 'row(entriesField), __lengths = 'global('__lengths)) in + 'row.insertFields( + entriesField -> + irRange(0, '__entries.len).flatMap('__col_idx ~> + irRange(0, '__lengths('__col_idx)).map('__unused ~> + '__entries('__col_idx))) + ) + ) + .mapGlobals('global.dropFields('__lengths)) case MatrixAggregateRowsByKey(child, entryExpr, rowExpr) => - val substEnv = matrixSubstEnv(child) - val eeSub = subst(lower(ctx, entryExpr, liftedRelationalLets), substEnv) - val reSub = subst(lower(ctx, rowExpr, liftedRelationalLets), substEnv) lower(ctx, child, liftedRelationalLets) .aggregateByKey( - reSub.insertFields(entriesField -> irRange(0, 'global(colsField).len) - .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( - let(sa = 'global(colsField)('__result_idx)) { - aggLet( - sa = 'global(colsField)('__element_idx), - g = 'row(entriesField)('__element_idx), - ) { - aggFilter(!'g.isNA, eeSub) - } - } - )) + let( + __cols = 'global(colsField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in (aggLet( + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in lower(ctx, rowExpr, liftedRelationalLets).insertFields( + entriesField -> + irRange(0, '__cols.len) + .aggElements('__element_idx, '__result_idx, Some('__cols.len))( + let(sa = '__cols('__result_idx)) in + (aggLet(sa = '__cols('__element_idx), g = '__entries('__element_idx)) in + aggFilter(!'g.isNA, lower(ctx, entryExpr, liftedRelationalLets))) + ) + )) ) case MatrixCollectColsByKey(child) => lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields('newColIdx -> - irRange(0, 'global(colsField).len).map { - 'i ~> - makeTuple('global(colsField)('i).selectFields(child.typ.colKey: _*), 'i) - }.groupByKey.toArray)) - .mapRows('row.insertFields(entriesField -> - 'global('newColIdx).map { - 'kv ~> - makeStruct(child.typ.entryType.fieldNames.map { s => - ( - Symbol(s), - 'kv('value).map { - 'i ~> 'row(entriesField)('i)(Symbol(s)) - }, - ) - }: _*) - })) - .mapGlobals('global - .insertFields(colsField -> - 'global('newColIdx).map { - 'kv ~> - 'kv('key).insertFields( - child.typ.colValueStruct.fieldNames.map { s => - (Symbol(s), 'kv('value).map('i ~> 'global(colsField)('i)(Symbol(s)))) - }: _* - ) - }) - .dropFields('newColIdx)) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global.insertFields( + '__new_col_idx -> + irRange(0, '__cols.len) + .map('i ~> makeTuple('__cols('i).selectFields(child.typ.colKey: _*), 'i)) + .groupByKey + .toArray + ) + ) + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields( + entriesField -> + 'global('__new_col_idx).map { + 'kv ~> + makeStruct(child.typ.entryType.fieldNames.map { f => + val s = Symbol(f) + s -> 'kv('value).map('i ~> '__entries('i)(s)) + }: _*) + } + ) + ) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global + .insertFields( + colsField -> + 'global('__new_col_idx).map('kv ~> + 'kv('key).insertFields( + child.typ.colValueStruct.fieldNames.map { f => + val s = Symbol(f) + s -> 'kv('value).map('i ~> '__cols('i)(s)) + }: _* + )) + ) + .dropFields('__new_col_idx) + ) case MatrixExplodeRows(child, path) => TableExplode(lower(ctx, child, liftedRelationalLets), path) @@ -859,63 +808,53 @@ object LowerMatrixIR { case mr: MatrixRead => mr.lower(ctx) case MatrixAggregateColsByKey(child, entryExpr, colExpr) => - val colKey = child.typ.colKey - - val originalColIdx = Symbol(genUID()) - val newColIdx1 = Symbol(genUID()) - val newColIdx2 = Symbol(genUID()) - val colsAggIdx = Symbol(genUID()) - val keyMap = Symbol(genUID()) - val aggElementIdx = Symbol(genUID()) - - val e1 = Env[IRProxy]( - MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*), - MatrixIR.rowName -> 'row.selectFields(child.typ.rowType.fieldNames: _*), - ) - val e2 = Env[IRProxy]( - MatrixIR.globalName -> 'global.selectFields(child.typ.globalType.fieldNames: _*) - ) - val ceSub = - subst(lower(ctx, colExpr, liftedRelationalLets), BindingEnv(e2, agg = Some(e2))) - val eeSub = - subst(lower(ctx, entryExpr, liftedRelationalLets), BindingEnv(e1, agg = Some(e1))) - lower(ctx, child, liftedRelationalLets) - .mapGlobals('global.insertFields(keyMap -> - let(__cols_field = 'global(colsField)) { - irRange(0, '__cols_field.len) - .map(originalColIdx ~> let(__cols_field_element = '__cols_field(originalColIdx)) { - makeStruct( - 'key -> '__cols_field_element.selectFields(colKey: _*), - 'value -> originalColIdx, - ) - }) - .groupByKey - .toArray - })) - .mapRows('row.insertFields(entriesField -> - let(__entries = 'row(entriesField), __key_map = 'global(keyMap)) { - irRange(0, '__key_map.len) - .map(newColIdx1 ~> '__key_map(newColIdx1) - .apply('value) - .streamAgg(aggElementIdx ~> - aggLet(g = '__entries(aggElementIdx), sa = 'global(colsField)(aggElementIdx)) { - aggFilter(!'g.isNA, eeSub) - })) - })) .mapGlobals( - 'global.insertFields(colsField -> - let(__key_map = 'global(keyMap)) { - irRange(0, '__key_map.len) - .map(newColIdx2 ~> + let(__cols = 'global(colsField)) in + 'global.insertFields( + '__key_map -> + irRange(0, '__cols.len) + .map('__old_col_idx ~> + (let(__elem = '__cols('__old_col_idx)) in + makeStruct( + 'key -> '__elem.selectFields(child.typ.colKey: _*), + 'value -> '__old_col_idx, + ))) + .groupByKey + .toArray + ) + ) + .mapRows( + let( + __key_map = 'global('__key_map), + __cols = 'global(colsField), + __entries = 'row(entriesField), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + va = 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) in 'row.insertFields( + entriesField -> + irRange(0, '__key_map.len).map('__new_col_idx ~> + '__key_map('__new_col_idx)('value).streamAgg('__agg_idx ~> + (aggLet(sa = '__cols('__agg_idx), g = '__entries('__agg_idx)) in + aggFilter(!'g.isNA, lower(ctx, entryExpr, liftedRelationalLets))))) + ) + ) + .mapGlobals( + let( + __cols = 'global(colsField), + __key_map = 'global('__key_map), + global = 'global.selectFields(child.typ.globalType.fieldNames: _*), + ) in 'global.insertFields( + colsField -> + irRange(0, '__key_map.len).map('__new_col_idx ~> + (let(__elem = '__key_map('__new_col_idx)) in concatStructs( - '__key_map(newColIdx2)('key), - '__key_map(newColIdx2)('value) - .streamAgg(colsAggIdx ~> aggLet(sa = 'global(colsField)(colsAggIdx)) { - ceSub - }), - )) - }).dropFields(keyMap) + '__elem('key), + '__elem('value).streamAgg('__agg_idx ~> + (aggLet(sa = '__cols('__agg_idx)) in + lower(ctx, colExpr, liftedRelationalLets))), + ))) + ) ) case MatrixLiteral(_, tl) => tl @@ -928,85 +867,100 @@ object LowerMatrixIR { lowered } - private[this] def lower(ctx: ExecuteContext, tir: TableIR, ab: Growable[(Name, IR)]): TableIR = { + private def lower(ctx: ExecuteContext, tir: TableIR, ab: Growable[(Name, IR)]): TableIR = { val lowered = tir match { case CastMatrixToTable(child, entries, cols) => lower(ctx, child, ab) - .mapRows('row.selectFields(child.typ.rowType.fieldNames ++ Array(entriesFieldName): _*)) - .mapGlobals('global.selectFields( - child.typ.globalType.fieldNames ++ Array(colsFieldName): _* - )) + .mapRows('row.selectFields(child.typ.rowType.fieldNames :+ entriesFieldName: _*)) + .mapGlobals('global.selectFields(child.typ.globalType.fieldNames :+ colsFieldName: _*)) .rename(Map(entriesFieldName -> entries), Map(colsFieldName -> cols)) case x @ MatrixEntriesTable(child) => val lc = lower(ctx, child, ab) if (child.typ.rowKey.nonEmpty && child.typ.colKey.nonEmpty) { - val oldColIdx = Symbol(genUID()) - val lambdaIdx1 = Symbol(genUID()) - val lambdaIdx2 = Symbol(genUID()) - val lambdaIdx3 = Symbol(genUID()) - val toExplode = Symbol(genUID()) - val values = Symbol(genUID()) lc - .mapGlobals('global.insertFields(oldColIdx -> - irRange(0, 'global(colsField).len) - .map(lambdaIdx1 ~> makeStruct( - 'key -> 'global(colsField)(lambdaIdx1).selectFields(child.typ.colKey: _*), - 'value -> lambdaIdx1, - )) - .sort(ascending = true, onKey = true) - .map(lambdaIdx1 ~> lambdaIdx1('value)))) - .aggregateByKey(makeStruct(values -> applyAggOp( - Collect(), - seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*)), - ))) - .mapRows('row.dropFields(values).insertFields(toExplode -> - 'global(oldColIdx) - .flatMap(lambdaIdx1 ~> 'row(values) - .filter(lambdaIdx2 ~> !lambdaIdx2(entriesField)(lambdaIdx1).isNA) - .map(lambdaIdx3 ~> let( - __col = 'global(colsField)(lambdaIdx1), - __entry = lambdaIdx3(entriesField)(lambdaIdx1), - ) { - makeStruct( - child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => - f -> lambdaIdx3(f) - ) ++ - child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col(f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry(f)): _* - ) - })))) - .explode(toExplode) - .mapRows(makeStruct(x.typ.rowType.fieldNames.map { f => - val fd = Symbol(f) - (fd, if (child.typ.rowKey.contains(f)) 'row(fd) else 'row(toExplode)(fd)) - }: _*)) - .mapGlobals('global.dropFields(colsField, oldColIdx)) + .mapGlobals( + let(__cols = 'global(colsField)) in + 'global.insertFields( + '__old_col_idx -> + irRange(0, '__cols.len) + .map('__col_idx ~> + makeStruct( + 'key -> '__cols('__col_idx).selectFields(child.typ.colKey: _*), + 'value -> '__col_idx, + )) + .sort(ascending = true, onKey = true) + .map('__elem ~> '__elem('value)) + ) + ) + .aggregateByKey(makeStruct( + '__values -> + applyAggOp( + Collect(), + seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*)), + ) + )) + .mapRows( + let(__cols = 'global(colsField)) in + 'row.dropFields('__values).insertFields( + '__explode -> + 'global('__old_col_idx).flatMap('__old_col_idx ~> + (let(__col = '__cols('__old_col_idx)) in + 'row('__values) + .filter('__v ~> !'__v(entriesField)('__old_col_idx).isNA) + .map('__v ~> + (let(__entry = '__v(entriesField)('__old_col_idx)) in + makeStruct( + child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => + f -> '__v(f) + ) ++ + child.typ.colType.fieldNames.map(Symbol(_)).map(f => + f -> '__col(f) + ) ++ + child.typ.entryType.fieldNames.map(Symbol(_)).map(f => + f -> '__entry(f) + ): _* + ))))) + ) + ) + .explode('__explode) + .mapRows( + let(__exploded = 'row('__explode)) in + makeStruct(x.typ.rowType.fieldNames.map { f => + val fd = Symbol(f) + (fd, if (child.typ.rowKey.contains(f)) 'row(fd) else '__exploded(fd)) + }: _*) + ) + .mapGlobals('global.dropFields(colsField, '__old_col_idx)) .keyBy(child.typ.rowKey ++ child.typ.colKey, isSorted = true) } else { - val colIdx = Symbol(genUID()) - val lambdaIdx = Symbol(genUID()) - val result = lc - .mapRows('row.insertFields(colIdx -> irRange(0, 'global(colsField).len) - .filter(lambdaIdx ~> !'row(entriesField)(lambdaIdx).isNA))) - .explode(colIdx) - .mapRows(let( - __col_struct = 'global(colsField)('row(colIdx)), - __entry_struct = 'row(entriesField)('row(colIdx)), - ) { - val newFields = - child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct(f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct(f)) - - 'row.dropFields(entriesField, colIdx).insertFieldsList( - newFields, - ordering = Some(x.typ.rowType.fieldNames), + val result = + lc + .mapRows( + let(__entries = 'row(entriesField)) in + 'row.insertFields( + '__col_idx -> + irRange(0, 'global(colsField).len) + .filter('__idx ~> !'__entries('__idx).isNA) + ) ) - }) - .mapGlobals('global.dropFields(colsField)) - if (child.typ.colKey.isEmpty) - result + .explode('__col_idx) + .mapRows { + val newFields = + child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct(f)) ++ + child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct(f)) + + let( + __col_struct = 'global(colsField)('row('__col_idx)), + __entry_struct = 'row(entriesField)('row('__col_idx)), + ) in 'row + .dropFields(entriesField, '__col_idx) + .insertFieldsList(newFields, ordering = Some(x.typ.rowType.fieldNames)) + } + .mapGlobals('global.dropFields(colsField)) + + if (child.typ.colKey.isEmpty) result else { assert(child.typ.rowKey.isEmpty) result.keyBy(child.typ.colKey) @@ -1033,23 +987,23 @@ object LowerMatrixIR { case MatrixColsTable(child) => val colKey = child.typ.colKey - let(__cols_and_globals = lower(ctx, child, ab).getGlobals) { - val sortedCols = if (colKey.isEmpty) - '__cols_and_globals(colsField) - else - '__cols_and_globals(colsField).map { - '__cols_element ~> - makeStruct( - // key struct - '_1 -> '__cols_element.selectFields(colKey: _*), - '_2 -> '__cols_element, - ) - }.sort(true, onKey = true) - .map { - 'elt ~> 'elt('_2) - } - makeStruct('rows -> sortedCols, 'global -> '__cols_and_globals.dropFields(colsField)) - }.parallelize(None).keyBy(child.typ.colKey) + + val sortedCols = + if (colKey.isEmpty) '__cols_and_global(colsField) + else '__cols_and_global(colsField) + .map('__cols_element ~> + makeStruct( + // key struct + '_1 -> '__cols_element.selectFields(colKey: _*), + '_2 -> '__cols_element, + )) + .sort(true, onKey = true) + .map('elt ~> 'elt('_2)) + + (let(__cols_and_global = lower(ctx, child, ab).getGlobals) in + makeStruct('rows -> sortedCols, 'global -> '__cols_and_global.dropFields(colsField))) + .parallelize(None) + .keyBy(child.typ.colKey) case table => lowerChildren(ctx, table, ab).asInstanceOf[TableIR] } @@ -1058,26 +1012,26 @@ object LowerMatrixIR { lowered } - private[this] def lower(ctx: ExecuteContext, bmir: BlockMatrixIR, ab: Growable[(Name, IR)]) + private def lower(ctx: ExecuteContext, bmir: BlockMatrixIR, ab: Growable[(Name, IR)]) : BlockMatrixIR = { - val lowered = bmir match { - case noMatrixChildren => lowerChildren(ctx, noMatrixChildren, ab).asInstanceOf[BlockMatrixIR] - } + val lowered = lowerChildren(ctx, bmir, ab).asInstanceOf[BlockMatrixIR] assertTypeUnchanged(bmir, lowered) lowered } - private[this] def lower(ctx: ExecuteContext, ir: IR, ab: Growable[(Name, IR)]): IR = { + private def lower(ctx: ExecuteContext, ir: IR, ab: Growable[(Name, IR)]): IR = { val lowered = ir match { - case MatrixToValueApply(child, function) => TableToValueApply( + case MatrixToValueApply(child, function) => + TableToValueApply( lower(ctx, child, ab), - function.lower() - .getOrElse(WrappedMatrixToValueFunction( + function.lower().getOrElse( + WrappedMatrixToValueFunction( function, colsFieldName, entriesFieldName, child.typ.colKey, - )), + ) + ), ) case MatrixWrite(child, writer) => TableWrite( @@ -1096,29 +1050,33 @@ object LowerMatrixIR { val lc = lower(ctx, child, ab) TableAggregate( lc, - aggExplodeIR( - filterIR( - zip2( - ToStream(GetField(Ref(TableIR.rowName, lc.typ.rowType), entriesFieldName)), - ToStream(GetField(Ref(TableIR.globalName, lc.typ.globalType), colsFieldName)), - ArrayZipBehavior.AssertSameLength, - ) { case (e, c) => - MakeTuple.ordered(FastSeq(e, c)) + Let( + FastSeq(MatrixIR.globalName -> globals(lc)), + M.agg { + for { + cols <- Name("__cols") -> colVals(lc) + entries <- Name("__entries") -> entries(lc) + _ <- MatrixIR.globalName -> globals(lc) + _ <- MatrixIR.rowName -> rowVal(lc) + } yield aggExplodeIR( + filterIR( + zip2( + ToStream(cols), + ToStream(entries), + ArrayZipBehavior.AssertSameLength, + ) { + (c, e) => maybeIR(e)(e => maketuple(c, e)) + } + )(r => ApplyUnaryPrimOp(Bang, IsNA(r))) + ) { explodedTuple => + M.agg { + (MatrixIR.colName -> GetTupleElement(explodedTuple, 0)) >> + (MatrixIR.entryName -> GetTupleElement(explodedTuple, 1)) >> + query + } } - )(filterTuple => ApplyUnaryPrimOp(Bang, IsNA(GetTupleElement(filterTuple, 0)))) - ) { explodedTuple => - AggLet( - MatrixIR.entryName, - GetTupleElement(explodedTuple, 0), - AggLet( - MatrixIR.colName, - GetTupleElement(explodedTuple, 1), - Subst(query, matrixSubstEnvIR(child, lc)), - isScan = false, - ), - isScan = false, - ) - }, + }, + ), ) case _ => lowerChildren(ctx, ir, ab).asInstanceOf[IR] } @@ -1126,7 +1084,7 @@ object LowerMatrixIR { lowered } - private[this] def assertTypeUnchanged(original: BaseIR, lowered: BaseIR): Unit = + private def assertTypeUnchanged(original: BaseIR, lowered: BaseIR): Unit = if (lowered.typ != original.typ) fatal( s"lowering changed type:\n before: ${original.typ}\n after: ${lowered.typ}\n ${original.getClass.getName} => ${lowered.getClass.getName}" diff --git a/hail/hail/src/is/hail/expr/ir/MatrixIR.scala b/hail/hail/src/is/hail/expr/ir/MatrixIR.scala index 6c86056729a..1d06d866ce1 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixIR.scala @@ -4,7 +4,6 @@ import is.hail.annotations._ import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq -import is.hail.expr.ir.DeprecatedIRBuilder._ import is.hail.expr.ir.analyses.{ColumnCount, PartitionCounts} import is.hail.expr.ir.defs._ import is.hail.expr.ir.functions.MatrixToMatrixFunction @@ -198,20 +197,14 @@ abstract class MatrixHybridReader extends TableReaderWithExtraUID with MatrixRea tr, InsertFields( Ref(TableIR.rowName, tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.entryType), - )), + FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray.empty(requestedType.entryType)), ), ) tr = TableMapGlobals( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.colType), - )), + FastSeq(LowerMatrixIR.colsFieldName -> MakeArray.empty(requestedType.colType)), ), ) } @@ -261,7 +254,7 @@ case class MatrixNativeReaderParameters( class MatrixNativeReader( val params: MatrixNativeReaderParameters, - spec: AbstractMatrixTableSpec, + val spec: AbstractMatrixTableSpec, ) extends MatrixReader { override def pathsUsed: Seq[String] = FastSeq(params.path) @@ -300,20 +293,14 @@ class MatrixNativeReader( tr, InsertFields( Ref(TableIR.globalName, tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.colType), - )), + FastSeq(LowerMatrixIR.colsFieldName -> MakeArray.empty(requestedType.colType)), ), ) TableMapRows( tr, InsertFields( Ref(TableIR.rowName, tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray( - FastSeq(), - TArray(requestedType.entryType), - )), + FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray.empty(requestedType.entryType)), ), ) } else { @@ -325,36 +312,27 @@ class MatrixNativeReader( spec.rowsSpec, spec.entriesSpec, ) - val tr: TableIR = TableRead(tt, dropRows, trdr) + val tr = TableRead(tt, dropRows, trdr) val colsRVDSpec = spec.colsSpec.rowsSpec val partFiles = colsRVDSpec.absolutePartPaths(spec.colsSpec.rowsComponent.absolutePath(colsPath)) - val cols = if (partFiles.length == 1) { + def readCols(index: IR, path: IR): IR = ReadPartition( - MakeStruct(ArraySeq("partitionIndex" -> I64(0), "partitionPath" -> Str(partFiles.head))), + makestruct("partitionIndex" -> index.toL, "partitionPath" -> path), requestedType.colType, PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName), ) - } else { - val contextType = TStruct("partitionIndex" -> TInt64, "partitionPath" -> TString) - val partNames = MakeArray( - partFiles.zipWithIndex.map { case (path, idx) => - MakeStruct(ArraySeq("partitionIndex" -> I64(idx.toLong), "partitionPath" -> Str(path))) - }, - TArray(contextType), - ) - val elt = Ref(freshName(), contextType) - StreamFlatMap( - partNames, - elt.name, - ReadPartition( - elt, - requestedType.colType, - PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName), - ), + + val cols = + if (partFiles.length == 1) readCols(0, Str(partFiles.head)) + else flatten( + zip2( + iota(0, 1), + ToStream(Literal(TArray(TString), partFiles)), + ArrayZipBehavior.TakeMinLength, + )(readCols(_, _)) ) - } TableMapGlobals( tr, @@ -377,8 +355,6 @@ class MatrixNativeReader( case that: MatrixNativeReader => params == that.params case _ => false } - - def getSpec(): AbstractMatrixTableSpec = this.spec } object MatrixRangeReader { @@ -402,13 +378,13 @@ object MatrixRangeReader { case class MatrixRangeReaderParameters(nRows: Int, nCols: Int, nPartitions: Option[Int]) case class MatrixRangeReader( - val params: MatrixRangeReaderParameters, + params: MatrixRangeReaderParameters, nPartitionsAdj: Int, ) extends MatrixReader { override def pathsUsed: Seq[String] = FastSeq() - override def rowUIDType = TInt64 - override def colUIDType = TInt64 + override def rowUIDType: Type = TInt64 + override def colUIDType: Type = TInt64 override def fullMatrixTypeWithoutUIDs: MatrixType = MatrixType( globalType = TStruct.empty, @@ -432,32 +408,39 @@ case class MatrixRangeReader( dropCols: Boolean, dropRows: Boolean, ): TableIR = { + import DeprecatedIRBuilder._ + val nRowsAdj = if (dropRows) 0 else params.nRows val nColsAdj = if (dropCols) 0 else params.nCols var ht = TableRange(nRowsAdj, params.nPartitions.getOrElse(ctx.backend.defaultParallelism)) .rename(Map("idx" -> "row_idx")) - if (requestedType.colType.hasField(colUIDFieldName)) - ht = ht.mapGlobals(makeStruct(LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> makeStruct( - 'col_idx -> 'i, - Symbol(colUIDFieldName) -> 'i.toL, - )))) + + ht = if (requestedType.colType.hasField(colUIDFieldName)) + ht.mapGlobals(makeStruct( + LowerMatrixIR.colsField -> + irRange(0, nColsAdj).map('i ~> + makeStruct( + 'col_idx -> 'i, + Symbol(colUIDFieldName) -> 'i.toL, + )) + )) else - ht = ht.mapGlobals(makeStruct(LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i)))) + ht.mapGlobals(makeStruct( + LowerMatrixIR.colsField -> + irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i)) + )) + if (requestedType.rowType.hasField(rowUIDFieldName)) - ht = ht.mapRows('row.insertFields( - LowerMatrixIR.entriesField -> irRange(0, nColsAdj).map('i ~> makeStruct()), + ht.mapRows('row.insertFields( + LowerMatrixIR.entriesField -> irRange(0, nColsAdj).map('i ~> makestruct()), Symbol(rowUIDFieldName) -> 'row('row_idx).toL, )) else - ht = ht.mapRows('row.insertFields( + ht.mapRows('row.insertFields( LowerMatrixIR.entriesField -> - irRange(0, nColsAdj).map('i ~> makeStruct()) + irRange(0, nColsAdj).map('i ~> makestruct()) )) - - ht } override def toJValue: JValue = { @@ -484,14 +467,6 @@ object MatrixRead { !reader.fullMatrixTypeWithoutUIDs.colType.hasField(MatrixReader.colUIDFieldName)) new MatrixRead(typ, dropCols, dropRows, reader) } - - def preserveExistingUIDs( - typ: MatrixType, - dropCols: Boolean, - dropRows: Boolean, - reader: MatrixReader, - ): MatrixRead = - new MatrixRead(typ, dropCols, dropRows, reader) } case class MatrixRead( diff --git a/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala b/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala index 3c7739630e0..dd7dfcd7b6d 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala @@ -7,6 +7,7 @@ import is.hail.backend.ExecuteContext import is.hail.collection.{ByteArrayBuilder, FastSeq} import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.{JSONAnnotationImpex, Nat} +import is.hail.expr.ir.{Memoized => M} import is.hail.expr.ir.defs._ import is.hail.expr.ir.lowering.TableStage import is.hail.expr.ir.streams.StreamProducer @@ -87,8 +88,8 @@ sealed trait MatrixWriterComponents { def stage: TableStage def setup: IR def writePartitionType: Type - def writePartition(rows: IR, ctx: Atom): IR - def finalizeWrite(parts: IR, globals: IR): IR + def writePartition(rows: Atom, ctx: Atom): IR + def finalizeWrite(parts: Atom, globals: Atom): IR } object MatrixNativeWriter { @@ -143,12 +144,6 @@ object MatrixNativeWriter { val partitioner = lowered.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - val emptyWriter = - PartitionNativeWriter(emptySpec, IndexedSeq(), s"$path/globals/globals/parts/", None, None) - val globalWriter = - PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/rows/parts/", None, None) - val colWriter = - PartitionNativeWriter(colSpec, IndexedSeq(), s"$path/cols/rows/parts/", None, None) val rowWriter = SplitPartitionNativeWriter( rowSpec, s"$path/rows/rows/parts/", @@ -161,39 +156,6 @@ object MatrixNativeWriter { else None, ) - val globalTableWriter = TableSpecWriter( - s"$path/globals", - TableType(tm.globalType, FastSeq(), TStruct.empty), - "rows", - "globals", - "../references", - log = false, - ) - val colTableWriter = TableSpecWriter( - s"$path/cols", - tm.colsTableType.copy(key = FastSeq[String]()), - "rows", - "../globals/rows", - "../references", - log = false, - ) - val rowTableWriter = TableSpecWriter( - s"$path/rows", - tm.rowsTableType, - "rows", - "../globals/rows", - "../references", - log = false, - ) - val entriesTableWriter = TableSpecWriter( - s"$path/entries", - TableType(tm.entriesRVType, FastSeq(), tm.globalType), - "rows", - "../globals/rows", - "../references", - log = false, - ) - new MatrixWriterComponents { override val stage: TableStage = @@ -205,7 +167,7 @@ object MatrixNativeWriter { oldCtx, ToStream(Literal(TArray(TString), partFiles)), ArrayZipBehavior.AssertSameLength, - )((ctxElt, pf) => MakeStruct(FastSeq("oldCtx" -> ctxElt, "writeCtx" -> pf))) + )((ctxElt, pf) => makestruct("oldCtx" -> ctxElt, "writeCtx" -> pf)) }(GetField(_, "oldCtx")) override val setup: IR = @@ -223,145 +185,166 @@ object MatrixNativeWriter { override def writePartitionType: Type = rowWriter.returnType - override def writePartition(rows: IR, ctx: Atom): IR = + override def writePartition(rows: Atom, ctx: Atom): IR = WritePartition(rows, GetField(ctx, "writeCtx") + UUID4(), rowWriter) - override def finalizeWrite(parts: IR, globals: IR): IR = { - // parts is array of partition results - val writeEmpty = WritePartition( - MakeStream(FastSeq(makestruct()), TStream(TStruct.empty)), - Str(partFile(1, 0)), - emptyWriter, - ) - val writeCols = - WritePartition(ToStream(GetField(globals, colsFieldName)), Str(partFile(1, 0)), colWriter) - val writeGlobals = WritePartition( - MakeStream( - FastSeq(SelectFields(globals, tm.globalType.fieldNames)), - TStream(tm.globalType), - ), - Str(partFile(1, 0)), - globalWriter, - ) + override def finalizeWrite(parts: Atom, globals: Atom): IR = + M.eval { + for { + partFile <- Str(partFile(1, 0)) + // parts is array of partition results + writeEmpty <- WritePartition( + MakeStream.single(makestruct()), + partFile, + PartitionNativeWriter( + emptySpec, + FastSeq(), + s"$path/globals/globals/parts/", + None, + None, + ), + ) - val matrixWriter = MatrixSpecWriter(path, tm, "rows/rows", "globals/rows", "cols/rows", - "entries/rows", "references", log = true) + colInfo <- WritePartition( + ToStream(GetField(globals, colsFieldName)), + partFile, + PartitionNativeWriter(colSpec, FastSeq(), s"$path/cols/rows/parts/", None, None), + ) - val rowsIndexSpec = IndexSpec.defaultAnnotation(ctx, "../../index", tcoerce[PStruct](pKey)) - val entriesIndexSpec = - IndexSpec.defaultAnnotation( - ctx, - "../../index", - tcoerce[PStruct](pKey), - withOffsetField = true, - ) + writeGlobals <- WritePartition( + MakeStream.single(SelectFields(globals, tm.globalType.fieldNames)), + partFile, + PartitionNativeWriter(globalSpec, FastSeq(), s"$path/globals/rows/parts/", None, None), + ) - bindIR(writeCols) { colInfo => - bindIR(parts) { partInfo => - Begin(FastSeq( - WriteMetadata( - MakeArray(GetField(writeEmpty, "filePath")), - RVDSpecWriter( - s"$path/globals/globals", - RVDSpecMaker(emptySpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), - ), + _ <- WriteMetadata( + MakeArray(GetField(writeEmpty, "filePath")), + RVDSpecWriter( + s"$path/globals/globals", + RVDSpecMaker(emptySpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), ), - WriteMetadata( - MakeArray(GetField(writeGlobals, "filePath")), - RVDSpecWriter( - s"$path/globals/rows", - RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), - ), + ) + + _ <- WriteMetadata( + MakeArray(GetField(writeGlobals, "filePath")), + RVDSpecWriter( + s"$path/globals/rows", + RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), ), - WriteMetadata( - MakeArray(MakeStruct(FastSeq( - "partitionCounts" -> I64(1), - "distinctlyKeyed" -> True(), - "firstKey" -> MakeStruct(FastSeq()), - "lastKey" -> MakeStruct(FastSeq()), - ))), - globalTableWriter, + ) + + _ <- WriteMetadata( + MakeArray(makestruct( + "partitionCounts" -> I64(1), + "distinctlyKeyed" -> True(), + "firstKey" -> makestruct(), + "lastKey" -> makestruct(), + )), + TableSpecWriter( + s"$path/globals", + TableType(tm.globalType, FastSeq(), TStruct.empty), + "rows", + "globals", + "../references", + log = false, ), - WriteMetadata( - MakeArray(GetField(colInfo, "filePath")), - RVDSpecWriter( - s"$path/cols/rows", - RVDSpecMaker(colSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), - ), + ) + + _ <- WriteMetadata( + MakeArray(GetField(colInfo, "filePath")), + RVDSpecWriter( + s"$path/cols/rows", + RVDSpecMaker(colSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), ), - WriteMetadata( - MakeArray(SelectFields( - colInfo, - IndexedSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), - )), - colTableWriter, + ) + + _ <- WriteMetadata( + MakeArray( + selectIR(colInfo, "partitionCounts", "distinctlyKeyed", "firstKey", "lastKey") ), - bindIR(ToArray(mapIR(ToStream(partInfo))(fc => GetField(fc, "filePath")))) { - files => - Begin(FastSeq( - WriteMetadata( - files, - RVDSpecWriter( - s"$path/rows/rows", - RVDSpecMaker(rowSpec, lowered.partitioner, rowsIndexSpec), - ), - ), - WriteMetadata( - files, - RVDSpecWriter( - s"$path/entries/rows", - RVDSpecMaker( - entrySpec, - RVDPartitioner.unkeyed(ctx.stateManager, lowered.numPartitions), - entriesIndexSpec, - ), - ), - ), - )) + TableSpecWriter( + s"$path/cols", + tm.colsTableType.copy(key = FastSeq[String]()), + "rows", + "../globals/rows", + "../references", + log = false, + ), + ) + + files <- mapArray(parts)(GetField(_, "filePath")) + + rowsIndexSpec = IndexSpec.defaultAnnotation(ctx, "../../index", pKey) + _ <- WriteMetadata( + files, + RVDSpecWriter( + s"$path/rows/rows", + RVDSpecMaker(rowSpec, lowered.partitioner, rowsIndexSpec), + ), + ) + + entriesIndexSpec = + IndexSpec.defaultAnnotation(ctx, "../../index", pKey, withOffsetField = true) + + _ <- WriteMetadata( + files, + RVDSpecWriter( + s"$path/entries/rows", + RVDSpecMaker( + entrySpec, + RVDPartitioner.unkeyed(ctx.stateManager, lowered.numPartitions), + entriesIndexSpec, + ), + ), + ) + + _ <- WriteMetadata( + mapArray(parts) { part => + selectIR(part, "partitionCounts", "distinctlyKeyed", "firstKey", "lastKey") }, - bindIR(ToArray(mapIR(ToStream(partInfo)) { fc => - SelectFields( - fc, - FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), + TableSpecWriter( + s"$path/rows", + tm.rowsTableType, + "rows", + "../globals/rows", + "../references", + log = false, + ), + ) + + _ <- WriteMetadata( + mapArray(parts) { part => + insertIR( + selectIR(part, "partitionCounts", "distinctlyKeyed"), + "firstKey" -> makestruct(), + "lastKey" -> makestruct(), ) - })) { countsAndKeyInfo => - Begin(FastSeq( - WriteMetadata(countsAndKeyInfo, rowTableWriter), - WriteMetadata( - ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKeyInfo => - InsertFields( - SelectFields( - countAndKeyInfo, - IndexedSeq("partitionCounts", "distinctlyKeyed"), - ), - IndexedSeq( - "firstKey" -> MakeStruct(FastSeq()), - "lastKey" -> MakeStruct(FastSeq()), - ), - ) - }), - entriesTableWriter, - ), - WriteMetadata( - makestruct( - "cols" -> GetField(colInfo, "partitionCounts"), - "rows" -> ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKey => - GetField(countAndKey, "partitionCounts") - }), - ), - matrixWriter, - ), - )) }, - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(path)), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/globals")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/cols")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/rows")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/entries")), - )) - } + TableSpecWriter( + s"$path/entries", + TableType(tm.entriesRVType, FastSeq(), tm.globalType), + "rows", + "../globals/rows", + "../references", + log = false, + ), + ) + + _ <- WriteMetadata( + makestruct( + "cols" -> GetField(colInfo, "partitionCounts"), + "rows" -> mapArray(parts)(part => GetField(part, "partitionCounts")), + ), + MatrixSpecWriter(path, tm, "rows/rows", "globals/rows", "cols/rows", "entries/rows"), + ) + + _ <- WriteMetadata(makestruct(), RelationalCommit(path)) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/globals")) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/cols")) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/rows")) + _ <- WriteMetadata(makestruct(), RelationalCommit(s"$path/entries")) + } yield Void() } - } } } } @@ -407,9 +390,6 @@ case class SplitPartitionNativeWriter( stageFolder: Option[Path], ) extends PartitionWriter { - val filenameType = PCanonicalString(required = true) - def pContextType = PCanonicalString() - val keyType = spec1.encodedVirtualType.asInstanceOf[TStruct].select(keyFieldNames)._1 override def ctxType: Type = TString @@ -501,11 +481,9 @@ case class SplitPartitionNativeWriter( val pCount = mb.newLocal[Long]("partition_count") cb.assign(pCount, 0L) + // True until proven otherwise, if there's a key to care about all. val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed") - cb.assign( - distinctlyKeyed, - !keyFieldNames.isEmpty, - ) // True until proven otherwise, if there's a key to care about all. + cb.assign(distinctlyKeyed, keyFieldNames.nonEmpty) val keyEmitType = EmitType(spec1.decodedPType(keyType).sType, false) @@ -562,7 +540,7 @@ case class SplitPartitionNativeWriter( keyType.fields.map(f => EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))): _* ) - if (!keyFieldNames.isEmpty) { + if (keyFieldNames.nonEmpty) { cb.if_( distinctlyKeyed, { lastSeenSettable.loadI(cb).consume( @@ -650,9 +628,7 @@ class MatrixSpecHelper( globalRelPath: String, colRelPath: String, entryRelPath: String, - refRelPath: String, typ: MatrixType, - log: Boolean, ) extends Logging with Serializable { def write(fs: FS, nCols: Long, partCounts: Array[Long]): Unit = { val spec = MatrixTableSpecParameters( @@ -686,8 +662,6 @@ case class MatrixSpecWriter( globalRelPath: String, colRelPath: String, entryRelPath: String, - refRelPath: String, - log: Boolean, ) extends MetadataWriter { override def annotationType: Type = TStruct("cols" -> TInt64, "rows" -> TArray(TInt64)) @@ -712,7 +686,7 @@ case class MatrixSpecWriter( }, ) cb += cb.emb.getObject(new MatrixSpecHelper(path, rowRelPath, globalRelPath, colRelPath, - entryRelPath, refRelPath, typ, log)) + entryRelPath, typ)) .invoke[FS, Long, Array[Long], Unit]( "write", cb.emb.getFS, @@ -787,25 +761,18 @@ case class MatrixVCFWriter( ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq( - "oldCtx" -> ctxElt, - "partFile" -> pf, - )) + makestruct("__old_ctx" -> ctxElt, "__part_file" -> pf) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { - (rows, ctxRef) => - val partFile = GetField(ctxRef, "partFile") + UUID4() + Str(ext) - val ctx = MakeStruct(FastSeq( - "cols" -> GetField(ts.globals, colsFieldName), - "partFile" -> partFile, - )) + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_vcf_writer") { (rows, ctxRef) => + val partFile = GetField(ctxRef, "__part_file") + UUID4() + Str(ext) + val ctx = makestruct("cols" -> GetField(ts.globals, colsFieldName), "partFile" -> partFile) WritePartition(rows, ctx, lineWriter) - } { (parts, globals) => - val ctx = - MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "partFiles" -> parts)) - val commit = VCFExportFinalizer(tm, path, appendStr, metadata, exportType, tabix) - Begin(FastSeq(WriteMetadata(ctx, commit))) - } + } { (parts, globals) => + val ctx = makestruct("cols" -> GetField(globals, colsFieldName), "partFiles" -> parts) + val commit = VCFExportFinalizer(tm, path, appendStr, metadata, exportType, tabix) + WriteMetadata(ctx, commit) + } } private def getAppendHeaderValue(fs: FS): Option[String] = append.map { f => @@ -1421,10 +1388,6 @@ case class MatrixGENWriter( r: RTable, ): IR = { val tm = MatrixType.fromTableType(ts.tableType, colsFieldName, entriesFieldName, colKey) - - val sampleWriter = new GenSampleWriter - - val lineWriter = GenVariantWriter(tm, entriesFieldName, precision) val folder = ctx.createTmpPath("export-gen") ts.mapContexts { oldCtx => @@ -1434,25 +1397,22 @@ case class MatrixGENWriter( ArraySeq.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-"), ) - zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq( - "oldCtx" -> ctxElt, - "partFile" -> pf, - )) + zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctx, pf) => + makestruct("__old_ctx" -> ctx, "__part_file" -> pf) + } + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_gen_writer") { (rows, ctxRef) => + val ctx = GetField(ctxRef, "__part_file") + UUID4() + WritePartition(rows, ctx, GenVariantWriter(tm, entriesFieldName, precision)) + } { (parts, globals) => + val cols = ToStream(GetField(globals, colsFieldName)) + val sampleFileName = Str(s"$path.sample") + val writeSamples = WritePartition(cols, sampleFileName, new GenSampleWriter) + val commitSamples = SimpleMetadataWriter(TString) + + val commit = TableTextFinalizer(s"$path.gen", ts.rowType, " ", header = false) + Begin(FastSeq(WriteMetadata(writeSamples, commitSamples), WriteMetadata(parts, commit))) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_gen_writer") { - (rows, ctxRef) => - val ctx = GetField(ctxRef, "partFile") + UUID4() - WritePartition(rows, ctx, lineWriter) - } { (parts, globals) => - val cols = ToStream(GetField(globals, colsFieldName)) - val sampleFileName = Str(s"$path.sample") - val writeSamples = WritePartition(cols, sampleFileName, sampleWriter) - val commitSamples = SimpleMetadataWriter(TString) - - val commit = TableTextFinalizer(s"$path.gen", ts.rowType, " ", header = false) - Begin(FastSeq(WriteMetadata(writeSamples, commitSamples), WriteMetadata(parts, commit))) - } } } @@ -1633,38 +1593,34 @@ case class MatrixBGENWriter( ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) + + // hint: don't writeHeader + val variantCounts = + if (writeHeader) ToStream(ts.countPerPartition().deepCopy) + else mapIR(rangeIR(ts.numPartitions))(_ => NA(TInt64)) + val partFiles = ToStream(Literal( TArray(TString), ArraySeq.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-"), )) - val numVariants = if (writeHeader) ToStream(ts.countPerPartition()) - else ToStream(MakeArray(ArraySeq.tabulate(ts.numPartitions)(_ => NA(TInt64)): _*)) - - val ctxElt = Ref(freshName(), tcoerce[TStream](oldCtx.typ).elementType) - val pf = Ref(freshName(), tcoerce[TStream](partFiles.typ).elementType) - val nv = Ref(freshName(), tcoerce[TStream](numVariants.typ).elementType) - - StreamZip( - FastSeq(oldCtx, partFiles, numVariants), - FastSeq(ctxElt.name, pf.name, nv.name), - MakeStruct(FastSeq("oldCtx" -> ctxElt, "numVariants" -> nv, "partFile" -> pf)), - ArrayZipBehavior.AssertSameLength, - ) - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { - (rows, ctxRef) => - val partFile = GetField(ctxRef, "partFile") + UUID4() - val ctx = MakeStruct(FastSeq( + + zipIR(FastSeq(oldCtx, variantCounts, partFiles), ArrayZipBehavior.AssertSameLength) { + case Seq(ctx, vc, pf) => + makestruct("__old_ctx" -> ctx, "__num_variants" -> vc, "__part_file" -> pf) + } + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_vcf_writer") { (rows, ctx) => + val writeCtx = makestruct( "cols" -> GetField(ts.globals, colsFieldName), - "numVariants" -> GetField(ctxRef, "numVariants"), - "partFile" -> partFile, - )) - WritePartition(rows, ctx, partWriter) - } { (results, globals) => - val ctx = - MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "results" -> results)) - val commit = BGENExportFinalizer(tm, path, exportType, compressionInt) - Begin(FastSeq(WriteMetadata(ctx, commit))) - } + "numVariants" -> GetField(ctx, "__num_variants"), + "partFile" -> (GetField(ctx, "__part_file") + UUID4()), + ) + WritePartition(rows, writeCtx, partWriter) + } { (results, globals) => + val ctx = makestruct("cols" -> GetField(globals, colsFieldName), "results" -> results) + val commit = BGENExportFinalizer(tm, path, exportType, compressionInt) + WriteMetadata(ctx, commit) + } } } @@ -1676,7 +1632,7 @@ case class BGENPartitionWriter( ) extends PartitionWriter { require(typ.entryType.hasField("GP") && typ.entryType.fieldType("GP") == TArray(TFloat64)) - val ctxType: Type = + override val ctxType: Type = TStruct("cols" -> TArray(typ.colType), "numVariants" -> TInt64, "partFile" -> TString) override def returnType: TStruct = @@ -2134,31 +2090,27 @@ case class MatrixPLINKWriter( ) zip2(oldCtx, ToStream(files), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq( - "oldCtx" -> ctxElt, - "file" -> pf, - )) + makestruct("__old_ctx" -> ctxElt, "__files" -> pf) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_plink_writer") { - (rows, ctxRef) => - val id = UUID4() - val bedFile = GetTupleElement(GetField(ctxRef, "file"), 0) + id - val bimFile = GetTupleElement(GetField(ctxRef, "file"), 1) + id - val ctx = MakeStruct(FastSeq("bedFile" -> bedFile, "bimFile" -> bimFile)) - WritePartition(rows, ctx, lineWriter) - } { (parts, globals) => - val commit = PLINKExportFinalizer(tm, path, tmpBedDir + "/header") - val famWriter = TableTextPartitionWriter(tm.colsTableType.rowType, "\t", writeHeader = false) - val famPath = Str(path + ".fam") - val cols = ToStream(GetField(globals, colsFieldName)) - val writeFam = WritePartition(cols, famPath, famWriter) - bindIR(writeFam) { fpath => + }(GetField(_, "__old_ctx")) + .mapCollectWithContextsAndGlobals("matrix_plink_writer") { (rows, ctxRef) => + bindIRs(UUID4(), GetField(ctxRef, "__files")) { case Seq(id, files) => + val bedFile = GetTupleElement(files, 0) + id + val bimFile = GetTupleElement(files, 1) + id + val ctx = makestruct("bedFile" -> bedFile, "bimFile" -> bimFile) + WritePartition(rows, ctx, lineWriter) + } + } { (parts, globals) => + val commit = PLINKExportFinalizer(tm, path, tmpBedDir + "/header") + val famWriter = + TableTextPartitionWriter(tm.colsTableType.rowType, "\t", writeHeader = false) + val cols = ToStream(GetField(globals, colsFieldName)) + val fpath = WritePartition(cols, Str(path + ".fam"), famWriter) Begin(FastSeq( WriteMetadata(parts, commit), WriteMetadata(fpath, SimpleMetadataWriter(fpath.typ)), )) } - } } } @@ -2166,10 +2118,10 @@ case class PLINKPartitionWriter(typ: MatrixType, entriesFieldName: String) exten val ctxType = TStruct("bedFile" -> TString, "bimFile" -> TString) override def returnType = TStruct("bedFile" -> TString, "bimFile" -> TString) - val locusIdx = typ.rowType.fieldIdx("locus") - val allelesIdx = typ.rowType.fieldIdx("alleles") - val varidIdx = typ.rowType.fieldIdx("varid") - val cmPosIdx = typ.rowType.fieldIdx("cm_position") + private[this] val locusIdx = typ.rowType.fieldIdx("locus") + private[this] val allelesIdx = typ.rowType.fieldIdx("alleles") + private[this] val varidIdx = typ.rowType.fieldIdx("varid") + private[this] val cmPosIdx = typ.rowType.fieldIdx("cm_position") override def unionTypeRequiredness( r: TypeWithRequiredness, @@ -2356,9 +2308,7 @@ case class MatrixBlockMatrixWriter( val numBlockCols: Int = BlockMatrixType.numBlocks(numCols.toLong, blockSize) val lastBlockNumCols = (numCols - 1) % blockSize + 1 - val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(paritionIR => - StreamLen(paritionIR) - ) + val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(StreamLen(_)) val inputRowCountPerPartition: IndexedSeq[Int] = CompileAndEvaluate[IndexedSeq[Int]](ctx, rowCountIR) val inputPartStartsPlusLast = inputRowCountPerPartition.scanLeft(0L)(_ + _) @@ -2369,32 +2319,33 @@ case class MatrixBlockMatrixWriter( val numBlockRows: Int = BlockMatrixType.numBlocks(numRows, blockSize) // Zip contexts with partition starts and ends - val zippedWithStarts = ts.mapContexts { oldContextsStream => - zipIR( - IndexedSeq( - oldContextsStream, - ToStream(Literal(TArray(TInt64), inputPartStarts)), - ToStream(Literal(TArray(TInt64), inputPartStops)), - ), - ArrayZipBehavior.AssertSameLength, - ) { case IndexedSeq(oldCtx, partStart, partStop) => - MakeStruct(FastSeq( - "mwOld" -> oldCtx, - "mwStartIdx" -> Cast(partStart, TInt32), - "mwStopIdx" -> Cast(partStop, TInt32), - )) - } - }(newCtx => GetField(newCtx, "mwOld")) + val zippedWithStarts = + ts.mapContexts { oldContextsStream => + zipIR( + FastSeq( + oldContextsStream, + ToStream(Literal(TArray(TInt64), inputPartStarts)), + ToStream(Literal(TArray(TInt64), inputPartStops)), + ), + ArrayZipBehavior.AssertSameLength, + ) { case Seq(oldCtx, partStart, partStop) => + makestruct( + "__mw_old_ctx" -> oldCtx, + "__mw_start_idx" -> partStart.toI, + "__mw_stop_idx" -> partStop.toI, + ) + } + }(GetField(_, "__mw_old_ctx")) // Now label each row with its idx. - val perRowIdxId = genUID() - val partsZippedWithIdx = zippedWithStarts.mapPartitionWithContext { (part, ctx) => - zip2( - part, - rangeIR(GetField(ctx, "mwStartIdx"), GetField(ctx, "mwStopIdx")), - ArrayZipBehavior.AssertSameLength, - )((partRow, idx) => insertIR(partRow, (perRowIdxId, idx))) - } + val partsZippedWithIdx = + zippedWithStarts.mapPartitionWithContext { (part, ctx) => + zip2( + part, + rangeIR(GetField(ctx, "__mw_start_idx"), GetField(ctx, "__mw_stop_idx")), + ArrayZipBehavior.AssertSameLength, + )((partRow, idx) => insertIR(partRow, "__per_row_idx" -> idx)) + } /* Two steps, make a partitioner that works currently based on row_idx splits, then resplit * accordingly. */ @@ -2404,84 +2355,93 @@ case class MatrixBlockMatrixWriter( } val rowIdxPartitioner = - new RVDPartitioner(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals) + new RVDPartitioner(ctx.stateManager, TStruct("__per_row_idx" -> TInt32), inputRowIntervals) val keyedByRowIdx = partsZippedWithIdx.changePartitionerNoRepartition(rowIdxPartitioner) // Now create a partitioner that makes appropriately sized blocks - val desiredRowStarts = (0 until numBlockRows).map(_ * blockSize) + val desiredRowStarts = ArraySeq.tabulate(numBlockRows)(_ * blockSize) val desiredRowStops = desiredRowStarts.drop(1) :+ numRows.toInt - val desiredRowIntervals = desiredRowStarts.zip(desiredRowStops).map { - case (intervalStart, intervalEnd) => - Interval(Row(intervalStart), Row(intervalEnd), true, false) - } + val desiredRowIntervals = + desiredRowStarts + .view + .zip(desiredRowStops) + .map { case (start, end) => Interval(Row(start), Row(end), true, false) } + .to(ArraySeq) val blockSizeGroupsPartitioner = - RVDPartitioner.generate(ctx.stateManager, TStruct((perRowIdxId, TInt32)), desiredRowIntervals) + RVDPartitioner.generate( + ctx.stateManager, + TStruct("__per_row_idx" -> TInt32), + desiredRowIntervals, + ) val rowsInBlockSizeGroups: TableStage = keyedByRowIdx.repartitionNoShuffle(ctx, blockSizeGroupsPartitioner) - def createBlockMakingContexts(tablePartsStreamIR: IR): IR = { + def createBlockMakingContexts(tablePartsStreamIR: Atom): IR = flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { - case (tableSinglePartCtx, blockRowIdx) => - mapIR(rangeIR(I32(numBlockCols))) { blockColIdx => - MakeStruct(FastSeq( - "oldTableCtx" -> tableSinglePartCtx, - "blockStart" -> (blockColIdx * I32(blockSize)), - "blockSize" -> If( + (tableSinglePartCtx, blockRowIdx) => + mapIR(rangeIR(numBlockCols)) { blockColIdx => + makestruct( + "__old_table_ctx" -> tableSinglePartCtx, + "__block_start" -> (blockColIdx * I32(blockSize)), + "__block_size" -> If( blockColIdx ceq I32(numBlockCols - 1), - I32(lastBlockNumCols), - I32(blockSize), + lastBlockNumCols, + blockSize, ), - "blockColIdx" -> blockColIdx, - "blockRowIdx" -> blockRowIdx, - )) + "__block_col_idx" -> blockColIdx, + "__block_row_idx" -> blockRowIdx, + ) } }) - } - val tableOfNDArrays = rowsInBlockSizeGroups.mapContexts(createBlockMakingContexts)(ir => - GetField(ir, "oldTableCtx") - ).mapPartitionWithContext { (partIr, ctxRef) => - bindIR(GetField(ctxRef, "blockStart")) { blockStartRef => - val numColsOfBlock = GetField(ctxRef, "blockSize") - val arrayOfSlicesAndIndices = ToArray(mapIR(partIr) { singleRow => - val mappedSlice = ToArray(mapIR(ToStream(sliceArrayIR( - GetField(singleRow, entriesFieldName), - blockStartRef, - blockStartRef + numColsOfBlock, - )))(entriesStructRef => - GetField(entriesStructRef, entryField) - )) - MakeStruct(FastSeq( - perRowIdxId -> GetField(singleRow, perRowIdxId), - "rowOfData" -> mappedSlice, - )) - }) - bindIR(arrayOfSlicesAndIndices) { arrayOfSlicesAndIndicesRef => - val idxOfResult = GetField(ArrayRef(arrayOfSlicesAndIndicesRef, I32(0)), perRowIdxId) - val ndarrayData = ToArray(flatMapIR(ToStream(arrayOfSlicesAndIndicesRef)) { idxAndSlice => - ToStream(GetField(idxAndSlice, "rowOfData")) - }) - val numRowsOfBlock = ArrayLen(arrayOfSlicesAndIndicesRef) - val shape = maketuple(Cast(numRowsOfBlock, TInt64), Cast(numColsOfBlock, TInt64)) - val ndarray = MakeNDArray(ndarrayData, shape, True(), ErrorIDs.NO_ERROR) - MakeStream( - FastSeq(MakeStruct(FastSeq( - perRowIdxId -> idxOfResult, - "blockRowIdx" -> GetField(ctxRef, "blockRowIdx"), - "blockColIdx" -> GetField(ctxRef, "blockColIdx"), - "ndBlock" -> ndarray, - ))), - TStream(TStruct( - perRowIdxId -> TInt32, - "blockRowIdx" -> TInt32, - "blockColIdx" -> TInt32, - "ndBlock" -> ndarray.typ, - )), - ) + val tableOfNDArrays = + rowsInBlockSizeGroups + .mapContexts(createBlockMakingContexts)(GetField(_, "__old_table_ctx")) + .mapPartitionWithContext { (part, ctx) => + M.eval { + for { + blockStart <- GetField(ctx, "__block_start") + blockSize <- GetField(ctx, "__block_size") + blockEnd <- blockStart + blockSize + + data <- streamAggIR(part) { row => + makestruct( + "__num_rows" -> + ApplyAggOp(Count())(), + "__result_idx" -> + ArrayRef( + ApplyAggOp( + FastSeq(I32(1)), + FastSeq(GetField(row, "__per_row_idx")), + Take(), + ), + I32(0), + ), + "__block_data" -> { + val slices = sliceArrayIR(GetField(row, entriesFieldName), blockStart, blockEnd) + val elem = Ref(Name("__elem"), TIterable.elementType(slices.typ)) + val collect = ApplyAggOp(Collect())(GetField(elem, entryField)) + AggExplode(ToStream(slices), elem.name, collect, isScan = false) + }, + ) + } + + numRowsOfBlock <- GetField(data, "__num_rows") + idxOfResult <- GetField(data, "__result_idx") + ndarrayData <- GetField(data, "__block_data") + shape <- maketuple(numRowsOfBlock, blockSize.toL) + ndarray <- MakeNDArray(ndarrayData, shape, True(), ErrorIDs.NO_ERROR) + } yield MakeStream.single( + makestruct( + "__per_row_idx" -> idxOfResult, + "__block_row_idx" -> GetField(ctx, "__block_row_idx"), + "__block_col_idx" -> GetField(ctx, "__block_col_idx"), + "__block" -> ndarray, + ) + ) + } } - } - } val elementType = tm.entryType.fieldType(entryField) val etype = EBlockMatrixNDArray( @@ -2499,25 +2459,24 @@ case class MatrixBlockMatrixWriter( val pathsWithColMajorIndices = tableOfNDArrays.mapCollect("matrix_block_matrix_writer") { partition => ToArray(mapIR(partition) { singleNDArrayTuple => - bindIR(GetField(singleNDArrayTuple, "blockRowIdx") + (GetField( - singleNDArrayTuple, - "blockColIdx", - ) * numBlockRows)) { colMajorIndex => - val blockPath = - Str(s"$path/parts/part-") + - invoke("str", TString, colMajorIndex) + Str("-") + UUID4() - maketuple( - colMajorIndex, - WriteValue(GetField(singleNDArrayTuple, "ndBlock"), blockPath, writer), - ) + M.eval { + for { + rowIdx <- GetField(singleNDArrayTuple, "__block_row_idx") + colIdx <- GetField(singleNDArrayTuple, "__block_col_idx") + colMajorIndex <- rowIdx + (colIdx * numBlockRows) + blockPath <- strConcat(s"$path/parts/part-", colMajorIndex, "-", UUID4()) + ndarray <- GetField(singleNDArrayTuple, "__block") + } yield maketuple(colMajorIndex, WriteValue(ndarray, blockPath, writer)) } }) } - val flatPathsAndIndices = flatMapIR(ToStream(pathsWithColMajorIndices))(ToStream(_)) - val sortedColMajorPairs = sortIR(flatPathsAndIndices) { case (l, r) => - ApplyComparisonOp(LT, GetTupleElement(l, 0), GetTupleElement(r, 0)) - } - val flatPaths = ToArray(mapIR(ToStream(sortedColMajorPairs))(GetTupleElement(_, 1))) + + val sortedColMajorPairs = + sortIR(flatten(pathsWithColMajorIndices)) { (l, r) => + GetTupleElement(l, 0) < GetTupleElement(r, 0) + } + + val flatPaths = mapArray(sortedColMajorPairs)(GetTupleElement(_, 1)) val bmt = BlockMatrixType( elementType, numRows, @@ -2527,7 +2486,7 @@ case class MatrixBlockMatrixWriter( ) RelationalWriter.scoped(path, overwrite, None)(WriteMetadata( flatPaths, - BlockMatrixNativeMetadataWriter(path, false, bmt), + BlockMatrixNativeMetadataWriter(path, stageLocally = false, bmt), )) } } @@ -2563,20 +2522,20 @@ case class MatrixNativeMultiWriter( require(tables.map(_._4.tableType.keyType).distinct.length == 1) val unionType = TTuple(components.map(c => TIterable.elementType(c.stage.contexts.typ)): _*) - val contextUnionType = TStruct("matrixId" -> TInt32, "options" -> unionType) + val contextUnionType = TStruct("__matrix_id" -> TInt32, "__options" -> unionType) val emptyUnionIRs: IndexedSeq[(Int, IR)] = - IndexedSeq.tabulate(unionType.size)(i => i -> NA(unionType.types(i))) + ArraySeq.tabulate(unionType.size)(i => i -> NA(unionType.types(i))) val concatenatedContexts = flatten( MakeArray( components.zipWithIndex.map { case (c, matrixId) => ToArray(mapIR(c.stage.contexts) { ctx => - MakeStruct(FastSeq( - "matrixId" -> I32(matrixId), - "options" -> MakeTuple(emptyUnionIRs.updated(matrixId, matrixId -> ctx.ir)), - )) + makestruct( + "__matrix_id" -> I32(matrixId), + "__options" -> MakeTuple(emptyUnionIRs.updated(matrixId, matrixId -> ctx.ir)), + ) }) }, TArray(TArray(contextUnionType)), @@ -2587,42 +2546,49 @@ case class MatrixNativeMultiWriter( n.str -> ir }) - Begin(FastSeq( - Begin(components.map(_.setup)), - Let( - components.flatMap(_.stage.letBindings), - bindIR(cdaIR(concatenatedContexts, allBroadcasts, "matrix_multi_writer") { - case (ctx, globals) => - bindIR(GetField(ctx, "options")) { options => - Switch( - GetField(ctx, "matrixId"), - default = Die("MatrixId exceeds matrix count", components.head.writePartitionType), - cases = components.zipWithIndex.map { case (component, i) => - val binds = component.stage.broadcastVals.map { case (name, _) => - name -> GetField(globals, name.str) - } - - Let( - binds, - bindIR(GetTupleElement(options, i)) { ctxRef => - component.writePartition(component.stage.partition(ctxRef), ctxRef) - }, - ) - }, - ) - } - }) { cdaResult => - val partitionCountScan = - components.map(_.stage.numPartitions).scanLeft(0)(_ + _) - - Begin(components.zipWithIndex.map { case (c, i) => - c.finalizeWrite( - ArraySlice(cdaResult, partitionCountScan(i), Some(partitionCountScan(i + 1))), - c.stage.globals, - ) + M.eval { + for { + _ <- M.defer { b => + components.foreach(c => b.memoize(c.setup)) + components.foreach(_.stage.letBindings.foreach { case (name, value) => + b.strictMemoize(value, name) }) - }, - ), - )) + Void() + } + + result <- cdaIR(concatenatedContexts, allBroadcasts, "matrix_multi_writer") { + (ctx, globals) => + Switch( + GetField(ctx, "__matrix_id"), + default = Die("MatrixId exceeds matrix count", components.head.writePartitionType), + cases = components.zipWithIndex.map { case (component, i) => + val binds = component.stage.broadcastVals.map { case (name, _) => + name -> GetField(globals, name.str) + } + + Let( + binds, + IRBuilder.scoped { b => + val options = GetField(ctx, "__options") + val ctxRef = b.memoize(GetTupleElement(options, i)) + val rows = b.memoize(component.stage.partition(ctxRef)) + component.writePartition(rows, ctxRef) + }, + ) + }, + ) + } + + partCounts = components.map(_.stage.numPartitions).scanLeft(0)(_ + _) + + _ <- M.defer { b => + components.zipWithIndex.foreach { case (c, i) => + val part = b.memoize(sliceArrayIR(result, partCounts(i), partCounts(i + 1))) + b.memoize(c.finalizeWrite(part, c.stage.globals)) + } + Void() + } + } yield Void() + } } } diff --git a/hail/hail/src/is/hail/expr/ir/Optimize.scala b/hail/hail/src/is/hail/expr/ir/Optimize.scala index 9bdd92a30f9..9d0fbee4296 100644 --- a/hail/hail/src/is/hail/expr/ir/Optimize.scala +++ b/hail/hail/src/is/hail/expr/ir/Optimize.scala @@ -54,7 +54,7 @@ object Optimize { catch { case NonFatal(e) => fatal( - s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir, preserveNames = true)}", + s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir)}", e, ) } diff --git a/hail/hail/src/is/hail/expr/ir/Pretty.scala b/hail/hail/src/is/hail/expr/ir/Pretty.scala index a25d72fa8d5..24fa206b16f 100644 --- a/hail/hail/src/is/hail/expr/ir/Pretty.scala +++ b/hail/hail/src/is/hail/expr/ir/Pretty.scala @@ -28,7 +28,7 @@ object Pretty { elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false, - preserveNames: Boolean = false, + preserveNames: Boolean = true, ): String = { val useSSA = ctx != null && ctx.getFlag("use_ssa_logs") != null val pretty = @@ -56,7 +56,7 @@ object Pretty { elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false, - preserveNames: Boolean = false, + preserveNames: Boolean = true, ): String = { val pretty = new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA = true, @@ -773,7 +773,7 @@ class Pretty( if (i == 1) some(MatrixIR.globalName -> "g") else None case _: MatrixMapRows => - if (i == 1) matrixBlockArgs map { _ :+ (Name("n_rows") -> "n_rows") } + if (i == 1) matrixBlockArgs map { _ :+ (Name("n_cols") -> "n_cols") } else None case NDArrayMap(_, name, _) => if (i == 1) some(name -> "elt") diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 787de9983fb..ef59cb51e5c 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -103,7 +103,7 @@ object TableReader { object LoweredTableReader extends Logging { type LoweredTableReaderCoercer = - (ExecuteContext, IR, Type, IndexedSeq[Any], IR => IR) => TableStage + (ExecuteContext, IR, Type, IndexedSeq[Any], Atom => IR) => TableStage def makeCoercer( ctx: ExecuteContext, @@ -360,7 +360,7 @@ object LoweredTableReader extends Logging { globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR, + body: Atom => IR, ) => { val partOrigIndex = sortedPartData.map(_.getInt(6)) @@ -399,7 +399,7 @@ object LoweredTableReader extends Logging { globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR, + body: Atom => IR, ) => { val partOrigIndex = sortedPartData.map(_.getInt(6)) @@ -444,7 +444,7 @@ object LoweredTableReader extends Logging { globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR, + body: Atom => IR, ) => { val partOrigIndex = sortedPartData.map(_.getInt(6)) @@ -531,7 +531,7 @@ abstract class TableReader { RVDPartitioner.empty(ctx, requestedType.keyType), TableStageDependency.none, MakeStream(FastSeq(), TStream(TStruct.empty)), - (_: Ref) => MakeStream(FastSeq(), TStream(requestedType.rowType)), + _ => MakeStream(FastSeq(), TStream(requestedType.rowType)), ) } else { lower(ctx, requestedType) diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index 8e2a362d6d8..c71e8d1898d 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -579,9 +579,9 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow ref, FastSeq(field -> (if (i == refs.length - 1) - ArrayRef(CastToArray(GetField(ref, field)), arg) + ArrayRef(CastToArray(GetField(ref.ir, field)), arg) else - Let(FastSeq(refs(i + 1).name -> GetField(ref, field)), arg))), + Let(FastSeq(refs(i + 1).name -> GetField(ref.ir, field)), arg))), ) }.asInstanceOf[InsertFields] } diff --git a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala index faf4a2a6dfb..474e07ed64d 100644 --- a/hail/hail/src/is/hail/expr/ir/agg/Extract.scala +++ b/hail/hail/src/is/hail/expr/ir/agg/Extract.scala @@ -708,7 +708,10 @@ object Extract { val init = Begin(initOps) initBuilder += InitOp( i, - knownLength.fold(ArraySeq(init))(ArraySeq(_, init)), + knownLength.fold(ArraySeq(init)) { ir => + bindInitArgRefs(FastSeq(ir)) + ArraySeq(ir, init) + }, checkSig, ) diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala index dcf82b6481d..50fcffc7766 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala @@ -898,14 +898,14 @@ class BlockMatrixStage2 private ( RVDPartitioner.unkeyed(ctx.stateManager, bmTyp.nDefinedBlocks), TableStageDependency.none, contextsIR, - { newCtxRef => - val s = makestruct( - "blockRow" -> GetTupleElement(newCtxRef, 0), - "blockCol" -> GetTupleElement(newCtxRef, 1), - "block" -> Let(FastSeq(ctxRefName -> GetTupleElement(newCtxRef, 2)), _blockIR), - ) - MakeStream(FastSeq(s), TStream(s.typ)) - }, + newCtxRef => + MakeStream.single( + makestruct( + "blockRow" -> GetTupleElement(newCtxRef, 0), + "blockCol" -> GetTupleElement(newCtxRef, 1), + "block" -> Let(FastSeq(ctxRefName -> GetTupleElement(newCtxRef, 2)), _blockIR), + ) + ), ) } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala index eee29399281..1a9426931a3 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -3,9 +3,9 @@ package is.hail.expr.ir.lowering import is.hail.backend.ExecuteContext import is.hail.collection.FastSeq import is.hail.collection.compat.immutable.ArraySeq -import is.hail.collection.implicits.toRichIterable -import is.hail.expr.ir.{agg, TableNativeWriter, _} -import is.hail.expr.ir.agg.AggExecuteContextExtensions +import is.hail.collection.implicits.toRichArray +import is.hail.expr.ir.{Memoized => M, _} +import is.hail.expr.ir.agg.{AggExecuteContextExtensions, Extract} import is.hail.expr.ir.analyses.PartitionCounts import is.hail.expr.ir.defs._ import is.hail.expr.ir.defs.ArrayZipBehavior.AssertSameLength @@ -30,13 +30,13 @@ object TableStage { partitioner: RVDPartitioner, dependency: TableStageDependency, contexts: IR, - body: (Ref) => IR, + body: Atom => IR, ): TableStage = { val globalsRef = Ref(freshName(), globals.typ) TableStage( FastSeq(globalsRef.name -> globals), FastSeq(globalsRef.name -> globalsRef), - globalsRef, + globalsRef.ir, partitioner, dependency, contexts, @@ -47,11 +47,11 @@ object TableStage { def apply( letBindings: IndexedSeq[(Name, IR)], broadcastVals: IndexedSeq[(Name, IR)], - globals: Ref, + globals: Atom, partitioner: RVDPartitioner, dependency: TableStageDependency, contexts: IR, - partition: Ref => IR, + partition: Atom => IR, ): TableStage = { val ctxType = contexts.typ.asInstanceOf[TStream].elementType val ctxRef = Ref(freshName(), ctxType) @@ -74,17 +74,16 @@ object TableStage { assert(children.forall(_.kType == keyType)) val ctxType = TTuple(children.map(_.ctxType): _*) - val ctxArrays = children.view.zipWithIndex.map { case (child, idx) => + val ctxArrays = children.zipWithIndex.map { case (child, idx) => ToArray(mapIR(child.contexts) { ctx => MakeTuple.ordered(children.indices.map { idx2 => if (idx == idx2) ctx.ir else NA(children(idx2).ctxType) }) }) - }.toFastSeq - val ctxs = flatMapIR(MakeStream(ctxArrays, TStream(TArray(ctxType)))) { ctxArray => - ToStream(ctxArray) } + val ctxs = flatten(MakeStream(ctxArrays, TStream(TArray(ctxType)))) + val newGlobals = children.head.globals val globalsRef = Ref(freshName(), newGlobals.typ) val newPartitioner = @@ -93,24 +92,23 @@ object TableStage { TableStage( children.flatMap(_.letBindings) :+ globalsRef.name -> newGlobals.ir, children.flatMap(_.broadcastVals) :+ globalsRef.name -> globalsRef, - globalsRef, + globalsRef.ir, newPartitioner, TableStageDependency.union(children.map(_.dependency)), ctxs, - (ctxRef: Ref) => { + ctxRef => StreamMultiMerge( children.indices.map { i => bindIR(GetTupleElement(ctxRef, i)) { ctx => If( IsNA(ctx), - MakeStream(IndexedSeq(), TStream(children(i).rowType)), + MakeStream(FastSeq(), TStream(children(i).rowType)), children(i).partition(ctx), ) } }, - IndexedSeq(), - ) - }, + FastSeq(), + ), ) } } @@ -125,7 +123,7 @@ object TableStage { class TableStage( val letBindings: IndexedSeq[(Name, IR)], val broadcastVals: IndexedSeq[(Name, IR)], - val globals: Ref, + val globals: Atom, val partitioner: RVDPartitioner, val dependency: TableStageDependency, val contexts: IR, @@ -134,25 +132,12 @@ class TableStage( ) extends Logging { self => - // useful for debugging, but should be disabled in production code due to N^2 complexity - // typecheckPartition() - contexts.typ match { case TStream(t) if t.isRealizable => case t => throw new IllegalArgumentException(s"TableStage constructed with illegal context type $t") } - def typecheckPartition(ctx: ExecuteContext): Unit = - TypeCheck( - ctx, - partitionIR, - BindingEnv(Env[Type](((letBindings ++ broadcastVals).map { case (s, x) => (s, x.typ) }) - ++ FastSeq[(Name, Type)]( - (ctxRefName, contexts.typ.asInstanceOf[TStream].elementType) - ): _*)), - ) - def upcast(ctx: ExecuteContext, newType: TableType): TableStage = { val newRowType = newType.rowType val newGlobalType = newType.globalType @@ -175,12 +160,16 @@ class TableStage( TableType(rowType, key, globalType) assert(kType.isSubsetOf(rowType), s"Key type $kType is not a subset of $rowType") - assert(broadcastVals.exists { case (name, value) => name == globals.name && value == globals }) + + assert(globals match { + case r: Ref => broadcastVals.exists { case (n, v) => n == r.name && v == r } + case _ => false + }) def copy( letBindings: IndexedSeq[(Name, IR)] = letBindings, broadcastVals: IndexedSeq[(Name, IR)] = broadcastVals, - globals: Ref = globals, + globals: Atom = globals, partitioner: RVDPartitioner = partitioner, dependency: TableStageDependency = dependency, contexts: IR = contexts, @@ -190,14 +179,26 @@ class TableStage( new TableStage(letBindings, broadcastVals, globals, partitioner, dependency, contexts, ctxRefName, partitionIR) - def partition(ctx: IR): IR = { + def deepCopy: TableStage = + new TableStage( + letBindings.map { case (n, r) => n -> r.deepCopy }, + broadcastVals.map { case (n, r) => n -> r.deepCopy }, + globals, + partitioner, + dependency, + contexts.deepCopy, + ctxRefName, + partitionIR.deepCopy, + ) + + def partition(ctx: Atom): IR = { require(ctx.typ == ctxType) Let(FastSeq(ctxRefName -> ctx), partitionIR) } def numPartitions: Int = partitioner.numPartitions - def mapPartition(newKey: Option[IndexedSeq[String]])(f: IR => IR): TableStage = { + def mapPartition(newKey: Option[IndexedSeq[String]])(f: Atom => IR): TableStage = { val part = newKey match { case Some(k) => if (!partitioner.kType.fieldNames.startsWith(k)) @@ -207,27 +208,15 @@ class TableStage( partitioner.coarsen(k.length) case None => partitioner } - copy(partitionIR = f(partitionIR), partitioner = part) + copy(partitionIR = bindIR(partitionIR)(f), partitioner = part) } - def zipPartitions(right: TableStage, newGlobals: (IR, IR) => IR, body: (IR, IR) => IR) - : TableStage = { + def zipPartitions( + right: TableStage, + newGlobals: (Atom, Atom) => IR, + body: (Atom, Atom) => IR, + ): TableStage = { val left = this - val leftCtxTyp = left.ctxType - val rightCtxTyp = right.ctxType - - val leftCtxRef = Ref(freshName(), leftCtxTyp) - val rightCtxRef = Ref(freshName(), rightCtxTyp) - - val leftCtxStructField = genUID() - val rightCtxStructField = genUID() - - val zippedCtxs = StreamZip( - FastSeq(left.contexts, right.contexts), - FastSeq(leftCtxRef.name, rightCtxRef.name), - MakeStruct(FastSeq(leftCtxStructField -> leftCtxRef, rightCtxStructField -> rightCtxRef)), - ArrayZipBehavior.AssertSameLength, - ) val globals = newGlobals(left.globals, right.globals) val globalsRef = Ref(freshName(), globals.typ) @@ -235,73 +224,69 @@ class TableStage( TableStage( left.letBindings ++ right.letBindings :+ (globalsRef.name -> globals), left.broadcastVals ++ right.broadcastVals :+ (globalsRef.name -> globalsRef), - globalsRef, + globalsRef.ir, left.partitioner, left.dependency.union(right.dependency), - zippedCtxs, - (ctxRef: Ref) => - bindIR(left.partition(GetField(ctxRef, leftCtxStructField))) { lPart => - bindIR(right.partition(GetField(ctxRef, rightCtxStructField))) { rPart => - body(lPart, rPart) - } + zip2(left.contexts, right.contexts, AssertSameLength)(maketuple(_, _)), + ctxRef => + M.eval { + for { + lctx <- GetTupleElement(ctxRef, 0) + lpart <- left.partition(lctx) + rctx <- GetTupleElement(ctxRef, 1) + rpart <- right.partition(rctx) + } yield body(lpart, rpart) }, ) } - def mapPartitionWithContext(f: (IR, Ref) => IR): TableStage = - copy(partitionIR = f(partitionIR, Ref(ctxRefName, ctxType))) + def mapPartitionWithContext(f: (Atom, Atom) => IR): TableStage = + copy(partitionIR = bindIR(partitionIR)(f(_, Ref(ctxRefName, ctxType)))) - def mapContexts(f: IR => IR)(getOldContext: IR => IR): TableStage = { - val newContexts = f(contexts) - TableStage( - letBindings, - broadcastVals, - globals, - partitioner, - dependency, - newContexts, - ctxRef => bindIR(getOldContext(ctxRef))(partition(_)), + def mapContexts(f: Atom => IR)(getOldContext: Atom => IR): TableStage = { + val newContexts = bindIR(contexts)(f) + val newCtxRef = Ref(freshName(), TIterable.elementType(newContexts.typ)) + copy( + contexts = newContexts, + ctxRefName = newCtxRef.name, + partitionIR = bindIR(getOldContext(newCtxRef))(partition), ) } - def zipContextsWithIdx(): TableStage = { - def getOldContext(ctx: IR) = GetField(ctx, "elt") - mapContexts(zipWithIndex)(getOldContext) - } - - def mapGlobals(f: IR => IR): TableStage = { + def mapGlobals(f: Atom => IR): TableStage = { val newGlobals = f(globals) val globalsRef = Ref(freshName(), newGlobals.typ) copy( letBindings = letBindings :+ globalsRef.name -> newGlobals, broadcastVals = broadcastVals :+ globalsRef.name -> globalsRef, - globals = globalsRef, + globals = globalsRef.ir, ) } - def mapCollect(staticID: String, dynamicID: IR = NA(TString))(f: IR => IR): IR = - mapCollectWithGlobals(staticID, dynamicID)(f)((parts, globals) => parts) + def mapCollect(staticID: String, dynamicID: IR = NA(TString))(f: Atom => IR): IR = + mapCollectWithGlobals(staticID, dynamicID)(f)((parts, _) => parts) def mapCollectWithGlobals( staticID: String, dynamicID: IR = NA(TString), )( - mapF: IR => IR + mapF: Atom => IR )( - body: (IR, IR) => IR + body: (Atom, Atom) => IR ): IR = - mapCollectWithContextsAndGlobals(staticID, dynamicID)((part, ctx) => mapF(part))(body) + mapCollectWithContextsAndGlobals(staticID, dynamicID)((part, _) => mapF(part))(body) // mapf is (part, ctx) => ???, body is (parts, globals) => ??? def mapCollectWithContextsAndGlobals( staticID: String, dynamicID: IR = NA(TString), )( - mapF: (IR, Ref) => IR + mapF: (Atom, Atom) => IR )( - body: (IR, IR) => IR + body: (Atom, Atom) => IR ): IR = { + val broadcastRefs = MakeStruct(broadcastVals.map { case (n, ir) => n.str -> ir }) val glob = Ref(freshName(), broadcastRefs.typ) @@ -311,8 +296,8 @@ class TableStage( ctxRefName, glob.name, Let( - broadcastVals.map { case (name, _) => name -> GetField(glob, name.str) }, - mapF(partitionIR, Ref(ctxRefName, ctxType)), + broadcastVals.map { case (name, _) => name -> GetField(glob.ir, name.str) }, + bindIR(partitionIR)(mapF(_, Ref(ctxRefName, ctxType))), ), dynamicID, staticID, @@ -323,11 +308,11 @@ class TableStage( } def collectWithGlobals(staticID: String, dynamicID: IR = NA(TString)): IR = - mapCollectWithGlobals(staticID, dynamicID)(ToArray) { (parts, globals) => - MakeStruct(FastSeq( + mapCollectWithGlobals(staticID, dynamicID)(ToArray(_)) { (parts, globals) => + makestruct( "rows" -> ToArray(flatMapIR(ToStream(parts))(ToStream(_))), "global" -> globals, - )) + ) } def countPerPartition(): IR = @@ -377,7 +362,7 @@ class TableStage( } ) { val newToOld = startAndEnd.groupBy(_._1._1).map { case (newIdx, values) => - (newIdx, values.map(_._2).sorted.toIndexedSeq) + (newIdx, values.map(_._2).sorted) } val (oldPartIndices, newPartitionerFilt) = @@ -412,7 +397,7 @@ class TableStage( newPartitionerFilt, dependency, newContexts, - (ctx: Ref) => flatMapIR(ToStream(ctx, true))(oldCtx => partition(oldCtx)), + ctx => flatMapIR(ToStream(ctx, true))(oldCtx => partition(oldCtx)), ) } @@ -428,38 +413,20 @@ class TableStage( "parentPartitions" -> TArray(TInt32), ) - val prevContextUID = freshName() - val mappingUID = freshName() - val idxUID = freshName() - val newContexts = Let( - FastSeq(prevContextUID -> ToArray(contexts)), - StreamMap( - ToStream( - Literal( - TArray(partitionMappingType), - partitionMapping, - ) - ), - mappingUID, - MakeStruct( - FastSeq( - "partitionBound" -> GetField(Ref(mappingUID, partitionMappingType), "partitionBound"), + val newContexts = + bindIR(ToArray(contexts)) { ctx => + mapIR(ToStream(Literal(TArray(partitionMappingType), partitionMapping))) { mapping => + makestruct( + "partitionBound" -> GetField(mapping, "partitionBound"), "oldContexts" -> ToArray( - StreamMap( - ToStream(GetField(Ref(mappingUID, partitionMappingType), "parentPartitions")), - idxUID, - ArrayRef( - Ref(prevContextUID, TArray(contexts.typ.asInstanceOf[TStream].elementType)), - Ref(idxUID, TInt32), - ), - ) + mapIR(ToStream(GetField(mapping, "parentPartitions"))) { + idx => ArrayRef(ctx, idx) + } ), ) - ), - ), - ) + } + } - val prevContextUIDPartition = freshName() TableStage( letBindings, broadcastVals, @@ -467,19 +434,11 @@ class TableStage( newPartitioner, dependency, newContexts, - (ctxRef: Ref) => { - val body = self.partition(Ref( - prevContextUIDPartition, - self.contexts.typ.asInstanceOf[TStream].elementType, - )) + { ctxRef => bindIR(GetField(ctxRef, "partitionBound")) { interval => takeWhile( dropWhile( - StreamFlatMap( - ToStream(GetField(ctxRef, "oldContexts"), true), - prevContextUIDPartition, - body, - ) + flatMapIR(ToStream(GetField(ctxRef, "oldContexts"), true))(partition) ) { elt => invoke( "pointLessThanPartitionIntervalLeftEndpoint", @@ -488,7 +447,6 @@ class TableStage( invoke("start", boundType.pointType, interval), invoke("includesStart", TBoolean, interval), ) - } ) { elt => invoke( @@ -551,8 +509,8 @@ class TableStage( right: TableStage, joinKey: Int, joinType: String, - globalJoiner: (IR, IR) => IR, - joiner: (Ref, Ref) => IR, + globalJoiner: (Atom, Atom) => IR, + joiner: (Atom, Atom) => IR, rightKeyIsDistinct: Boolean = false, ): TableStage = { assert(this.kType.truncate(joinKey).isJoinableWith(right.kType.truncate(joinKey))) @@ -574,29 +532,16 @@ class TableStage( } val repartitionedLeft: TableStage = repartitionNoShuffle(ec, newPartitioner) - val partitionJoiner: (IR, IR) => IR = (lPart, rPart) => { - val lEltType = lPart.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - val rEltType = rPart.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - - val lKey = this.kType.fieldNames.take(joinKey) - val rKey = right.kType.fieldNames.take(joinKey) - - val lEltRef = Ref(freshName(), lEltType) - val rEltRef = Ref(freshName(), rEltType) - - StreamJoin( + val partitionJoiner: (Atom, Atom) => IR = (lPart, rPart) => + joinIR( lPart, rPart, - lKey, - rKey, - lEltRef.name, - rEltRef.name, - joiner(lEltRef, rEltRef), + this.kType.fieldNames.take(joinKey), + right.kType.fieldNames.take(joinKey), joinType, requiresMemoryManagement = true, rightKeyIsDistinct = rightKeyIsDistinct, - ) - } + )(joiner) val newKey = kType.fieldNames ++ right.kType.fieldNames.drop(joinKey) @@ -616,8 +561,8 @@ class TableStage( ec: ExecuteContext, right: TableStage, joinKey: Int, - globalJoiner: (IR, IR) => IR, - joiner: (IR, IR) => IR, + globalJoiner: (Atom, Atom) => IR, + joiner: (Atom, Atom) => IR, ): TableStage = { require(joinKey <= kType.size) require(joinKey <= right.kType.size) @@ -639,8 +584,8 @@ class TableStage( ctx: ExecuteContext, right: TableStage, rightRowRType: RStruct, - globalJoiner: (IR, IR) => IR, - joiner: (IR, IR) => IR, + globalJoiner: (Atom, Atom) => IR, + joiner: (Atom, Atom) => IR, ): TableStage = { require(right.kType.size == 1) val rightKeyType = right.kType.fields.head.typ @@ -649,54 +594,61 @@ class TableStage( val irPartitioner = partitioner.coarsen(1).partitionBoundsIRRepresentation - val rightWithPartNums = right.mapPartition(None) { partStream => - flatMapIR(partStream) { row => - val interval = bindIR(GetField(row, right.key.head)) { interval => - invoke( - "Interval", - TInterval(TTuple(kType.typeAfterSelect(ArraySeq(0)), TInt32)), - MakeTuple.ordered(FastSeq( - MakeStruct(FastSeq(kType.fieldNames.head -> invoke( - "start", - kType.types.head, - interval, - ))), - I32(1), - )), - MakeTuple.ordered(FastSeq( - MakeStruct(FastSeq(kType.fieldNames.head -> invoke( - "end", - kType.types.head, - interval, - ))), - I32(1), - )), - invoke("includesStart", TBoolean, interval), - invoke("includesEnd", TBoolean, interval), - ) - } + val rightWithPartNums = + right.mapPartition(None) { partStream => + flatMapIR(partStream) { row => + M.eval { + for { + interval <- GetField(row, right.key.head) - bindIR(invoke( - "partitionerFindIntervalRange", - TTuple(TInt32, TInt32), - irPartitioner, - interval, - )) { range => - val rangeStream = StreamRange( - GetTupleElement(range, 0), - GetTupleElement(range, 1), - I32(1), - requiresMemoryManagementPerElement = true, - ) - mapIR(rangeStream)(partNum => InsertFields(row, FastSeq("__partNum" -> partNum))) + interval <- + invoke( + "Interval", + TInterval(TTuple(kType.typeAfterSelect(ArraySeq(0)), TInt32)), + maketuple( + makestruct( + kType.fieldNames.head -> + invoke("start", kType.types.head, interval) + ), + 1, + ), + maketuple( + makestruct( + kType.fieldNames.head -> + invoke("end", kType.types.head, interval) + ), + 1, + ), + invoke("includesStart", TBoolean, interval), + invoke("includesEnd", TBoolean, interval), + ) + + range <- + invoke( + "partitionerFindIntervalRange", + TTuple(TInt32, TInt32), + irPartitioner, + interval, + ) + + rangeStream = + StreamRange( + GetTupleElement(range, 0), + GetTupleElement(range, 1), + 1, + requiresMemoryManagementPerElement = true, + ) + } yield mapIR(rangeStream)(partNum => + InsertFields(row, FastSeq("__partNum" -> partNum)) + ) + } } } - } val rightRowRTypeWithPartNum = - IndexedSeq("__partNum" -> TypeWithRequiredness(TInt32)) ++ rightRowRType.fields.map(rField => - rField.name -> rField.typ - ) + ArraySeq("__partNum" -> TypeWithRequiredness(TInt32)) ++ + rightRowRType.fields.map(f => f.name -> f.typ) + val rightTableRType = RTable(rightRowRTypeWithPartNum, FastSeq(), right.key) val sortedReader = ctx.backend.lowerDistributedSort( ctx, @@ -710,8 +662,8 @@ class TableStage( ctx.stateManager, Some(1), TStruct.concat(TStruct("__partNum" -> TInt32), right.kType), - ArraySeq.tabulate[Interval](partitioner.numPartitions)(i => - Interval(Row(i), Row(i), true, true) + ArraySeq.tabulate(partitioner.numPartitions)(i => + Interval(Row(i), Row(i), includesStart = true, includesEnd = true) ), ) val repartitioned = sorted.repartitionNoShuffle(ctx, newRightPartitioner) @@ -757,55 +709,54 @@ object LowerTableIR extends Logging { val samplesPerPartition = sampleSize / math.max(1, stage.numPartitions) val keyType = child.typ.keyType - bindIR(flatten(stage.mapCollect("table_calculate_new_partitions") { rows => - streamAggIR(mapIR(rows)(row => SelectFields(row, keyType.fieldNames))) { elt => - ToArray(flatMapIR(ToStream( - MakeArray( - ApplyAggOp(FastSeq(I32(samplesPerPartition)), FastSeq(elt), ReservoirSample()), - ApplyAggOp(FastSeq(I32(1)), FastSeq(elt, elt), TakeBy()), - ApplyAggOp(FastSeq(I32(1)), FastSeq(elt, elt), TakeBy(Descending)), - ) - ))(inner => ToStream(inner))) - } - })) { partData => - val sorted = sortIR(partData)((l, r) => ApplyComparisonOp(LT, l, r)) - bindIR(ToArray(flatMapIR(StreamGroupByKey( - ToStream(sorted), - keyType.fieldNames, - missingEqual = true, - ))(groupRef => StreamTake(groupRef, 1)))) { boundsArray => - bindIR(ArrayLen(boundsArray)) { nBounds => - bindIR(minIR(nBounds, nPartitions)) { nParts => - If( - nParts.ceq(0), - MakeArray(FastSeq(), TArray(TInterval(keyType))), - bindIR((nBounds + (nParts - 1)) floorDiv nParts) { stepSize => - ToArray(mapIR(StreamRange(0, nBounds, stepSize)) { i => - If( - (i + stepSize) < (nBounds - 1), - invoke( - "Interval", - TInterval(keyType), - ArrayRef(boundsArray, i), - ArrayRef(boundsArray, i + stepSize), - True(), - False(), - ), - invoke( - "Interval", - TInterval(keyType), - ArrayRef(boundsArray, i), - ArrayRef(boundsArray, nBounds - 1), - True(), - True(), - ), - ) - }) - }, - ) - } - } - } + M.eval { + for { + partData <- + flatten(stage.mapCollect("table_calculate_new_partitions") { rows => + streamAggIR(mapIR(rows)(row => SelectFields(row, keyType.fieldNames))) { elt => + ToArray(flatten(ToStream( + MakeArray( + ApplyAggOp( + FastSeq(I32(samplesPerPartition)), + FastSeq(elt), + ReservoirSample(), + ), + ApplyAggOp(FastSeq(I32(1)), FastSeq(elt, elt), TakeBy()), + ApplyAggOp(FastSeq(I32(1)), FastSeq(elt, elt), TakeBy(Descending)), + ) + ))) + } + }) + + sorted <- sortIR(partData)(ApplyComparisonOp(LT, _, _)) + + boundsArray <- + ToArray(flatMapIR(StreamGroupByKey( + ToStream(sorted), + keyType.fieldNames, + missingEqual = true, + ))(groupRef => StreamTake(groupRef, 1))) + + nBounds <- ArrayLen(boundsArray) + nParts <- minIR(nBounds, nPartitions) + } yield If( + nParts ceq 0, + MakeArray(FastSeq(), TArray(TInterval(keyType))), + bindIR((nBounds + nParts - 1) floorDiv nParts) { stepSize => + ToArray(mapIR(StreamRange(0, nBounds, stepSize)) { i => + bindIR((i + stepSize) < (nBounds - 1)) { closed => + invoke( + "Interval", + TInterval(keyType), + ArrayRef(boundsArray, i), + ArrayRef(boundsArray, If(closed, i + stepSize, nBounds - 1)), + True(), + !closed, + ) + } + }) + }, + ) } case TableGetGlobals(child) => @@ -815,7 +766,7 @@ object LowerTableIR extends Logging { lower(child).collectWithGlobals("table_collect") case TableAggregate(child, query) => - val aggs = agg.Extract(ctx, query, analyses.requirednessAnalysis).independent + val aggs = Extract(ctx, query, analyses.requirednessAnalysis).independent val aggSigs = aggs.sigs val lc = lower(child) @@ -824,13 +775,14 @@ object LowerTableIR extends Logging { FastSeq(TableIR.globalName -> lc.globals), RunAgg(aggs.init, aggSigs.valuesOp, aggSigs.states), ) + val initStateRef = Ref(freshName(), initState.typ) val lcWithInitBinding = lc.copy( - letBindings = lc.letBindings ++ FastSeq((initStateRef.name, initState)), - broadcastVals = lc.broadcastVals ++ FastSeq((initStateRef.name, initStateRef)), + letBindings = lc.letBindings ++ FastSeq(initStateRef.name -> initState), + broadcastVals = lc.broadcastVals ++ FastSeq(initStateRef.name -> initStateRef), ) - val initFromSerializedStates = aggSigs.initFromSerializedValueOp(initStateRef) + def initFromSerializedStates = aggSigs.initFromSerializedValueOp(initStateRef) val branchFactor = ctx.branchingFactor val useTreeAggregate = aggSigs.shouldTreeAggregate && branchFactor < lc.numPartitions @@ -848,7 +800,7 @@ object LowerTableIR extends Logging { ) val writer = ETypeValueWriter(codecSpec) val reader = ETypeValueReader(codecSpec) - lcWithInitBinding.mapCollectWithGlobals("table_aggregate")({ part: IR => + lcWithInitBinding.mapCollectWithGlobals("table_aggregate") { part => Let( FastSeq(TableIR.globalName -> lc.globals), RunAgg( @@ -860,16 +812,11 @@ object LowerTableIR extends Logging { aggSigs.states, ), ) - }) { case (collected, globals) => - val treeAggFunction = freshName() - val currentAggStates = Ref(freshName(), TArray(TString)) - val iterNumber = Ref(freshName(), TInt32) - - def combineGroup(partArrayRef: IR, useInitStates: Boolean): IR = { + } { case (collected, globals) => + def combineGroup(partArrayRef: Atom, useInitStates: Boolean): IR = Begin(FastSeq( - if (useInitStates) { - initFromSerializedStates - } else { + if (useInitStates) initFromSerializedStates + else { bindIR(ReadValue( ArrayRef(partArrayRef, 0), reader, @@ -889,44 +836,41 @@ object LowerTableIR extends Logging { ))(aggSigs.combOpValues) }, )) - } - val loopBody = If( - ArrayLen(currentAggStates) <= I32(branchFactor), - currentAggStates, - Recur( - treeAggFunction, - FastSeq( - cdaIR( - mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => - ToArray(x) - ), - MakeStruct(FastSeq()), - "table_tree_aggregate", - strConcat( - Str("iteration="), - invoke("str", TString, iterNumber), - Str(", n_states="), - invoke("str", TString, ArrayLen(currentAggStates)), + val treeAggregation = + tailLoop(TArray(TString), collected, 0) { + case (recur, Seq(currentAggStates, iterNumber)) => + If( + ArrayLen(currentAggStates) <= I32(branchFactor), + currentAggStates, + recur( + FastSeq( + cdaIR( + mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))( + ToArray(_) + ), + makestruct(), + "table_tree_aggregate", + strConcat( + Str("iteration="), + invoke("str", TString, iterNumber), + Str(", n_states="), + invoke("str", TString, ArrayLen(currentAggStates)), + ), + ) { (context, _) => + RunAgg( + combineGroup(context, false), + WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), + aggSigs.states, + ) + }, + iterNumber + 1, + ) ), - )((context, _) => - RunAgg( - combineGroup(context, false), - WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), - aggSigs.states, - ) - ), - iterNumber + 1, - ), - currentAggStates.typ, - ), - ) - bindIR(TailLoop( - treeAggFunction, - FastSeq[(Name, IR)](currentAggStates.name -> collected, iterNumber.name -> I32(0)), - loopBody.typ, - loopBody, - )) { finalParts => + ) + } + + bindIR(treeAggregation) { finalParts => RunAgg( combineGroup(finalParts, true), Let( @@ -938,7 +882,7 @@ object LowerTableIR extends Logging { } } } else { - lcWithInitBinding.mapCollectWithGlobals("table_aggregate_singlestage")({ part: IR => + lcWithInitBinding.mapCollectWithGlobals("table_aggregate_singlestage") { part => Let( FastSeq(TableIR.globalName -> lc.globals), RunAgg( @@ -950,7 +894,7 @@ object LowerTableIR extends Logging { aggSigs.states, ), ) - }) { case (collected, globals) => + } { case (collected, globals) => Let( FastSeq(TableIR.globalName -> globals), RunAgg( @@ -990,7 +934,8 @@ object LowerTableIR extends Logging { s"IR nodes with TableIR children must be defined explicitly: \n${Pretty(ctx, node)}" ) } - lowered + + NormalizeNames()(ctx, lowered) } def applyTable( @@ -1022,34 +967,36 @@ object LowerTableIR extends Logging { val loweredRowsAndGlobal = lowerIR(rowsAndGlobal) val loweredRowsAndGlobalRef = Ref(freshName(), loweredRowsAndGlobal.typ) - val context = bindIR(ArrayLen(GetField(loweredRowsAndGlobalRef, "rows"))) { numRowsRef => - bindIR(invoke( - "extend", - TArray(TInt32), - ToArray(mapIR(rangeIR(nPartitionsAdj)) { partIdx => - (partIdx * numRowsRef) floorDiv nPartitionsAdj - }), - MakeArray((numRowsRef)), - )) { indicesArray => - bindIR(GetField(loweredRowsAndGlobalRef, "rows")) { rows => - mapIR(rangeIR(nPartitionsAdj)) { partIdx => - ToArray(mapIR(rangeIR( - ArrayRef(indicesArray, partIdx), - ArrayRef(indicesArray, partIdx + 1), - ))(rowIdx => ArrayRef(rows, rowIdx))) - } + val context = + M.eval { + for { + rows <- GetField(loweredRowsAndGlobalRef, "rows") + numRowsRef <- ArrayLen(rows) + indicesArray <- + invoke( + "extend", + TArray(TInt32), + ToArray(mapIR(rangeIR(nPartitionsAdj)) { partIdx => + (partIdx * numRowsRef) floorDiv nPartitionsAdj + }), + MakeArray(numRowsRef), + ) + } yield mapIR(rangeIR(nPartitionsAdj)) { partIdx => + ToArray(mapIR(rangeIR( + ArrayRef(indicesArray, partIdx), + ArrayRef(indicesArray, partIdx + 1), + ))(rowIdx => ArrayRef(rows, rowIdx))) } } - } val globalsRef = Ref(freshName(), typ.globalType) TableStage( FastSeq( loweredRowsAndGlobalRef.name -> loweredRowsAndGlobal, - globalsRef.name -> GetField(loweredRowsAndGlobalRef, "global"), + globalsRef.name -> GetField(loweredRowsAndGlobalRef.ir, "global"), ), FastSeq(globalsRef.name -> globalsRef), - globalsRef, + globalsRef.ir, RVDPartitioner.unkeyed(ctx.stateManager, nPartitionsAdj), TableStageDependency.none, context, @@ -1086,10 +1033,10 @@ object LowerTableIR extends Logging { val partitionIdx = StreamRange(I32(0), I32(partitioner.numPartitions), I32(1)) val bounds = Literal( TArray(TInterval(partitioner.kType)), - partitioner.rangeBounds.toIndexedSeq, + partitioner.rangeBounds, ) zipIR(FastSeq(partitionIdx, ToStream(bounds), ctxs), AssertSameLength, errorId)( - refs => MakeTuple.ordered(refs.map(_.ir)) + elems => MakeTuple.ordered(elems.map(_.ir)) ) } } @@ -1129,21 +1076,20 @@ object LowerTableIR extends Logging { val ranges = ArraySeq.tabulate(nPartitionsAdj)(i => partStarts(i) -> partStarts(i + 1)) TableStage( - MakeStruct(FastSeq()), + makestruct(), new RVDPartitioner( ctx.stateManager, Array("idx"), tir.typ.rowType, - ranges.map { - case (start, end) => - Interval(Row(start), Row(end), includesStart = true, includesEnd = false) + ranges.map { case (start, end) => + Interval(Row(start), Row(end), includesStart = true, includesEnd = false) }, ), TableStageDependency.none, ToStream(Literal(TArray(contextType), ranges.map(Row.fromTuple))), - (ctxRef: Ref) => + ctxRef => mapIR(StreamRange(GetField(ctxRef, "start"), GetField(ctxRef, "end"), I32(1), true)) { - i => MakeStruct(FastSeq("idx" -> i)) + i => makestruct("idx" -> i) }, ) @@ -1175,11 +1121,10 @@ object LowerTableIR extends Logging { ), // FIXME: would prefer a First() agg op expr, ) { case Seq(key, value) => - MakeStruct(child.typ.key.map(k => - (k, GetField(key, k)) - ) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => - (f, GetField(value, f)) - }) + MakeStruct( + child.typ.key.map(k => k -> GetField(key, k)) ++ + expr.typ.asInstanceOf[TStruct].fieldNames.map(f => f -> GetField(value, f)) + ) }, ) }, @@ -1229,49 +1174,42 @@ object LowerTableIR extends Logging { RVDPartitioner.intervalToIRRepresentation(i, kt.size) } - val (newRangeBounds, includedIndices, startAndEndInterval, f) = if (keep) { - val (newRangeBounds, includedIndices, startAndEndInterval) = - part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => - if (filterPartitioner.overlaps(interval)) { - Some(( - interval, - i, - ( - filterPartitioner.lowerBoundInterval(interval), - filterPartitioner.upperBoundInterval(interval), - ), - )) - } else None - }.unzip3 - - def f(partitionIntervals: IR, key: IR): IR = - invoke("partitionerContains", TBoolean, partitionIntervals, key) - - (newRangeBounds, includedIndices, startAndEndInterval, f _) - } else { - // keep = False - val (newRangeBounds, includedIndices, startAndEndInterval) = - part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => - val lowerBound = filterPartitioner.lowerBoundInterval(interval) - val upperBound = filterPartitioner.upperBoundInterval(interval) - if ( - (lowerBound until upperBound).map(filterPartitioner.rangeBounds).exists { - filterInterval => - iord.compareNonnull( - filterInterval.left, - interval.left, - ) <= 0 && iord.compareNonnull(filterInterval.right, interval.right) >= 0 - } - ) - None - else Some((interval, i, (lowerBound, upperBound))) - }.unzip3 - - def f(partitionIntervals: IR, key: IR): IR = - !invoke("partitionerContains", TBoolean, partitionIntervals, key) - - (newRangeBounds, includedIndices, startAndEndInterval, f _) - } + val ((newRangeBounds, includedIndices, startAndEndInterval), f) = + if (keep) + ( + part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => + if (filterPartitioner.overlaps(interval)) { + Some(( + interval, + i, + ( + filterPartitioner.lowerBoundInterval(interval), + filterPartitioner.upperBoundInterval(interval), + ), + )) + } else None + }.unzip3, + (intervals: IR, key: IR) => + invoke("partitionerContains", TBoolean, intervals, key), + ) + else + ( + part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => + val lowerBound = filterPartitioner.lowerBoundInterval(interval) + val upperBound = filterPartitioner.upperBoundInterval(interval) + if ( + (lowerBound until upperBound).exists { i => + val filterInterval = filterPartitioner.rangeBounds(i) + iord.compareNonnull(filterInterval.left, interval.left) <= 0 && + iord.compareNonnull(filterInterval.right, interval.right) >= 0 + } + ) + None + else Some((interval, i, (lowerBound, upperBound))) + }.unzip3, + (intervals: IR, key: IR) => + !invoke("partitionerContains", TBoolean, intervals, key), + ) val newPart = new RVDPartitioner(ctx.stateManager, kt, newRangeBounds) @@ -1284,48 +1222,45 @@ object LowerTableIR extends Logging { loweredChild.globals, newPart, loweredChild.dependency, - contexts = bindIRs( - ToArray(loweredChild.contexts), - Literal( - TArray(TTuple(TInt32, TInt32)), - startAndEndInterval.map(Row.fromTuple), - ), - ) { case Seq(prevContexts, bounds) => - zip2( - ToStream(Literal(TArray(TInt32), includedIndices)), - ToStream(bounds), - ArrayZipBehavior.AssumeSameLength, - ) { (idx, bound) => - MakeStruct(FastSeq(("prevContext", ArrayRef(prevContexts, idx)), ("bounds", bound))) - } - }, - { (part: Ref) => - val oldPart = loweredChild.partition(GetField(part, "prevContext")) - bindIR(GetField(part, "bounds")) { bounds => - bindIRs(GetTupleElement(bounds, 0), GetTupleElement(bounds, 1)) { - case Seq(startIntervalIdx, endIntervalIdx) => - bindIR(ToArray(mapIR(rangeIR(startIntervalIdx, endIntervalIdx)) { i => + contexts = + bindIR(ToArray(loweredChild.contexts)) { prevContexts => + zip2( + ToStream(Literal(TArray(TInt32), includedIndices)), + ToStream(Literal( + TArray(TTuple(TInt32, TInt32)), + startAndEndInterval.map(Row.fromTuple), + )), + ArrayZipBehavior.AssumeSameLength, + ) { (idx, bound) => + makestruct("prevContext" -> ArrayRef(prevContexts, idx), "bounds" -> bound) + } + }, + part => + M.eval { + for { + ctx <- GetField(part, "prevContext") + rows <- loweredChild.partition(ctx) + bounds <- GetField(part, "bounds") + startIntervalIdx <- GetTupleElement(bounds, 0) + endIntervalIdx <- GetTupleElement(bounds, 1) + partitionIntervals <- + ToArray(mapIR(rangeIR(startIntervalIdx, endIntervalIdx)) { i => ArrayRef(filterIntervalsRef, i) - })) { partitionIntervals => - filterIR(oldPart) { row => - bindIR(SelectFields(row, child.typ.key))(key => f(partitionIntervals, key)) - } - } + }) + } yield filterIR(rows) { row => + f(partitionIntervals, SelectFields(row, child.typ.key)) } - } - }, + }, ) case TableHead(child, targetNumRows) => val loweredChild = lower(child) - def streamLenOrMax(a: IR): IR = - if (targetNumRows <= Integer.MAX_VALUE) - StreamLen(StreamTake(a, targetNumRows.toInt)) - else - StreamLen(a) + def streamLenOrMax(a: Atom): IR = + if (targetNumRows <= Integer.MAX_VALUE) StreamLen(StreamTake(a, targetNumRows.toInt)) + else StreamLen(a) - def partitionSizeArray(childContexts: Atom): IR = { + def partitionSizeArray(childContexts: Atom): IR = PartitionCounts(child) match { case Some(partCounts) => var idx = 0 @@ -1338,108 +1273,76 @@ object LowerTableIR extends Logging { val finalParts = partsToKeep.map(partSize => partSize.toInt) Literal(TArray(TInt32), finalParts) case None => - val partitionSizeArrayFunc = freshName() - val howManyPartsToTryRef = Ref(freshName(), TInt32) - val howManyPartsToTry = if (targetNumRows == 1L) 1 else 4 - val iteration = Ref(freshName(), TInt32) - - val loopBody = bindIR( - loweredChild - .mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef)) { - ctx: IR => ctx + tailLoop(TArray(TInt32), if (targetNumRows == 1L) 1 else 4, 0) { + case (recur, Seq(numPartsToTry, iteration)) => + bindIR( + loweredChild + .mapContexts(_ => StreamTake(ToStream(childContexts), numPartsToTry))( + identity + ) + .mapCollect( + "table_head_recursive_count", + strConcat("iteration=", iteration, ",nParts=", numPartsToTry), + )(streamLenOrMax) + ) { counts => + If( + (streamSumIR(ToStream(counts)).toL >= targetNumRows) || + (ArrayLen(childContexts) <= ArrayLen(counts)), + counts, + recur(FastSeq(numPartsToTry * 4, iteration + 1)), + ) } - .mapCollect( - "table_head_recursive_count", - strConcat( - Str("iteration="), - invoke("str", TString, iteration), - Str(",nParts="), - invoke("str", TString, howManyPartsToTryRef), - ), - )(streamLenOrMax) - ) { counts => - If( - (Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) - || (ArrayLen(childContexts) <= ArrayLen(counts)), - counts, - Recur( - partitionSizeArrayFunc, - FastSeq(howManyPartsToTryRef * 4, iteration + 1), - TArray(TInt32), - ), - ) } - TailLoop( - partitionSizeArrayFunc, - FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), - loopBody.typ, - loopBody, - ) } - } def answerTuple(partitionSizeArrayRef: Atom): IR = bindIR(ArrayLen(partitionSizeArrayRef)) { numPartitions => - val howManyPartsToKeep = freshName() - val i = Ref(freshName(), TInt32) - val numLeft = Ref(freshName(), TInt64) - def makeAnswer(howManyParts: IR, howManyFromLast: IR) = - MakeTuple(FastSeq((0, howManyParts), (1, howManyFromLast))) - - val loopBody = If( - (i ceq numPartitions - 1) || ((numLeft - Cast( - ArrayRef(partitionSizeArrayRef, i), - TInt64, - )) <= 0L), - makeAnswer(i + 1, numLeft), - Recur( - howManyPartsToKeep, - FastSeq( - i + 1, - numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64), - ), - TTuple(TInt32, TInt64), - ), - ) If( numPartitions ceq 0, - makeAnswer(0, 0L), - TailLoop( - howManyPartsToKeep, - FastSeq(i.name -> 0, numLeft.name -> targetNumRows), - loopBody.typ, - loopBody, - ), + maketuple(0, 0L), + tailLoop(TTuple(TInt32, TInt64), 0, targetNumRows) { + case (recur, Seq(i, numLeft)) => + If( + (i ceq numPartitions - 1) || + ((numLeft - ArrayRef(partitionSizeArrayRef, i).toL) <= 0L), + maketuple(i + 1, numLeft), + recur( + FastSeq( + i + 1, + numLeft - ArrayRef(partitionSizeArrayRef, i).toL, + ) + ), + ) + }, ) } - val newCtxs = bindIR(ToArray(loweredChild.contexts)) { childContexts => - bindIR(partitionSizeArray(childContexts)) { partitionSizeArrayRef => - bindIR(answerTuple(partitionSizeArrayRef)) { answerTupleRef => - val numParts = GetTupleElement(answerTupleRef, 0) - val numElementsFromLastPart = GetTupleElement(answerTupleRef, 1) - val onlyNeededPartitions = StreamTake(ToStream(childContexts), numParts) - val howManyFromEachPart = mapIR(rangeIR(numParts)) { idxRef => - If( - idxRef ceq (numParts - 1), - Cast(numElementsFromLastPart, TInt32), - ArrayRef(partitionSizeArrayRef, idxRef), - ) - } - - zipIR( - FastSeq(onlyNeededPartitions, howManyFromEachPart), - ArrayZipBehavior.AssumeSameLength, - ) { case Seq(part, howMany) => - MakeStruct(FastSeq("numberToTake" -> howMany, "old" -> part)) - } - } + val newCtxs = + M.eval { + for { + childContexts <- ToArray(loweredChild.contexts) + partitionSizeArrayRef <- partitionSizeArray(childContexts) + answerTupleRef <- answerTuple(partitionSizeArrayRef) + numParts <- GetTupleElement(answerTupleRef, 0) + numElementsFromLastPart <- GetTupleElement(answerTupleRef, 1) + howManyFromEachPart <- + mapIR(rangeIR(numParts)) { i => + If( + i ceq (numParts - 1), + numElementsFromLastPart.toI, + ArrayRef(partitionSizeArrayRef, i), + ) + } + } yield zipIR( + FastSeq(StreamTake(ToStream(childContexts), numParts), howManyFromEachPart), + ArrayZipBehavior.AssumeSameLength, + ) { case Seq(part, howMany) => makestruct("numberToTake" -> howMany, "old" -> part) } } - } val bindRelationLetsNewCtx = Let(loweredChild.letBindings, ToArray(newCtxs)) val newCtxSeq = CompileAndEvaluate[IndexedSeq[Any]](ctx, bindRelationLetsNewCtx) + val numNewParts = newCtxSeq.length val newIntervals = loweredChild.partitioner.rangeBounds.slice(0, numNewParts) val newPartitioner = loweredChild.partitioner.copy(rangeBounds = newIntervals) @@ -1451,17 +1354,19 @@ object LowerTableIR extends Logging { newPartitioner, loweredChild.dependency, ToStream(Literal(bindRelationLetsNewCtx.typ, newCtxSeq)), - (ctxRef: Ref) => - StreamTake( - loweredChild.partition(GetField(ctxRef, "old")), - GetField(ctxRef, "numberToTake"), - ), + ctxRef => + bindIR(GetField(ctxRef, "old")) { old => + StreamTake( + loweredChild.partition(old), + GetField(ctxRef, "numberToTake"), + ) + }, ) case TableTail(child, targetNumRows) => val loweredChild = lower(child) - def partitionSizeArray(childContexts: Atom, totalNumPartitions: Atom): IR = { + def partitionSizeArray(childContexts: Atom, totalNumPartitions: Atom): IR = PartitionCounts(child) match { case Some(partCounts) => var idx = partCounts.length @@ -1474,116 +1379,77 @@ object LowerTableIR extends Logging { Literal(TArray(TInt32), finalParts) case None => - val partitionSizeArrayFunc = freshName() - val howManyPartsToTryRef = Ref(freshName(), TInt32) - val howManyPartsToTry = if (targetNumRows == 1L) 1 else 4 - - val iteration = Ref(freshName(), TInt32) - - val loopBody = bindIR( - loweredChild - .mapContexts(_ => - StreamDrop( - ToStream(childContexts), - maxIR(totalNumPartitions - howManyPartsToTryRef, 0), + tailLoop(TArray(TInt32), if (targetNumRows == 1L) 1 else 4, 0) { + case (recur, Seq(numPartsToTry, iteration)) => + bindIR( + loweredChild + .mapContexts(_ => + StreamDrop( + ToStream(childContexts), + maxIR(totalNumPartitions - numPartsToTry, 0), + ) + )(identity) + .mapCollect( + "table_tail_recursive_count", + strConcat("iteration=", iteration, ", nParts=", numPartsToTry), + )(StreamLen(_)) + ) { counts => + If( + (streamSumIR(ToStream(counts)).toL >= targetNumRows) || + (totalNumPartitions <= ArrayLen(counts)), + counts, + recur(FastSeq(numPartsToTry * 4, iteration + 1)), ) - ) { ctx: IR => ctx } - .mapCollect( - "table_tail_recursive_count", - strConcat( - Str("iteration="), - invoke("str", TString, iteration), - Str(", nParts="), - invoke("str", TString, howManyPartsToTryRef), - ), - )(StreamLen) - ) { counts => - If( - (Cast( - streamSumIR(ToStream(counts)), - TInt64, - ) >= targetNumRows) || (totalNumPartitions <= ArrayLen(counts)), - counts, - Recur( - partitionSizeArrayFunc, - FastSeq(howManyPartsToTryRef * 4, iteration + 1), - TArray(TInt32), - ), - ) + } } - TailLoop( - partitionSizeArrayFunc, - FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), - loopBody.typ, - loopBody, - ) } - } - /* First element is how many partitions to keep from the right partitionSizeArrayRef, second - * is how many to keep from first kept element. */ - def answerTuple(partitionSizeArrayRef: Atom): IR = { - bindIR(ArrayLen(partitionSizeArrayRef)) { numPartitions => - val howManyPartsToDrop = freshName() - val i = Ref(freshName(), TInt32) - val nRowsToRight = Ref(freshName(), TInt64) - def makeAnswer(howManyParts: IR, howManyFromLast: IR) = - MakeTuple.ordered(FastSeq(howManyParts, howManyFromLast)) - - val loopBody = If( - (i ceq numPartitions) || ((nRowsToRight + Cast( - ArrayRef(partitionSizeArrayRef, numPartitions - i), - TInt64, - )) >= targetNumRows), - makeAnswer( - i, - maxIR( - 0L, - Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64) - (I64( - targetNumRows - ) - nRowsToRight), - ).toI, - ), - Recur( - howManyPartsToDrop, - FastSeq( - i + 1, - nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64), - ), - TTuple(TInt32, TInt32), - ), - ) - If( - numPartitions ceq 0, - makeAnswer(0, 0), - TailLoop( - howManyPartsToDrop, - FastSeq(i.name -> 1, nRowsToRight.name -> 0L), - loopBody.typ, - loopBody, - ), - ) - } - } - - val newCtxs = bindIR(ToArray(loweredChild.contexts)) { childContexts => - bindIR(ArrayLen(childContexts)) { totalNumPartitions => - bindIR(partitionSizeArray(childContexts, totalNumPartitions)) { partitionSizeArrayRef => - bindIR(answerTuple(partitionSizeArrayRef)) { answerTupleRef => - val numPartsToKeepFromRight = GetTupleElement(answerTupleRef, 0) - val nToDropFromFirst = GetTupleElement(answerTupleRef, 1) - bindIR(totalNumPartitions - numPartsToKeepFromRight) { startIdx => - mapIR(rangeIR(numPartsToKeepFromRight)) { idx => - MakeStruct(FastSeq( - "numberToDrop" -> If(idx ceq 0, nToDropFromFirst, 0), - "old" -> ArrayRef(childContexts, idx + startIdx), - )) - } + // First element is how many partitions to keep from the right partitionSizeArray + // Second element is how many to keep from first kept element. + def answerTuple(partitionSizeArray: Atom, nPartitions: Atom): IR = + If( + nPartitions ceq 0, + maketuple(0, 0), + tailLoop(TTuple(TInt32, TInt32), 1, 0L) { + case (recur, Seq(i, nRowsToRight)) => + bindIR(ArrayRef(partitionSizeArray, nPartitions - i).toL) { keep => + If( + (i ceq nPartitions) || (nRowsToRight + keep) >= targetNumRows, + maketuple( + i, + maxIR(0L, keep - (I64(targetNumRows) - nRowsToRight)).toI, + ), + recur( + FastSeq( + i + 1, + nRowsToRight + keep, + ) + ), + ) } - } + }, + ) + + val newCtxs = + M.eval { + for { + childContexts <- ToArray(loweredChild.contexts) + nContexts <- ArrayLen(childContexts) + + partitionSizeArrayRef <- partitionSizeArray(childContexts, nContexts) + nPartitions <- ArrayLen(partitionSizeArrayRef) + answerTupleRef <- answerTuple(partitionSizeArrayRef, nPartitions) + + numPartsToKeepFromRight <- GetTupleElement(answerTupleRef, 0) + nToDropFromFirst <- GetTupleElement(answerTupleRef, 1) + startIdx <- nContexts - numPartsToKeepFromRight + } yield mapIR(rangeIR(numPartsToKeepFromRight)) { idx => + makestruct( + "numberToDrop" -> If(idx ceq 0, nToDropFromFirst, 0), + "old" -> ArrayRef(childContexts, idx + startIdx), + ) } } - } val letBindNewCtx = Let(loweredChild.letBindings, ToArray(newCtxs)) val newCtxSeq = CompileAndEvaluate[IndexedSeq[Any]](ctx, letBindNewCtx) @@ -1598,7 +1464,7 @@ object LowerTableIR extends Logging { newPartitioner, loweredChild.dependency, ToStream(Literal(letBindNewCtx.typ, newCtxSeq)), - (ctxRef: Ref) => + ctxRef => bindIR(GetField(ctxRef, "old")) { oldRef => StreamDrop(loweredChild.partition(oldRef), GetField(ctxRef, "numberToDrop")) }, @@ -1610,12 +1476,12 @@ object LowerTableIR extends Logging { lc.mapPartition(Some(child.typ.key)) { rows => Let( FastSeq(TableIR.globalName -> lc.globals), - mapIR(rows)(row => Let(FastSeq(TableIR.rowName -> row), newRow)), + StreamMap(rows, TableIR.rowName, newRow), ) } } else { - val aggs = - agg.Extract(ctx, newRow, analyses.requirednessAnalysis, isScan = true).independent + + val aggs = Extract(ctx, newRow, analyses.requirednessAnalysis, isScan = true).independent val aggSigs = aggs.sigs val initState = RunAgg( @@ -1623,296 +1489,273 @@ object LowerTableIR extends Logging { aggSigs.valuesOp, aggSigs.states, ) + val initStateRef = Ref(freshName(), initState.typ) - val lcWithInitBinding = lc.copy( - letBindings = lc.letBindings ++ FastSeq((initStateRef.name, initState)), - broadcastVals = lc.broadcastVals ++ FastSeq((initStateRef.name, initStateRef)), - ) - val initFromSerializedStates = aggSigs.initFromSerializedValueOp(initStateRef) + val lcWithInitBinding = { + val lc_ = lc.deepCopy // no-sharing + lc_ + .copy( + letBindings = FastSeq(initStateRef.name -> initState), + broadcastVals = lc_.broadcastVals ++ FastSeq(initStateRef.name -> initStateRef.ir), + ) + } + + def initFromSerializedStates = aggSigs.initFromSerializedValueOp(initStateRef.ir) val branchFactor = ctx.branchingFactor - val big = aggSigs.shouldTreeAggregate && branchFactor < lc.numPartitions - val (partitionPrefixSumValues, transformPrefixSum): (IR, IR => IR) = if (big) { - val tmpDir = ctx.createTmpPath("aggregate_intermediates/") - - val codecSpec = TypedCodecSpec( - ctx, - PCanonicalTuple(true, Seq.fill(aggSigs.nAggs)(PCanonicalBinary(true)): _*), - BufferSpec.wireSpec, - ) - val writer = ETypeValueWriter(codecSpec) - val reader = ETypeValueReader(codecSpec) - val partitionPrefixSumFiles = - lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums")({ part: IR => - Let( - FastSeq(TableIR.globalName -> lcWithInitBinding.globals), - RunAgg( - Begin(FastSeq( - initFromSerializedStates, - StreamFor(part, TableIR.rowName, aggs.seqPerElt), - )), - WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), - aggSigs.states, - ), + + val (partitionPrefixSumValues, transformPrefixSum): (IR, Atom => IR) = + if (aggSigs.shouldTreeAggregate && branchFactor < lc.numPartitions) { + val tmpDir = ctx.createTmpPath("aggregate_intermediates/") + + val codecSpec = + TypedCodecSpec( + ctx, + PCanonicalTuple(true, Seq.fill(aggSigs.nAggs)(PCanonicalBinary(true)): _*), + BufferSpec.wireSpec, ) - // Collected is TArray of TString - }) { case (collected, _) => - def combineGroup(partArrayRef: IR): IR = { - Begin(FastSeq( - bindIR(ReadValue( - ArrayRef(partArrayRef, 0), - reader, - reader.spec.encodedVirtualType, - ))(aggSigs.initFromSerializedValueOp), - forIR(StreamRange( - 1, - ArrayLen(partArrayRef), - 1, - requiresMemoryManagementPerElement = true, - )) { fileIdx => + + val reader = ETypeValueReader(codecSpec) + val writer = ETypeValueWriter(codecSpec) + + val partitionPrefixSumFiles = + lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums") { part => + Let( + FastSeq(TableIR.globalName -> lcWithInitBinding.globals), + RunAgg( + Begin(FastSeq( + initFromSerializedStates, + StreamFor(part, TableIR.rowName, aggs.seqPerElt), + )), + WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), + aggSigs.states, + ), + ) + // Collected is TArray of TString + } { case (collected, _) => + def combineGroup(partArrayRef: Atom): IR = + Begin(FastSeq( bindIR(ReadValue( - ArrayRef(partArrayRef, fileIdx), + ArrayRef(partArrayRef, 0), reader, reader.spec.encodedVirtualType, - ))(aggSigs.combOpValues) - }, - )) - } + ))(aggSigs.initFromSerializedValueOp), + forIR(StreamRange(1, ArrayLen(partArrayRef), 1, true)) { fileIdx => + bindIR(ReadValue( + ArrayRef(partArrayRef, fileIdx), + reader, + reader.spec.encodedVirtualType, + ))(aggSigs.combOpValues) + }, + )) - // Return Array[Array[String]], length is log_b(num_partitions) - // The upward pass starts with partial aggregations from each partition, - // and aggregates these in a tree parameterized by the branching factor. - // The tree ends when the number of partial aggregations is less than or - // equal to the branching factor. - - // The upward pass returns the full tree of results as an array of arrays, - // where the first element is partial aggregations per partition of the - // input. - def upPass(): IR = { - val aggStack = Ref(freshName(), TArray(TArray(TString))) - val iteration = Ref(freshName(), TInt32) - val loopName = freshName() - - val loopBody = bindIR(ArrayRef(aggStack, (ArrayLen(aggStack) - 1))) { states => - bindIR(ArrayLen(states)) { statesLen => - If( - statesLen > branchFactor, - bindIR((statesLen + branchFactor - 1) floorDiv branchFactor) { nCombines => - val contexts = mapIR(rangeIR(nCombines)) { outerIdxRef => - sliceArrayIR( - states, - outerIdxRef * branchFactor, - (outerIdxRef + 1) * branchFactor, + // Return Array[Array[String]], length is log_b(num_partitions) + // The upward pass starts with partial aggregations from each partition, + // and aggregates these in a tree parameterized by the branching factor. + // The tree ends when the number of partial aggregations is less than or + // equal to the branching factor. + // The upward pass returns the full tree of results as an array of arrays, + // where the first element is partial aggregations per partition of the + // input. + val upPass = + tailLoop(TArray(TArray(TString)), MakeArray(collected), 0) { + case (recur, Seq(aggStack, iteration)) => + bindIR(ArrayRef(aggStack, ArrayLen(aggStack) - 1)) { states => + bindIR(ArrayLen(states)) { statesLen => + If( + statesLen > branchFactor, { + val nCombines = + (statesLen + branchFactor - 1) floorDiv branchFactor + + val contexts = + mapIR(rangeIR(nCombines)) { outerIdxRef => + sliceArrayIR( + states, + outerIdxRef * branchFactor, + (outerIdxRef + 1) * branchFactor, + ) + } + + val cdaResult = + cdaIR( + contexts, + makestruct(), + "table_scan_up_pass", + strConcat("iteration=", iteration, ", nStates=", statesLen), + ) { case (contexts, _) => + RunAgg( + combineGroup(contexts), + WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), + aggSigs.states, + ) + } + + recur( + FastSeq( + invoke( + "extend", + TArray(TArray(TString)), + aggStack, + MakeArray(cdaResult), + ), + iteration + 1, + ) + ) + }, + aggStack, ) } - val cdaResult = cdaIR( - contexts, - MakeStruct(FastSeq()), - "table_scan_up_pass", - strConcat( - Str("iteration="), - invoke("str", TString, iteration), - Str(", nStates="), - invoke("str", TString, statesLen), - ), - ) { case (contexts, _) => - RunAgg( - combineGroup(contexts), - WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), - aggSigs.states, + } + } + + // The downward pass traverses the tree from root to leaves, computing partial + // scan sums as it goes. The two pieces of state transmitted between iterations + // are: + // - the level (an integer) referring to a position in the array `aggStack`, + // - and `last`, the partial sums from the last iteration. + // + // The starting state for `last` is an array of a single empty aggregation state. + bindIR(upPass) { aggStack => + val freshState = WriteValue(initStateRef.ir, Str(tmpDir) + UUID4(), writer) + tailLoop(TArray(TString), ArrayLen(aggStack) - 1, MakeArray(freshState), 0) { + case (recur, Seq(level, last, iteration)) => + If( + level < 0, + last, + bindIR(ArrayRef(aggStack, level)) { aggsArray => + val groups = + mapIR( + zipWithIndex( + mapIR(StreamGrouped(ToStream(aggsArray), I32(branchFactor)))( + ToArray(_) + ) + ) + ) { eltAndIdx => + makestruct( + "prev" -> ArrayRef(last, GetField(eltAndIdx, "idx")), + "partialSums" -> GetField(eltAndIdx, "elt"), + ) + } + + val results = + cdaIR( + groups, + maketuple(), + "table_scan_down_pass", + strConcat("iteration=", iteration, ", level=", level), + ) { case (context, _) => + val elt = Ref(freshName(), TString) + ToArray(RunAggScan( + ToStream( + GetField(context, "partialSums"), + requiresMemoryManagementPerElement = true, + ), + elt.name, + bindIR(ReadValue( + GetField(context, "prev"), + reader, + reader.spec.encodedVirtualType, + ))( + aggSigs.initFromSerializedValueOp + ), + bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType))( + aggSigs.combOpValues + ), + WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), + aggSigs.states, + )) + } + + recur( + FastSeq( + level - 1, + ToArray(flatten(ToStream(results))), + iteration + 1, + ) ) - } - Recur( - loopName, - IndexedSeq( - invoke( - "extend", - TArray(TArray(TString)), - aggStack, - MakeArray(cdaResult), - ), - iteration + 1, - ), - TArray(TArray(TString)), - ) - }, - aggStack, - ) + }, + ) } } - TailLoop( - loopName, - IndexedSeq((aggStack.name, MakeArray(collected)), (iteration.name, I32(0))), - loopBody.typ, - loopBody, - ) } - // The downward pass traverses the tree from root to leaves, computing partial scan - // sums as it goes. The two pieces of state transmitted between iterations are the - // level (an integer) referring to a position in the array `aggStack`, and `last`, - // the partial sums from the last iteration. The starting state for `last` is an - // array of a single empty aggregation state. - bindIR(upPass()) { aggStack => - val downPassLoopName = freshName() - val iteration = Ref(freshName(), TInt32) - - val level = Ref(freshName(), TInt32) - val last = Ref(freshName(), TArray(TString)) - - bindIR(WriteValue(initState, Str(tmpDir) + UUID4(), writer)) { freshState => - val loopBody = If( - level < 0, - last, - bindIR(ArrayRef(aggStack, level)) { aggsArray => - val groups = mapIR(zipWithIndex(mapIR(StreamGrouped( - ToStream(aggsArray), - I32(branchFactor), - ))(x => ToArray(x)))) { eltAndIdx => - MakeStruct(FastSeq( - ("prev", ArrayRef(last, GetField(eltAndIdx, "idx"))), - ("partialSums", GetField(eltAndIdx, "elt")), - )) - } - - val results = cdaIR( - groups, - MakeTuple.ordered(FastSeq()), - "table_scan_down_pass", - strConcat( - Str("iteration="), - invoke("str", TString, iteration), - Str(", level="), - invoke("str", TString, level), - ), - ) { case (context, _) => - bindIR(GetField(context, "prev")) { prev => - val elt = Ref(freshName(), TString) - ToArray(RunAggScan( - ToStream( - GetField(context, "partialSums"), - requiresMemoryManagementPerElement = true, - ), - elt.name, - bindIR(ReadValue(prev, reader, reader.spec.encodedVirtualType))( - aggSigs.initFromSerializedValueOp - ), - bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType))( - aggSigs.combOpValues - ), - WriteValue(aggSigs.valuesOp, Str(tmpDir) + UUID4(), writer), - aggSigs.states, - )) - } - } - Recur( - downPassLoopName, - IndexedSeq( - level - 1, - ToArray(flatten(ToStream(results))), - iteration + 1, - ), - TArray(TString), - ) - }, - ) - TailLoop( - downPassLoopName, - IndexedSeq( - (level.name, ArrayLen(aggStack) - 1), - (last.name, MakeArray(freshState)), - (iteration.name, I32(0)), + ( + partitionPrefixSumFiles, + file => ReadValue(file, reader, reader.spec.encodedVirtualType), + ) + } else { + val partitionAggs = + lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage") { + rows => + Let( + FastSeq(TableIR.globalName -> lc.globals), + RunAgg( + Begin(FastSeq( + initFromSerializedStates, + StreamFor(rows, TableIR.rowName, aggs.seqPerElt), + )), + aggSigs.valuesOp, + aggSigs.states, ), - loopBody.typ, - loopBody, ) - } - } - } - ( - partitionPrefixSumFiles, - { (file: IR) => ReadValue(file, reader, reader.spec.encodedVirtualType) }, - ) - - } else { - val partitionAggs = - lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage")({ - part: IR => + } { case (collected, globals) => Let( - FastSeq(TableIR.globalName -> lc.globals), - RunAgg( - Begin(FastSeq( - initFromSerializedStates, - StreamFor(part, TableIR.rowName, aggs.seqPerElt), - )), - aggSigs.valuesOp, - aggSigs.states, + FastSeq(TableIR.globalName -> globals), + ToArray( + StreamTake( + streamScanIR( + ToStream(collected, requiresMemoryManagementPerElement = true), + initStateRef.ir, + ) { + (acc, value) => + RunAgg( + Begin(FastSeq( + aggSigs.initFromSerializedValueOp(acc), + aggSigs.combOpValues(value), + )), + aggSigs.valuesOp, + aggSigs.states, + ) + }, + ArrayLen(collected), + ) ), ) - }) { case (collected, globals) => - Let( - FastSeq(TableIR.globalName -> globals), - ToArray(StreamTake( - { - val acc = Ref(freshName(), initStateRef.typ) - val value = Ref(freshName(), collected.typ.asInstanceOf[TArray].elementType) - StreamScan( - ToStream(collected, requiresMemoryManagementPerElement = true), - initStateRef, - acc.name, - value.name, - RunAgg( - Begin(FastSeq( - aggSigs.initFromSerializedValueOp(acc), - aggSigs.combOpValues(value), - )), - aggSigs.valuesOp, - aggSigs.states, - ), - ) - }, - ArrayLen(collected), - )), - ) - } - (partitionAggs, identity[IR]) - } + } + + (partitionAggs, identity(_)) + } val partitionPrefixSumsRef = Ref(freshName(), partitionPrefixSumValues.typ) - val zipOldContextRef = Ref(freshName(), lc.contexts.typ.asInstanceOf[TStream].elementType) - val zipPartAggUID = - Ref(freshName(), partitionPrefixSumValues.typ.asInstanceOf[TArray].elementType) - TableStage.apply( + TableStage( letBindings = - lc.letBindings ++ FastSeq((partitionPrefixSumsRef.name, partitionPrefixSumValues)), + lc.letBindings :+ (partitionPrefixSumsRef.name -> partitionPrefixSumValues), broadcastVals = lc.broadcastVals, partitioner = lc.partitioner, dependency = lc.dependency, globals = lc.globals, - contexts = StreamZip( + contexts = zipIR( FastSeq(lc.contexts, ToStream(partitionPrefixSumsRef)), - FastSeq(zipOldContextRef.name, zipPartAggUID.name), - MakeStruct(FastSeq(("oldContext", zipOldContextRef), ("scanState", zipPartAggUID))), ArrayZipBehavior.AssertSameLength, - ), - partition = { (partitionRef: Ref) => - bindIRs(GetField(partitionRef, "oldContext"), GetField(partitionRef, "scanState")) { - case Seq(oldContext, rawPrefixSum) => - bindIR(transformPrefixSum(rawPrefixSum)) { scanState => - Let( - FastSeq(TableIR.globalName -> lc.globals), - RunAggScan( - lc.partition(oldContext), - TableIR.rowName, - aggSigs.initFromSerializedValueOp(scanState), - aggs.seqPerElt, - aggs.result, - aggSigs.states, - ), - ) - } - } + ) { case Seq(oldContext, scanState) => + makestruct("oldContext" -> oldContext, "scanState" -> scanState) }, + partition = partitionRef => + M.eval { + for { + oldContext <- GetField(partitionRef, "oldContext") + rawPrefixSum <- GetField(partitionRef, "scanState") + scanState <- transformPrefixSum(rawPrefixSum) + _ <- TableIR.globalName -> lc.globals.ir + } yield RunAggScan( + lc.partition(oldContext), + TableIR.rowName, + aggSigs.initFromSerializedValueOp(scanState), + aggs.seqPerElt, + aggs.result, + aggSigs.states, + ) + }, ) } @@ -1941,22 +1784,18 @@ object LowerTableIR extends Logging { commonKeyLength, (lGlobals, _) => lGlobals, (leftPart, rightPart) => { - val leftElementRef = Ref(freshName(), left.typ.rowType) - val rightElementRef = Ref(freshName(), right.typ.rowType) + val (typeOfRootStruct, _) = + right.typ.rowType.filterSet(right.typ.key.toSet, include = false) - val (typeOfRootStruct, _) = right.typ.rowType.filterSet(right.typ.key.toSet, false) - val rootStruct = SelectFields(rightElementRef, typeOfRootStruct.fieldNames.toIndexedSeq) - val joiningOp = InsertFields(leftElementRef, FastSeq(root -> rootStruct)) - StreamJoinRightDistinct( + joinRightDistinctIR( leftPart, rightPart, left.typ.key.take(commonKeyLength), right.typ.key, - leftElementRef.name, - rightElementRef.name, - joiningOp, "left", - ) + ) { (l, r) => + InsertFields(l, FastSeq(root -> SelectFields(r, typeOfRootStruct.fieldNames))) + } }, ) @@ -1967,39 +1806,22 @@ object LowerTableIR extends Logging { analyses.requirednessAnalysis.lookup(right).asInstanceOf[RTable].rowType, (lGlobals, _) => lGlobals, { (lstream, rstream) => - val lref = Ref(freshName(), left.typ.rowType) - if (product) { - val rref = Ref(freshName(), TArray(right.typ.rowType)) - StreamLeftIntervalJoin( + if (product) + leftIntervalJoinIR( lstream, rstream, left.typ.key.head, right.typ.keyType.fields(0).name, - lref.name, - rref.name, + ) { (l, r) => InsertFields( - lref, - FastSeq( - root -> mapArray(rref)(SelectFields(_, right.typ.valueType.fieldNames)) - ), - ), - ) - } else { - val rref = Ref(freshName(), right.typ.rowType) - StreamJoinRightDistinct( - lstream, - rstream, - left.typ.key, - right.typ.key, - lref.name, - rref.name, - InsertFields( - lref, - FastSeq(root -> SelectFields(rref, right.typ.valueType.fieldNames)), - ), - "left", - ) - } + l, + FastSeq(root -> mapArray(r)(SelectFields(_, right.typ.valueType.fieldNames))), + ) + } + else + joinRightDistinctIR(lstream, rstream, left.typ.key, right.typ.key, "left") { (l, r) => + InsertFields(l, FastSeq(root -> SelectFields(r, right.typ.valueType.fieldNames))) + } }, ) @@ -2034,7 +1856,7 @@ object LowerTableIR extends Logging { ctxRef => StreamMultiMerge( repartitioned.indices.map(i => - repartitioned(i).partition(GetTupleElement(ctxRef, i)) + bindIR(GetTupleElement(ctxRef, i))(ctx => repartitioned(i).partition(ctx)) ), keyType.fieldNames, ), @@ -2052,7 +1874,7 @@ object LowerTableIR extends Logging { val repartitioned = lowered.map(_.repartitionNoShuffle(ctx, newPartitioner)) val newGlobals = MakeStruct(FastSeq( globalName -> MakeArray( - repartitioned.map(_.globals), + repartitioned.map(_.globals.ir), TArray(repartitioned.head.globalType), ) )) @@ -2067,7 +1889,7 @@ object LowerTableIR extends Logging { TableStage( repartitioned.flatMap(_.letBindings) :+ globalsRef.name -> newGlobals, repartitioned.flatMap(_.broadcastVals) :+ globalsRef.name -> globalsRef, - globalsRef, + globalsRef.ir, newPartitioner, TableStageDependency.union(repartitioned.map(_.dependency)), zipIR(repartitioned.map(_.contexts), ArrayZipBehavior.AssumeSameLength) { ctxRefs => @@ -2076,7 +1898,7 @@ object LowerTableIR extends Logging { ctxRef => StreamZipJoin( repartitioned.indices.map(i => - repartitioned(i).partition(GetTupleElement(ctxRef, i)) + bindIR(GetTupleElement(ctxRef, i))(ctx => repartitioned(i).partition(ctx)) ), keyType.fieldNames, keyRef.name, @@ -2095,22 +1917,23 @@ object LowerTableIR extends Logging { case TableExplode(child, path) => lower(child).mapPartition(Some(child.typ.key.takeWhile(k => k != path(0)))) { rows => - flatMapIR(rows) { case row: Ref => - val refsBuffer = Array.fill[Ref](path.length + 1)(null) - val rootsBuffer = Array.fill[IR](path.length)(null) - var i = 0 - refsBuffer(0) = row - while (i < path.length) { - rootsBuffer(i) = GetField(refsBuffer(i), path(i)) - refsBuffer(i + 1) = Ref(freshName(), rootsBuffer(i).typ) - i += 1 + flatMapIR(rows) { row => + val N = path.length + + val bindings = new Array[Binding](N) + val refs = new Array[Atom](N) + val last = (0 until N).foldLeft(row) { (ref, i) => + refs(i) = ref + val root = GetField(ref, path(i)) + val next = Ref(freshName(), root.typ) + bindings(i) = Binding(next.name, root) + next } - val refs = ArraySeq.unsafeWrapArray(refsBuffer) - val roots = ArraySeq.unsafeWrapArray(rootsBuffer) - Let( - refs.tail.zip(roots).map { case (ref, root) => ref.name -> root }, - mapIR(ToStream(refs.last, true)) { elt => - path.zip(refs.init).foldRight[IR](elt) { case ((p, ref), inserted) => + + Block( + bindings.unsafeToArraySeq, + mapIR(ToStream(last, requiresMemoryManagementPerElement = true)) { elt => + path.zip(refs.unsafeToArraySeq).foldRight[IR](elt) { case ((p, ref), inserted) => InsertFields(ref, FastSeq(p -> inserted)) } }, @@ -2135,7 +1958,7 @@ object LowerTableIR extends Logging { ), dependency = lc.dependency, contexts = mapIR(StreamGrouped(lc.contexts, groupSize))(group => ToArray(group)), - partition = (r: Ref) => flatMapIR(ToStream(r))(prevCtx => lc.partition(prevCtx)), + partition = r => flatMapIR(ToStream(r))(prevCtx => lc.partition(prevCtx)), ) case TableRename(child, rowMap, globalMap) => @@ -2150,11 +1973,11 @@ object LowerTableIR extends Logging { TableStage( loweredChild.letBindings :+ newGlobalsRef.name -> newGlobals, loweredChild.broadcastVals :+ newGlobalsRef.name -> newGlobalsRef, - newGlobalsRef, + newGlobalsRef.ir, loweredChild.partitioner.copy(kType = loweredChild.kType.rename(rowMap)), loweredChild.dependency, loweredChild.contexts, - (ctxRef: Ref) => + ctxRef => mapIR(loweredChild.partition(ctxRef)) { row => CastRename(row, row.typ.asInstanceOf[TStruct].rename(rowMap)) }, @@ -2172,41 +1995,33 @@ object LowerTableIR extends Logging { case TableToTableApply(child, TableFilterPartitions(seq, keep)) => val lc = lower(child) - - val arr = seq.sorted - val keptSet = seq.toSet - val lit = Literal(TSet(TInt32), keptSet) - if (keep) { - def lookupRangeBound(idx: Int): Interval = { - try - lc.partitioner.rangeBounds(idx) - catch { - case exc: ArrayIndexOutOfBoundsException => - fatal(s"_filter_partitions: no partition with index $idx", exc) - } - } - - lc.copy( - partitioner = lc.partitioner.copy(rangeBounds = arr.map(lookupRangeBound)), - contexts = mapIR( - filterIR( - zipWithIndex(lc.contexts) - )(t => invoke("contains", TBoolean, lit, GetField(t, "idx"))) - )(t => GetField(t, "elt")), - ) - } else { - lc.copy( - partitioner = - lc.partitioner.copy(rangeBounds = lc.partitioner.rangeBounds.zipWithIndex.filter { - case (_, idx) => !keptSet.contains(idx) - }.map(_._1)), - contexts = mapIR( - filterIR( - zipWithIndex(lc.contexts) - )(t => !invoke("contains", TBoolean, lit, GetField(t, "idx"))) - )(t => GetField(t, "elt")), - ) - } + val keepSet = seq.toSet + lc.copy( + partitioner = + lc.partitioner.copy( + rangeBounds = + lc.partitioner + .rangeBounds + .zipWithIndex + .flatMap { case (interval, idx) => + if (keep == keepSet.contains(idx)) Some(interval) + else None + } + ), + contexts = + bindIR(Literal(TSet(TInt32), keepSet)) { keepSet => + flatten(zip2(lc.contexts, iota(0, 1), ArrayZipBehavior.TakeMinLength) { + (c, idx) => + val `contains?` = invoke("contains", TBoolean, keepSet, idx) + val elt = GetField(c, "elt") + If( + if (keep) `contains?` else !`contains?`, + MakeStream.single(elt), + MakeStream.empty(elt.typ), + ) + }) + }, + ) case TableToTableApply( child, @@ -2229,33 +2044,29 @@ object LowerTableIR extends Logging { case BlockMatrixToTable(bmir) => val ts = LowerBlockMatrixIR.lowerToTableStage(bmir, typesToLower, ctx, analyses) // I now have an unkeyed table of (blockRow, blockCol, block). - ts.mapPartitionWithContext { (partition, ctxRef) => - flatMapIR(partition)(singleRowRef => - bindIR(GetField(singleRowRef, "block")) { singleNDRef => - bindIR(NDArrayShape(singleNDRef)) { shapeTupleRef => - flatMapIR(rangeIR(Cast(GetTupleElement(shapeTupleRef, 0), TInt32))) { - withinNDRowIdx => - mapIR(rangeIR(Cast(GetTupleElement(shapeTupleRef, 1), TInt32))) { - withinNDColIdx => - val entry = NDArrayRef( - singleNDRef, - IndexedSeq(Cast(withinNDRowIdx, TInt64), Cast(withinNDColIdx, TInt64)), - ErrorIDs.NO_ERROR, - ) - val blockStartRow = GetField(singleRowRef, "blockRow") * bmir.typ.blockSize - val blockStartCol = GetField(singleRowRef, "blockCol") * bmir.typ.blockSize - makestruct( - "i" -> Cast(withinNDRowIdx + blockStartRow, TInt64), - "j" -> Cast(withinNDColIdx + blockStartCol, TInt64), - "entry" -> entry, - ) - } + ts.mapPartitionWithContext { (partition, _) => + flatMapIR(partition) { row => + M.eval { + for { + block <- GetField(row, "block") + rowOffset <- GetField(row, "blockRow").toL * I64(bmir.typ.blockSize.toLong) + colOffset <- GetField(row, "blockCol").toL * I64(bmir.typ.blockSize.toLong) + + shape <- NDArrayShape(block) + numRows <- GetTupleElement(shape, 0).toI + numCols <- GetTupleElement(shape, 1).toI + } yield flatMapIR(mapIR(rangeIR(numRows))(_.toL)) { rowIdx => + mapIR(mapIR(rangeIR(numCols))(_.toL)) { colIdx => + makestruct( + "i" -> (rowIdx + rowOffset), + "j" -> (colIdx + colOffset), + "entry" -> NDArrayRef(block, FastSeq(rowIdx, colIdx), ErrorIDs.NO_ERROR), + ) } } } - ) + } } - case node => throw new LowererUnsupportedOperation(s"undefined: \n${Pretty(ctx, node)}") } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala index 69985bba295..1f9293ac6de 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala @@ -32,12 +32,11 @@ object LowerTableIRHelpers { loweredRight, joinKey, joinType, - (lGlobals, rGlobals) => { - val rGlobalType = rGlobals.typ.asInstanceOf[TStruct] - bindIR(rGlobals) { rGlobalRef => - InsertFields(lGlobals, rGlobalType.fieldNames.map(f => f -> GetField(rGlobalRef, f))) - } - }, + (lGlobals, rGlobals) => + InsertFields( + lGlobals, + rGlobals.typ.asInstanceOf[TStruct].fieldNames.map(f => f -> GetField(rGlobals, f)), + ), (lEltRef, rEltRef) => { MakeStruct( lKeyFields.lazyZip(rKeyFields).map { (lKey, rKey) => diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala b/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala index b49d351a931..b4f0db213d1 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LoweringPass.scala @@ -61,15 +61,9 @@ case class OptimizePass(_context: String) extends LoweringPass { case object LowerMatrixToTablePass extends LoweringPass { override val context: String = "LowerMatrixToTable" - override def before: Invariant = AnyIR + override def before: Invariant = LowerableIR override def after: Invariant = before and NoMatrixIR - - override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = ir match { - case x: IR => LowerMatrixIR(ctx, x) - case x: TableIR => LowerMatrixIR(ctx, x) - case x: MatrixIR => LowerMatrixIR(ctx, x) - case x: BlockMatrixIR => LowerMatrixIR(ctx, x) - } + override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = LowerMatrixIR(ctx, ir) } case object LiftRelationalValuesToRelationalLets extends LoweringPass { @@ -100,8 +94,8 @@ case object LowerOrInterpretNonCompilablePass extends LoweringPass { case class LowerToDistributedArrayPass(t: DArrayLowering.Type) extends LoweringPass { override val context: String = "LowerToDistributedArray" - override def before: Invariant = AnyIR and NoMatrixIR - override def after: Invariant = AnyIR and CompilableIR + override def before: Invariant = LowerableIR and NoMatrixIR + override def after: Invariant = LowerableIR and CompilableIR override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = LowerToCDA(ir.asInstanceOf[IR], t, ctx) diff --git a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala index 8de56a85f15..84044c7e54e 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -6,9 +6,7 @@ import is.hail.backend.{BroadcastValue, ExecuteContext} import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} import is.hail.collection.FastSeq import is.hail.expr.ir._ -import is.hail.expr.ir.defs.{ - GetField, In, Let, MakeStruct, ReadPartition, Ref, StreamRange, ToArray, -} +import is.hail.expr.ir.defs.{GetField, In, Let, MakeStruct, ReadPartition, StreamRange, ToArray} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs.FS import is.hail.rvd.{RVD, RVDType} @@ -123,7 +121,7 @@ object TableStageToRVD { partitioner = _ts.partitioner, dependency = _ts.dependency, contexts = mapIR(_ts.contexts)(c => MakeStruct(FastSeq("context" -> c))), - partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }, + partition = ctx => bindIR(GetField(ctx, "context"))(_ts.partition), ) val sparkContext = ctx.backend.asSpark.sc diff --git a/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala b/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala index 8c6373c1961..6772fbcbf01 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/invariant/package.scala @@ -41,7 +41,7 @@ package invariant { IRTraversal.trace(ir).foreach { case trace @ ir :: _ => if (!invariant(ir)) throw new UnsatisfiedInvariantError( s"""Invariant ${E.value} forbids - |${trace.take(5).map(Pretty(ctx, _, preserveNames = true)).mkString("\nin\n")} + |${trace.take(5).map(Pretty(ctx, _)).mkString("\nin\n")} |""".stripMargin ) } @@ -102,9 +102,9 @@ package object invariant { !newNames.add(name) || names.put(name, ir).forall { orig => throw new UnsatisfiedInvariantError( s"""Invariant ${implicitly[Enclosing].value} forbids redefinition of '$name' in - |${Pretty.ssaStyle(ir, preserveNames = true)} + |${Pretty.ssaStyle(ir)} |Originally bound in - |${Pretty.ssaStyle(orig, preserveNames = true)}""".stripMargin + |${Pretty.ssaStyle(orig)}""".stripMargin ) } } diff --git a/hail/hail/src/is/hail/expr/ir/package.scala b/hail/hail/src/is/hail/expr/ir/package.scala index 98d3f949ff6..125197a1067 100644 --- a/hail/hail/src/is/hail/expr/ir/package.scala +++ b/hail/hail/src/is/hail/expr/ir/package.scala @@ -165,20 +165,20 @@ package object ir extends CompileOps with LowerPriorityImplicits { def guardIR(condition: IR)(body: IR): IR = If(condition, body, NA(body.typ)) - def mapIR(stream: IR)(f: Atom => IR): IR = { + def mapIR(stream: IR)(f: Atom => IR): IR with TypedIR[TStream] = { val ref = Ref(freshName(), tcoerce[TStream](stream.typ).elementType) StreamMap(stream, ref.name, f(ref)) } - def mapArray(array: IR)(f: Atom => IR): IR = + def mapArray(array: IR)(f: Atom => IR): IR with TypedIR[TArray] = ToArray(mapIR(ToStream(array))(f)) - def flatMapIR(stream: IR)(f: Atom => IR): IR = { + def flatMapIR(stream: IR)(f: Atom => IR): IR with TypedIR[TStream] = { val ref = Ref(freshName(), tcoerce[TStream](stream.typ).elementType) StreamFlatMap(stream, ref.name, f(ref)) } - def flatten(stream: IR): IR = + def flatten(stream: IR): IR with TypedIR[TStream] = flatMapIR(if (stream.typ.isInstanceOf[TStream]) stream else ToStream(stream)) { elt => if (elt.typ.isInstanceOf[TStream]) elt else ToStream(elt) } @@ -235,8 +235,8 @@ package object ir extends CompileOps with LowerPriorityImplicits { )( f: (Atom, Atom) => IR ): IR = { - val lRef = Ref(freshName(), left.typ.asInstanceOf[TStream].elementType) - val rRef = Ref(freshName(), right.typ.asInstanceOf[TStream].elementType) + val lRef = Ref(freshName(), TIterable.elementType(left.typ)) + val rRef = Ref(freshName(), TIterable.elementType(right.typ)) StreamJoin( left, right, @@ -266,12 +266,33 @@ package object ir extends CompileOps with LowerPriorityImplicits { joinType: String, )( f: (Atom, Atom) => IR - ): IR = { - val lRef = Ref(freshName(), left.typ.asInstanceOf[TStream].elementType) - val rRef = Ref(freshName(), right.typ.asInstanceOf[TStream].elementType) + ): IR with TypedIR[TStream] = { + val lRef = Ref(freshName(), TIterable.elementType(left.typ)) + val rRef = Ref(freshName(), TIterable.elementType(right.typ)) StreamJoinRightDistinct(left, right, lkey, rkey, lRef.name, rRef.name, f(lRef, rRef), joinType) } + def leftIntervalJoinIR( + left: IR, + right: IR, + keyField: String, + intervalField: String, + )( + f: (Atom, Atom) => IR + ): IR with TypedIR[TStream] = { + val lRef = Ref(freshName(), TIterable.elementType(left.typ)) + val rRef = Ref(freshName(), TArray(TIterable.elementType(right.typ))) + StreamLeftIntervalJoin( + left, + right, + keyField, + intervalField, + lRef.name, + rRef.name, + f(lRef, rRef), + ) + } + def streamSumIR(stream: IR): IR = foldIR(stream, 0) { case (accum, elt) => accum + elt } diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index e6928ba3f9f..f5e297e4d22 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -286,7 +286,7 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { } def assertBMEvalsTo( - bm: BlockMatrixIR, + bm0: BlockMatrixIR, expected: DenseMatrix[Double], )(implicit execStrats: Set[ExecStrategy] ): Unit = { @@ -296,6 +296,8 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { logger.info("skipping interpret and non-lowering compile steps on non-spark backend") execStrats.intersect(ExecStrategy.backendOnly) } + + val bm = bm0.deepCopy filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => try { val res = diff --git a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala index c5055b0dedb..f92431f57c1 100644 --- a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala @@ -990,27 +990,27 @@ class Aggregators2Suite extends HailSuite { TStruct("row_idx" -> TInt32), TStruct.empty, ) - val ir = TableCollect(MatrixColsTable(MatrixMapCols( - MatrixRead(t, false, false, MatrixRangeReader(ctx, 10, 10, None)), - InsertFields( - Ref(MatrixIR.colName, t.colType), - FastSeq(( - "foo", - bindIR(GetField(Ref(MatrixIR.colName, t.colType), "col_idx") + I32(1)) { bar => - AggFilter( - GetField(Ref(MatrixIR.rowName, t.rowType), "row_idx") < I32(5), - bar.toL + bar.toL + ApplyAggOp( - FastSeq(), - FastSeq(GetField(Ref(MatrixIR.rowName, t.rowType), "row_idx").toL), - Sum(), - ), - false, - ) - }, - )), - ), - Some(FastSeq()), - ))) + + val col: Atom = Ref(MatrixIR.colName, t.colType) + val row: Atom = Ref(MatrixIR.rowName, t.rowType) + + val ir = TableCollect( + MatrixColsTable( + MatrixMapCols( + MatrixRead(t, false, false, MatrixRangeReader(ctx, 10, 10, None)), + insertIR( + col, + "foo" -> bindIR(GetField(col, "col_idx").toL + 1L) { colIdx => + aggBindIR(GetField(row, "row_idx")) { rowIdx => + AggFilter(rowIdx < 5, colIdx + colIdx + ApplyAggOp(Sum())(rowIdx.toL), false) + } + }, + ), + Some(FastSeq()), + ) + ) + ) + assertEvalsTo(ir, Row((0 until 10).map(i => Row(i, 2L * i + 12L)), Row()))( ExecStrategy.interpretOnly ) diff --git a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala index d69831ecaad..a6ee73035ee 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala @@ -124,7 +124,7 @@ class MatrixIRSuite extends HailSuite { val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) val newRow = - InsertFields(oldRow, FastSeq("range" -> IRScanCollect(GetField(oldRow, "row_idx")))) + InsertFields(oldRow.ir, FastSeq("range" -> IRScanCollect(GetField(oldRow, "row_idx")))) val newMatrix = MatrixMapRows(mt, newRow) val rows = getRows(newMatrix) @@ -138,7 +138,7 @@ class MatrixIRSuite extends HailSuite { val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) val newRow = InsertFields( - oldRow, + oldRow.ir, FastSeq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldRow, "row_idx").toL)), ) @@ -165,7 +165,7 @@ class MatrixIRSuite extends HailSuite { val oldCol = Ref(MatrixIR.colName, mt.typ.colType) val newCol = - InsertFields(oldCol, FastSeq("range" -> IRScanCollect(GetField(oldCol, "col_idx")))) + InsertFields(oldCol.ir, FastSeq("range" -> IRScanCollect(GetField(oldCol, "col_idx")))) val newMatrix = MatrixMapCols(mt, newCol, None) val cols = getCols(newMatrix) @@ -179,7 +179,7 @@ class MatrixIRSuite extends HailSuite { val oldCol = Ref(MatrixIR.colName, mt.typ.colType) val newCol = InsertFields( - oldCol, + oldCol.ir, FastSeq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldCol, "col_idx").toL)), ) @@ -199,7 +199,7 @@ class MatrixIRSuite extends HailSuite { MatrixKeyRowsBy(baseRange, FastSeq()), InsertFields( row, - FastSeq("row_idx" -> (GetField(row, "row_idx") + start)), + FastSeq("row_idx" -> (GetField(row.ir, "row_idx") + start)), ), ), FastSeq("row_idx"), diff --git a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala index 61cd1e34bcb..4d5a6e449cc 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala @@ -1474,7 +1474,7 @@ class TableIRSuite extends HailSuite { val entriesPath = getTestResource("sample.vcf.mt/entries") val mnr = MatrixNativeReader(fs, getTestResource("sample.vcf.mt")) - val mnrSpec = mnr.getSpec() + val mnrSpec = mnr.spec val reader = TableNativeZippedReader(rowsPath, entriesPath, None, mnrSpec.rowsSpec, mnrSpec.entriesSpec) diff --git a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala index 96c3c12fff8..00f8b717439 100644 --- a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala @@ -6,10 +6,9 @@ import is.hail.collection.FastSeq import is.hail.expr.ir._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{ - ApplyBinaryPrimOp, Atom, ErrorIDs, GetField, MakeStream, MakeStruct, Ref, Str, StreamRange, - TableAggregate, TableGetGlobals, + Atom, ErrorIDs, GetField, MakeStream, MakeStruct, Ref, Str, StreamRange, TableAggregate, + TableGetGlobals, } -import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ import is.hail.utils.{HailException, Interval} @@ -111,21 +110,19 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testLowering(): Unit = { - val table = collect(mkTableGen()) - val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) - assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) + val rows = collect(mkTableGen()) + assertEvalsTo(rows, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } @Test(groups = Array("lowering")) def testNumberOfContextsMatchesPartitions(): Unit = { val errorId = 42 - val table = collect(mkTableGen( + val rows = collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), errorId = Some(errorId), )) - val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[HailException] { - loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) + loweredExecute(ctx, rows, Env.empty, FastSeq(), None) } ex.errorId shouldBe errorId ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") @@ -134,7 +131,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testRowsAreCorrectlyKeyed(): Unit = { val errorId = 56 - val table = collect(mkTableGen( + val rows = collect(mkTableGen( partitioner = Some(new RVDPartitioner( ctx.stateManager, TStruct("a" -> TInt32), @@ -145,9 +142,8 @@ class TableGenSuite extends HailSuite { )), errorId = Some(errorId), )) - val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[SparkException] { - loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) + loweredExecute(ctx, rows, Env.empty, FastSeq(), None) }.getCause.asInstanceOf[HailException] ex.errorId shouldBe errorId @@ -195,19 +191,13 @@ class TableGenSuite extends HailSuite { body: Option[(Atom, Atom) => IR] = None, partitioner: Option[RVDPartitioner] = None, errorId: Option[Int] = None, - ): TableGen = { + ): TableGen = tableGen( contexts.getOrElse(StreamRange(0, 2, 1)), - globals.getOrElse(MakeStruct(IndexedSeq("g" -> 0))), + globals.getOrElse(makestruct("g" -> 0)), partitioner.getOrElse(RVDPartitioner.unkeyed(ctx.stateManager, 2)), errorId.getOrElse(ErrorIDs.NO_ERROR), )( - body.getOrElse { (c, g) => - val elem = MakeStruct(IndexedSeq( - "a" -> ApplyBinaryPrimOp(Multiply(), c, GetField(g, "g")) - )) - MakeStream(IndexedSeq(elem), TStream(elem.typ)) - } + body.getOrElse((c, g) => MakeStream.single(makestruct("a" -> c * GetField(g, "g")))) ) - } }