Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions hail/hail/ir-gen/src/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,8 @@ object Main {
r += node("I64", in("x", att("Long"))).withTraits(Atom)
r += node("F32", in("x", att("Float"))).withTraits(Atom)
r += node("F64", in("x", att("Double"))).withTraits(Atom)
r += node("Str", in("x", att("String"))).withTraits(Atom)
// Making Str < Atom would lead to code bloat
r += node("Str", in("x", att("String")))
.withPreamble(
"override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\""
): @nowarn("msg=possible missing interpolator")
Expand Down Expand Up @@ -671,7 +672,9 @@ object Main {
)

r += node("MakeArray", in("args", child.*), _typ("TArray")).withCompanionExtension
r += node("MakeStream", in("args", child.*), _typ("TStream"), mmPerElt).withCompanionExtension
r += node("MakeStream", in("args", child.*), _typ("TStream"), mmPerElt)
.typed("TStream")
.withCompanionExtension
r += node("ArrayRef", in("a", child), in("i", child), errorID)
r += node(
"ArraySlice",
Expand Down
48 changes: 25 additions & 23 deletions hail/hail/src/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,38 +224,40 @@ object Bindings {
private def childEnvValue(ir: IR, i: Int): Bindings[Type] =
ir match {
case Block(bindings, _) =>
val bindingsTypes = bindings.view.take(i).map(b => b.name -> b.value.typ).to(ArraySeq)
val types = ArraySeq.newBuilder[(Name, Type)]
types.sizeHint(i)

val eval = ArraySeq.newBuilder[Int]
eval.sizeHint(i) // most likely binding in eval

val agg = ArraySeq.newBuilder[Int]
val scan = ArraySeq.newBuilder[Int]
for (k <- 0 until i) bindings(k) match {
case Binding(_, _, Scope.EVAL) =>
eval += k
case Binding(_, _, Scope.AGG) =>
agg += k
case Binding(_, _, Scope.SCAN) =>
scan += k
}
if (i < bindings.length) bindings(i).scope match {
case Scope.EVAL =>
Bindings(
bindingsTypes,
eval.result(),
AggEnv.bindOrNoOp(agg.result()),
AggEnv.bindOrNoOp(scan.result()),
)
case Scope.AGG =>
Bindings(bindingsTypes, agg.result(), AggEnv.Promote, AggEnv.bindOrNoOp(scan.result()))
case Scope.SCAN =>
Bindings(bindingsTypes, scan.result(), AggEnv.bindOrNoOp(agg.result()), AggEnv.Promote)

for (k <- 0 until i) {
val Binding(name, value, scope) = bindings(k)
types += name -> value.typ
scope match {
case Scope.EVAL =>
eval += k
case Scope.AGG =>
agg += k
case Scope.SCAN =>
scan += k
}
}
else

if (i == bindings.length || bindings(i).scope == Scope.EVAL)
Bindings(
bindingsTypes,
types.result(),
eval.result(),
AggEnv.bindOrNoOp(agg.result()),
AggEnv.bindOrNoOp(scan.result()),
)
else if (bindings(i).scope == Scope.AGG)
Bindings(types.result(), agg.result(), AggEnv.Promote, AggEnv.bindOrNoOp(scan.result()))
else // SCAN
Bindings(types.result(), scan.result(), AggEnv.bindOrNoOp(agg.result()), AggEnv.Promote)

case TailLoop(name, args, resultType, _) if i == args.length =>
Bindings(
args.map { case (name, ir) => name -> ir.typ } :+
Expand Down
33 changes: 18 additions & 15 deletions hail/hail/src/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@ package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.collection.compat.immutable.ArraySeq
import is.hail.expr.ir.defs.{BaseRef, Binding, Block, In, Ref, Str}
import is.hail.types.virtual.TVoid
import is.hail.expr.ir.defs.{Atom, BaseRef, Binding, Block, Ref}
import is.hail.utils.Logging

import scala.collection.Set

object ForwardLets {
object ForwardLets extends Logging {
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
ctx.time {
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val UsesAndDefs(uses, _, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ctx, ir1)

def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Scope)
: Boolean =
IsPure(value) && (
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
(IsConstant(value) && !value.isInstanceOf[Str]) ||
value.isInstanceOf[Atom] ||
Copy link
Copy Markdown
Member Author

@ehigham ehigham Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlining Constants can lead to significant IR bloat from Literals and EncodedLiteral.
We don't have Simplifyrules that match on either literal type and so I'm unaware of any reason to inline them.
Note that we pay special attention to literals in Emit to ensure we don't duplicate them in generated code. I'm not sure that this is required anymore, but I'll leave that for another change.

refs.isEmpty ||
(refs.size == 1 &&
nestingDepth.lookupRef(refs.head) == nestingDepth.lookupBinding(base, scope) &&
Comment on lines 17 to 23
Expand All @@ -36,11 +34,16 @@ object ForwardLets {
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
val refs_ = refs.filter(_.t.name == name)
if (shouldForward(rewriteValue, refs_, l, scope)) {
if (refs_.nonEmpty) env.bindInScope(name, rewriteValue, scope)
else {
logger.info(
f"Eliminating unused binding:\n" +
f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value, preserveNames = true).trim}"
)
env
}
} else {
keep += Binding(name, rewriteValue, scope)
env
Expand All @@ -55,9 +58,9 @@ object ForwardLets {
case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy
else forwarded
.map {
case forwarded: Atom => forwarded.ir
case big => big
}
.getOrElse(x)
case _ =>
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/ForwardRelationalLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object ForwardRelationalLets {
ctx.time {
val uses = mutable.HashMap.empty[Name, (Int, Int)]
val nestingDepth = NestingDepth(ctx, ir0)
IRTraversal.levelOrder(ir0).foreach {
IRTraversal.preOrder(ir0).foreach {
case x @ RelationalRef(name, _) =>
val (n, nd) = uses.getOrElseUpdate(name, (0, 0))
uses(name) = (n + 1, math.max(nd, nestingDepth.lookupRef(x)))
Expand Down
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ package defs {
requiresMemoryManagementPerElement,
)
}

def single(x: IR): IR with TypedIR[TStream] =
MakeStream(FastSeq(x), TStream(x.typ))
}

abstract class ArraySortCompanionExt {
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/NormalizeNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class NormalizeNames(freeVariables: Set[Name]) {
case Ref(name, typ) =>
val newName = env.eval.lookupOption(name).getOrElse {
if (!freeVariables.contains(name)) throw new RuntimeException(
s"found free variable in normalize: $name; ${env.pretty(x => x.str)}"
s"found free variable in normalize: $name: ${typ._toPretty}; ${env.pretty(x => x.str)}"
)
else name
}
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/expr/expressions/typed_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4399,7 +4399,7 @@ def reshape(self, *shape):
for i, tuple_field_type in enumerate(shape.dtype.types):
if tuple_field_type not in [hl.tint32, hl.tint64]:
raise TypeError(f"Argument {i} of reshape needs to be an integer, got {tuple_field_type}.")
shape_ir = hl.or_missing(hl.is_defined(shape), hl.tuple([hl.int64(i) for i in shape]))._ir
shape_ir = hl.bind(lambda s: hl.or_missing(hl.is_defined(s), hl.tuple([hl.int64(i) for i in s])), shape)._ir
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While not recomputing shape is desirable, not including this line lead to the following failure:

E           Java stack trace:
E           java.lang.RuntimeException: requiredness mismatch: EC=true / Analysis=false
E           SNDArrayPointer(PCNDArray[+PInt64,4])
E           
E           %1 = EncodedLiteral [Tuple[Int32,Int32,Int32,Int32]]
E           !2 = GetTupleElement(%1) [0]
E           !3 = Cast(!2) [Int64] 
E           !4 = GetTupleElement(%1) [1]
E           !5 = Cast(!4) [Int64] 
E           !6 = GetTupleElement(%1) [2]
E           !7 = Cast(!6) [Int64] 
E           !8 = GetTupleElement(%1) [3]
E           !9 = Cast(!8) [Int64]
E           !10 = MakeTuple(!3, !5, !7, !9) [(0 1 2 3)]
E           NDArrayReshape(#undefined_ref, !10) [524]
E           
E           	at is.hail.expr.ir.Emit.emitI(Emit.scala:3369)
E           	at is.hail.expr.ir.Emit.emitInNewBuilder$1(Emit.scala:1011)
E           	at is.hail.expr.ir.Emit.$anonfun$emitI$41(Emit.scala:1218)
E           	at is.hail.expr.ir.EmitCode$.fromI(Emit.scala:553)
E           	at is.hail.expr.ir.Emit.$anonfun$emitI$40(Emit.scala:1218)
E           	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
E           	at scala.collection.Iterator.foreach(Iterator.scala:943)
E           	at scala.collection.Iterator.foreach$(Iterator.scala:943)
E           	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
E           	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
E           	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
E           	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
E           	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
E           	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
E           	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
E           	at is.hail.expr.ir.Emit.emitI(Emit.scala:1217)
E           	at is.hail.expr.ir.Emit.$anonfun$emitSplitMethod$1(Emit.scala:731)
E           	at is.hail.expr.ir.Emit.$anonfun$emitSplitMethod$1$adapted(Emit.scala:729)
E           	at is.hail.expr.ir.EmitCodeBuilder$.scoped(EmitCodeBuilder.scala:21)
E           	at is.hail.expr.ir.EmitCodeBuilder$.scopedVoid(EmitCodeBuilder.scala:31)
E           	at is.hail.expr.ir.EmitMethodBuilder.voidWithBuilder(EmitClassBuilder.scala:1263)
E           	at is.hail.expr.ir.Emit.emitSplitMethod(Emit.scala:729)
E           	at is.hail.expr.ir.Emit.emitInSeparateMethod(Emit.scala:754)
E           	at is.hail.expr.ir.Emit.emitI(Emit.scala:988)
E           	at is.hail.expr.ir.Emit.emitInNewBuilder$1(Emit.scala:1011)
E           	at is.hail.expr.ir.Emit.$anonfun$emitI$41(Emit.scala:1218)
E           	at is.hail.expr.ir.EmitCode$.fromI(Emit.scala:553)
E           	at is.hail.expr.ir.Emit.$anonfun$emitI$40(Emit.scala:1218)
E           	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
E           	at scala.collection.Iterator.foreach(Iterator.scala:943)
E           	at scala.collection.Iterator.foreach$(Iterator.scala:943)
E           	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
E           	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
E           	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
E           	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
E           	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
E           	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
E           	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
E           	at is.hail.expr.ir.Emit.emitI(Emit.scala:1217)
E           	at is.hail.expr.ir.Emit.emitI$2(Emit.scala:1001)
E           	at is.hail.expr.ir.Emit.emitI(Emit.scala:1118)
E           	at is.hail.expr.ir.Emit.emitInNewBuilder$1(Emit.scala:1011)
E           	at is.hail.expr.ir.Emit.$anonfun$emitI$41(Emit.scala:1218)
E           	at is.hail.expr.ir.EmitCode$.fromI(Emit.scala:553)
E           	at is.hail.expr.ir.Emit.$anonfun$emitI$40(Emit.scala:1218)
E           	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
E           	at scala.collection.Iterator.foreach(Iterator.scala:943)
E           	at scala.collection.Iterator.foreach$(Iterator.scala:943)
E           	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
E           	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
E           	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
E           	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
E           	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
E           	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
E           	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
E           	at is.hail.expr.ir.Emit.emitI(Emit.scala:1217)
E           	at is.hail.expr.ir.Emit$.$anonfun$apply$6(Emit.scala:137)
E           	at is.hail.expr.ir.EmitCodeBuilder$.scoped(EmitCodeBuilder.scala:21)
E           	at is.hail.expr.ir.EmitCodeBuilder$.scopedCode(EmitCodeBuilder.scala:26)
E           	at is.hail.expr.ir.EmitMethodBuilder.emitWithBuilder(EmitClassBuilder.scala:1260)
E           	at is.hail.expr.ir.WrappedEmitMethodBuilder.emitWithBuilder(EmitClassBuilder.scala:1311)
E           	at is.hail.expr.ir.WrappedEmitMethodBuilder.emitWithBuilder$(EmitClassBuilder.scala:1311)
E           	at is.hail.expr.ir.EmitFunctionBuilder.emitWithBuilder(EmitClassBuilder.scala:1328)
E           	at is.hail.expr.ir.Emit$.$anonfun$apply$1(Emit.scala:132)
E           	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:99)
E           	at is.hail.backend.ExecuteContext.time(ExecuteContext.scala:167)
E           	at is.hail.expr.ir.Emit$.apply(Emit.scala:111)
E           	at is.hail.expr.ir.CompileOps.$anonfun$Impl$5(Compile.scala:133)
E           	at scala.collection.mutable.MapLike.getOrElseUpdate(MapLike.scala:209)
E           	at scala.collection.mutable.MapLike.getOrElseUpdate$(MapLike.scala:206)
E           	at scala.collection.mutable.AbstractMap.getOrElseUpdate(Map.scala:85)
E           	at is.hail.expr.ir.CompileOps.$anonfun$Impl$1(Compile.scala:108)
E           	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:99)
E           	at is.hail.backend.ExecuteContext.time(ExecuteContext.scala:167)
E           	at is.hail.expr.ir.CompileOps.Impl(Compile.scala:90)
E           	at is.hail.expr.ir.CompileOps.Compile(Compile.scala:46)
E           	at is.hail.expr.ir.CompileOps.Compile$(Compile.scala:38)
E           	at is.hail.expr.ir.package$.Compile(package.scala:34)
E           	at is.hail.expr.ir.CompileAndEvaluate$.$anonfun$_apply$1(CompileAndEvaluate.scala:72)
E           	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:99)
E           	at is.hail.backend.ExecuteContext.time(ExecuteContext.scala:167)
E           	at is.hail.expr.ir.CompileAndEvaluate$._apply(CompileAndEvaluate.scala:49)
E           	at is.hail.backend.spark.SparkBackend.$anonfun$execute$1(SparkBackend.scala:363)
E           	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:99)
E           	at is.hail.backend.ExecuteContext.time(ExecuteContext.scala:167)
E           	at is.hail.backend.spark.SparkBackend.execute(SparkBackend.scala:348)
E           	at is.hail.backend.driver.BackendRpc.$anonfun$runRpc$2(BackendRpc.scala:99)
E           	at is.hail.backend.driver.BackendRpc.withRegisterSerializedFns(BackendRpc.scala:173)
E           	at is.hail.backend.driver.BackendRpc.$anonfun$runRpc$1(BackendRpc.scala:97)
E           	at is.hail.backend.ExecuteContext$.$anonfun$scoped$3(ExecuteContext.scala:94)
E           	at is.hail.utils.package$.using(package.scala:492)
E           	at is.hail.backend.ExecuteContext$.$anonfun$scoped$2(ExecuteContext.scala:94)
E           	at is.hail.utils.package$.using(package.scala:492)
E           	at is.hail.annotations.RegionPool.scopedRegion(RegionPool.scala:169)
E           	at is.hail.backend.ExecuteContext$.$anonfun$scoped$1(ExecuteContext.scala:77)
E           	at is.hail.utils.package$.using(package.scala:492)
E           	at is.hail.annotations.RegionPool$.scoped(RegionPool.scala:16)
E           	at is.hail.backend.ExecuteContext$.scoped(ExecuteContext.scala:76)
E           	at is.hail.backend.driver.Py4JQueryDriver.$anonfun$withExecuteContext$1(Py4JQueryDriver.scala:354)
E           	at is.hail.utils.ExecutionTimer$.time(ExecutionTimer.scala:16)
E           	at is.hail.backend.driver.Py4JQueryDriver.is$hail$backend$driver$Py4JQueryDriver$$withExecuteContext(Py4JQueryDriver.scala:336)
E           	at is.hail.backend.driver.Py4JQueryDriver$$anon$1$Context$.scoped(Py4JQueryDriver.scala:444)
E           	at is.hail.backend.driver.Py4JQueryDriver$$anon$1$Context$.scoped(Py4JQueryDriver.scala:442)
E           	at is.hail.backend.driver.BackendRpc.runRpc(BackendRpc.scala:83)
E           	at is.hail.backend.driver.BackendRpc.runRpc$(BackendRpc.scala:79)
E           	at is.hail.backend.driver.Py4JQueryDriver$$anon$1.runRpc(Py4JQueryDriver.scala:393)
E           	at is.hail.backend.driver.Py4JQueryDriver$$anon$1.$anonfun$new$1(Py4JQueryDriver.scala:452)
E           	at jdk.httpserver/com.sun.net.httpserver.Filter$Chain.doFilter(Filter.java:77)
E           	at jdk.httpserver/sun.net.httpserver.AuthFilter.doFilter(AuthFilter.java:82)
E           	at jdk.httpserver/com.sun.net.httpserver.Filter$Chain.doFilter(Filter.java:80)
E           	at jdk.httpserver/sun.net.httpserver.ServerImpl$Exchange$LinkHandler.handle(ServerImpl.java:848)
E           	at jdk.httpserver/com.sun.net.httpserver.Filter$Chain.doFilter(Filter.java:77)
E           	at jdk.httpserver/sun.net.httpserver.ServerImpl$Exchange.run(ServerImpl.java:817)
E           	at jdk.httpserver/sun.net.httpserver.ServerImpl$DefaultExecutor.execute(ServerImpl.java:201)
E           	at jdk.httpserver/sun.net.httpserver.ServerImpl$Dispatcher.handle(ServerImpl.java:560)
E           	at jdk.httpserver/sun.net.httpserver.ServerImpl$Dispatcher.run(ServerImpl.java:525)
E           	at java.base/java.lang.Thread.run(Thread.java:829)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this duplicate work? if anything it should do less work as we are evaluating/binding shape before getting all its elements.

ndim = len(shape)
else:
wrapped_shape = wrap_to_list(shape)
Expand Down
Loading