From a34828c99529f80a4792d28022c4abf6f7470888 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 1 Jun 2026 22:34:40 -0400 Subject: [PATCH] [query] no-sharing in various utils --- hail/hail/ir-gen/src/Main.scala | 7 ++- hail/hail/src/is/hail/expr/ir/Binds.scala | 48 ++++++++++--------- .../src/is/hail/expr/ir/ForwardLets.scala | 33 +++++++------ .../hail/expr/ir/ForwardRelationalLets.scala | 2 +- hail/hail/src/is/hail/expr/ir/IR.scala | 3 ++ .../src/is/hail/expr/ir/NormalizeNames.scala | 2 +- .../expr/expressions/typed_expressions.py | 2 +- 7 files changed, 54 insertions(+), 43 deletions(-) diff --git a/hail/hail/ir-gen/src/Main.scala b/hail/hail/ir-gen/src/Main.scala index 3ef96830e8d..a935fabeb30 100644 --- a/hail/hail/ir-gen/src/Main.scala +++ b/hail/hail/ir-gen/src/Main.scala @@ -582,7 +582,8 @@ object Main { r += node("I64", in("x", att("Long"))).withTraits(Atom) r += node("F32", in("x", att("Float"))).withTraits(Atom) r += node("F64", in("x", att("Double"))).withTraits(Atom) - r += node("Str", in("x", att("String"))).withTraits(Atom) + // Making Str < Atom would lead to code bloat + r += node("Str", in("x", att("String"))) .withPreamble( "override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\"" ): @nowarn("msg=possible missing interpolator") @@ -671,7 +672,9 @@ object Main { ) r += node("MakeArray", in("args", child.*), _typ("TArray")).withCompanionExtension - r += node("MakeStream", in("args", child.*), _typ("TStream"), mmPerElt).withCompanionExtension + r += node("MakeStream", in("args", child.*), _typ("TStream"), mmPerElt) + .typed("TStream") + .withCompanionExtension r += node("ArrayRef", in("a", child), in("i", child), errorID) r += node( "ArraySlice", diff --git a/hail/hail/src/is/hail/expr/ir/Binds.scala b/hail/hail/src/is/hail/expr/ir/Binds.scala index 7adee29ede9..fb919db4e01 100644 --- a/hail/hail/src/is/hail/expr/ir/Binds.scala +++ b/hail/hail/src/is/hail/expr/ir/Binds.scala @@ -224,38 +224,40 @@ object Bindings { private def childEnvValue(ir: IR, i: Int): Bindings[Type] = ir match { case Block(bindings, _) => - val bindingsTypes = bindings.view.take(i).map(b => b.name -> b.value.typ).to(ArraySeq) + val types = ArraySeq.newBuilder[(Name, Type)] + types.sizeHint(i) + val eval = ArraySeq.newBuilder[Int] + eval.sizeHint(i) // most likely binding in eval + val agg = ArraySeq.newBuilder[Int] val scan = ArraySeq.newBuilder[Int] - for (k <- 0 until i) bindings(k) match { - case Binding(_, _, Scope.EVAL) => - eval += k - case Binding(_, _, Scope.AGG) => - agg += k - case Binding(_, _, Scope.SCAN) => - scan += k - } - if (i < bindings.length) bindings(i).scope match { - case Scope.EVAL => - Bindings( - bindingsTypes, - eval.result(), - AggEnv.bindOrNoOp(agg.result()), - AggEnv.bindOrNoOp(scan.result()), - ) - case Scope.AGG => - Bindings(bindingsTypes, agg.result(), AggEnv.Promote, AggEnv.bindOrNoOp(scan.result())) - case Scope.SCAN => - Bindings(bindingsTypes, scan.result(), AggEnv.bindOrNoOp(agg.result()), AggEnv.Promote) + + for (k <- 0 until i) { + val Binding(name, value, scope) = bindings(k) + types += name -> value.typ + scope match { + case Scope.EVAL => + eval += k + case Scope.AGG => + agg += k + case Scope.SCAN => + scan += k + } } - else + + if (i == bindings.length || bindings(i).scope == Scope.EVAL) Bindings( - bindingsTypes, + types.result(), eval.result(), AggEnv.bindOrNoOp(agg.result()), AggEnv.bindOrNoOp(scan.result()), ) + else if (bindings(i).scope == Scope.AGG) + Bindings(types.result(), agg.result(), AggEnv.Promote, AggEnv.bindOrNoOp(scan.result())) + else // SCAN + Bindings(types.result(), scan.result(), AggEnv.bindOrNoOp(agg.result()), AggEnv.Promote) + case TailLoop(name, args, resultType, _) if i == args.length => Bindings( args.map { case (name, ir) => name -> ir.typ } :+ diff --git a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala index 79d62533dc9..5355ba3762d 100644 --- a/hail/hail/src/is/hail/expr/ir/ForwardLets.scala +++ b/hail/hail/src/is/hail/expr/ir/ForwardLets.scala @@ -2,24 +2,22 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext import is.hail.collection.compat.immutable.ArraySeq -import is.hail.expr.ir.defs.{BaseRef, Binding, Block, In, Ref, Str} -import is.hail.types.virtual.TVoid +import is.hail.expr.ir.defs.{Atom, BaseRef, Binding, Block, Ref} +import is.hail.utils.Logging import scala.collection.Set -object ForwardLets { +object ForwardLets extends Logging { def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T = ctx.time { val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0) - val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false) + val UsesAndDefs(uses, _, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false) val nestingDepth = NestingDepth(ctx, ir1) def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Scope) : Boolean = IsPure(value) && ( - value.isInstanceOf[Ref] || - value.isInstanceOf[In] || - (IsConstant(value) && !value.isInstanceOf[Str]) || + value.isInstanceOf[Atom] || refs.isEmpty || (refs.size == 1 && nestingDepth.lookupRef(refs.head) == nestingDepth.lookupBinding(base, scope) && @@ -36,11 +34,16 @@ object ForwardLets { val newEnv = l.bindings.foldLeft(env) { case (env, Binding(name, value, scope)) => val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR] - if ( - rewriteValue.typ != TVoid - && shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope) - ) { - env.bindInScope(name, rewriteValue, scope) + val refs_ = refs.filter(_.t.name == name) + if (shouldForward(rewriteValue, refs_, l, scope)) { + if (refs_.nonEmpty) env.bindInScope(name, rewriteValue, scope) + else { + logger.info( + f"Eliminating unused binding:\n" + + f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value, preserveNames = true).trim}" + ) + env + } } else { keep += Binding(name, rewriteValue, scope) env @@ -55,9 +58,9 @@ object ForwardLets { case x @ Ref(name, _) => env.eval .lookupOption(name) - .map { forwarded => - if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy - else forwarded + .map { + case forwarded: Atom => forwarded.ir + case big => big } .getOrElse(x) case _ => diff --git a/hail/hail/src/is/hail/expr/ir/ForwardRelationalLets.scala b/hail/hail/src/is/hail/expr/ir/ForwardRelationalLets.scala index 306237f96e1..655bf81b376 100644 --- a/hail/hail/src/is/hail/expr/ir/ForwardRelationalLets.scala +++ b/hail/hail/src/is/hail/expr/ir/ForwardRelationalLets.scala @@ -10,7 +10,7 @@ object ForwardRelationalLets { ctx.time { val uses = mutable.HashMap.empty[Name, (Int, Int)] val nestingDepth = NestingDepth(ctx, ir0) - IRTraversal.levelOrder(ir0).foreach { + IRTraversal.preOrder(ir0).foreach { case x @ RelationalRef(name, _) => val (n, nd) = uses.getOrElseUpdate(name, (0, 0)) uses(name) = (n + 1, math.max(nd, nestingDepth.lookupRef(x))) diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index 18d2a01fa50..e91756b1d81 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -634,6 +634,9 @@ package defs { requiresMemoryManagementPerElement, ) } + + def single(x: IR): IR with TypedIR[TStream] = + MakeStream(FastSeq(x), TStream(x.typ)) } abstract class ArraySortCompanionExt { diff --git a/hail/hail/src/is/hail/expr/ir/NormalizeNames.scala b/hail/hail/src/is/hail/expr/ir/NormalizeNames.scala index bf13bc714df..1a85a759f15 100644 --- a/hail/hail/src/is/hail/expr/ir/NormalizeNames.scala +++ b/hail/hail/src/is/hail/expr/ir/NormalizeNames.scala @@ -74,7 +74,7 @@ class NormalizeNames(freeVariables: Set[Name]) { case Ref(name, typ) => val newName = env.eval.lookupOption(name).getOrElse { if (!freeVariables.contains(name)) throw new RuntimeException( - s"found free variable in normalize: $name; ${env.pretty(x => x.str)}" + s"found free variable in normalize: $name: ${typ._toPretty}; ${env.pretty(x => x.str)}" ) else name } diff --git a/hail/python/hail/expr/expressions/typed_expressions.py b/hail/python/hail/expr/expressions/typed_expressions.py index 73b9a887440..dcffdf73090 100644 --- a/hail/python/hail/expr/expressions/typed_expressions.py +++ b/hail/python/hail/expr/expressions/typed_expressions.py @@ -4399,7 +4399,7 @@ def reshape(self, *shape): for i, tuple_field_type in enumerate(shape.dtype.types): if tuple_field_type not in [hl.tint32, hl.tint64]: raise TypeError(f"Argument {i} of reshape needs to be an integer, got {tuple_field_type}.") - shape_ir = hl.or_missing(hl.is_defined(shape), hl.tuple([hl.int64(i) for i in shape]))._ir + shape_ir = hl.bind(lambda s: hl.or_missing(hl.is_defined(s), hl.tuple([hl.int64(i) for i in s])), shape)._ir ndim = len(shape) else: wrapped_shape = wrap_to_list(shape)