Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions hail/hail/ir-gen/src/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ trait IRDSL {
val BaseRef: Trait
def TypedIR(t: String): Trait
val NDArrayIR: Trait
// AbstractApplyNodeUnseededMissingness{Aware, Oblivious}JVMFunction
def ApplyNode(missingnessAware: Boolean = false): Trait
val ApplyNode: Trait

// Implicits for common names

Expand Down Expand Up @@ -121,12 +120,7 @@ object IRDSL_Impl extends IRDSL {
override val BaseRef: Trait = Trait("BaseRef")
override def TypedIR(typ: String): Trait = Trait(s"TypedIR[$typ]")
override val NDArrayIR: Trait = Trait("NDArrayIR")

override def ApplyNode(missingnessAware: Boolean = false): Trait = {
val t =
s"AbstractApplyNode[UnseededMissingness${if (missingnessAware) "Aware" else "Oblivious"}JVMFunction]"
Trait(t)
}
override val ApplyNode: Trait = Trait("AbstractApplyNode")

trait Repr[+T] {
def typ: Type[T]
Expand Down Expand Up @@ -1109,7 +1103,7 @@ object Main {
in("args", child.*),
in("returnType", att("Type")),
errorID,
).withTraits(ApplyNode())
).withTraits(ApplyNode)

r += node(
"ApplySeeded",
Expand All @@ -1118,19 +1112,10 @@ object Main {
in("rngState", child),
in("staticUID", att("Long")),
in("returnType", att("Type")),
).withTraits(ApplyNode())
).withTraits(ApplyNode)
.withPreamble("val args = rngState +: _args")
.withPreamble("val typeArgs: Seq[Type] = Seq.empty[Type]")

r += node(
"ApplySpecial",
in("function", att("String")),
in("typeArgs", att("Seq[Type]")),
in("args", child.*),
in("returnType", att("Type")),
errorID,
).withTraits(ApplyNode(missingnessAware = true))

r += node("LiftMeOut", in("child", child))

r += node("TableCount", tableChild)
Expand Down Expand Up @@ -1222,8 +1207,8 @@ object Main {
"BlockMatrixMultiWriter, ValueReader, ValueWriter}",
"is.hail.expr.ir.lowering.TableStageDependency",
"is.hail.expr.ir.agg.{PhysicalAggSig, AggStateSig}",
"is.hail.expr.ir.functions.{UnseededMissingnessAwareJVMFunction, " +
"UnseededMissingnessObliviousJVMFunction, TableToValueFunction, MatrixToValueFunction, " +
"is.hail.expr.ir.functions.{MissingnessAwareJVMFunction, " +
"MissingnessObliviousJVMFunction, TableToValueFunction, MatrixToValueFunction, " +
"BlockMatrixToValueFunction}",
"is.hail.expr.ir.defs.exts._",
"scala.collection.compat._",
Expand Down
9 changes: 1 addition & 8 deletions hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,7 @@ class BlockMatrixNativeReader(
val reader = ETypeValueReader(spec)

def blockIR(ctx: IR): IR = {
val path = Apply(
"concat",
FastSeq(),
FastSeq(Str(s"${params.path}/parts/"), ctx),
TString,
ErrorIDs.NO_ERROR,
)

val path = invoke("concat", TString, Str(s"${params.path}/parts/"), ctx)
ReadValue(path, reader, vType)
}

Expand Down
81 changes: 41 additions & 40 deletions hail/hail/src/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import is.hail.expr.ir.analyses.{
}
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs._
import is.hail.expr.ir.functions.{MissingnessAwareJVMFunction, MissingnessObliviousJVMFunction}
import is.hail.expr.ir.lowering.TableStageDependency
import is.hail.expr.ir.ndarrays.EmitNDArray
import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils}
Expand Down Expand Up @@ -2771,14 +2772,53 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
val rvAgg = agg.Extract.getAgg(sig)
rvAgg.result(cb, sc.states(idx), region)

case ir @ Apply(fn, typeArgs, args, rt, errorID) =>
ir.implementation match {
case impl: MissingnessObliviousJVMFunction =>
val unified = impl.unify(typeArgs, args.map(_.typ), rt)
assert(unified)

IEmitCode.multiMap(
cb,
args.map(arg => (cb: EmitCodeBuilder) => emitInNewBuilder(cb, arg)),
) { codeArgs =>
val argSTypes = codeArgs.map(_.st)
val retType = impl.computeStrictReturnEmitType(ir.typ, argSTypes)
val k = (fn, typeArgs, argSTypes, retType)
val meth =
methods.get(k) match {
case Some(funcMB) =>
funcMB
case None =>
val funcMB = impl.getAsMethod(mb.ecb, retType, typeArgs, argSTypes: _*)
methods.update(k, funcMB)
funcMB
}
cb.invokeSCode(
meth,
FastSeq[Param](cb.this_, CodeParam(region), CodeParam(errorID)) ++ codeArgs.map(
pc =>
pc: Param
): _*
)
}

case impl: MissingnessAwareJVMFunction =>
val codeArgs = args.map(a => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, a)))
val unified = impl.unify(typeArgs, args.map(_.typ), rt)
assert(unified)
val retType = impl.computeReturnEmitType(ir.typ, codeArgs.map(_.emitType))
impl.apply(cb, region, retType.st, typeArgs, errorID, codeArgs: _*)
}

case x @ ApplySeeded(_, args, rngState, staticUID, rt) =>
val codeArgs = args.map(a => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, a)))
val codeArgsMem = codeArgs.map(_.memoize(cb, "ApplySeeded_arg"))
val state = emitI(rngState).getOrAssert(cb)
val impl = x.implementation
assert(impl.unify(Array.empty[Type], x.argTypes, rt))
val newState = EmitCode.present(mb, state.asRNGState.splitStatic(cb, staticUID))
impl.applyI(
impl.asInstanceOf[MissingnessObliviousJVMFunction].applyI(
region,
cb,
impl.computeReturnEmitType(x.typ, newState.emitType +: codeArgs.map(_.emitType)).st,
Expand Down Expand Up @@ -3497,45 +3537,6 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
)
ev.load

case ir @ Apply(fn, typeArgs, args, rt, errorID) =>
val impl = ir.implementation
val unified = impl.unify(typeArgs, args.map(_.typ), rt)
assert(unified)

val emitArgs = args.map(a => EmitCode.fromI(mb)(emitI(a, _))).toFastSeq

val argSTypes = emitArgs.map(_.st)
val retType = impl.computeStrictReturnEmitType(ir.typ, argSTypes)
val k = (fn, typeArgs, argSTypes, retType)
val meth =
methods.get(k) match {
case Some(funcMB) =>
funcMB
case None =>
val funcMB = impl.getAsMethod(mb.ecb, retType, typeArgs, argSTypes: _*)
methods.update(k, funcMB)
funcMB
}
EmitCode.fromI(mb) { cb =>
val emitArgs = args.map(a => EmitCode.fromI(cb.emb)(emitI(a, _))).toFastSeq
IEmitCode.multiMapEmitCodes(cb, emitArgs) { codeArgs =>
cb.invokeSCode(
meth,
FastSeq[Param](cb.this_, CodeParam(region), CodeParam(errorID)) ++ codeArgs.map(pc =>
pc: Param
): _*
)
}
}

case x @ ApplySpecial(_, typeArgs, args, rt, errorID) =>
val codeArgs = args.map(a => emit(a))
val impl = x.implementation
val unified = impl.unify(typeArgs, args.map(_.typ), rt)
assert(unified)
val retType = impl.computeReturnEmitType(x.typ, codeArgs.map(_.emitType))
impl.apply(mb, region, retType.st, typeArgs, errorID, codeArgs: _*)

case WritePartition(stream, pctx, writer) =>
val ctxCode = emit(pctx)
val streamCode = emitStream(stream, region)
Expand Down
10 changes: 5 additions & 5 deletions hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -809,12 +809,12 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) extends Logg
},
x.typ,
)
case ApplySpecial("lor", _, _, _, _) => children match {
case Apply("lor", _, _, _, _) => children match {
case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) =>
ConstantValue(l || r, TBoolean)
case _ => AbstractLattice.top
}
case ApplySpecial("land", _, _, _, _) => children match {
case Apply("land", _, _, _, _) => children match {
case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) =>
ConstantValue(l && r, TBoolean)
case _ => AbstractLattice.top
Expand Down Expand Up @@ -842,7 +842,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) extends Logg
.restrict(keySet)
case (IsNA(_), Seq(b: BoolValue)) => b.isNA.restrict(keySet)
// collection contains
case (ApplySpecial("contains", _, _, _, _), Seq(ConstantValue(intervalVal), queryVal)) =>
case (Apply("contains", _, _, _, _), Seq(ConstantValue(intervalVal), queryVal)) =>
(intervalVal: @unchecked) match {
case null => BoolValue.allNA(keySet)
case i: Interval => queryVal match {
Expand Down Expand Up @@ -891,9 +891,9 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) extends Logg
}
case (ApplyComparisonOp(op, _, _), Seq(l, r)) =>
AbstractLattice.compare(l, r, op, keySet)
case (ApplySpecial("lor", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) =>
case (Apply("lor", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) =>
BoolValue.or(l, r)
case (ApplySpecial("land", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) =>
case (Apply("land", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) =>
BoolValue.and(l, r)
case (ApplyUnaryPrimOp(Bang, _), Seq(x: BoolValue)) =>
BoolValue.not(x)
Expand Down
11 changes: 6 additions & 5 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ package defs {
}
}

trait AbstractApplyNode[F <: JVMFunction] extends IR {
trait AbstractApplyNode extends IR {
def function: String

def args: Seq[IR]
Expand All @@ -231,9 +231,10 @@ package defs {

def argTypes: Seq[Type] = args.map(_.typ)

lazy val implementation: F =
lazy val implementation: JVMFunction =
IRFunctionRegistry.lookupFunctionOrFail(function, returnType, typeArgs, argTypes)
.asInstanceOf[F]

def strictArgs: Boolean = implementation.isInstanceOf[MissingnessObliviousJVMFunction]
}

object PartitionReader {
Expand Down Expand Up @@ -779,15 +780,15 @@ package defs {

def all(element: IR): IR =
aggFoldIR(True()) { accum =>
ApplySpecial(
Apply(
"land",
Seq.empty[Type],
FastSeq(accum, element),
TBoolean,
ErrorIDs.NO_ERROR,
)
} { (accum1, accum2) =>
ApplySpecial(
Apply(
"land",
Seq.empty[Type],
FastSeq(accum1, accum2),
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/InferType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object InferType {
case _ => TBoolean
}
case a: ApplyIR => a.returnType
case a: AbstractApplyNode[_] =>
case a: AbstractApplyNode =>
val typeArgs = a.typeArgs
val argTypes = a.args.map(_.typ)
assert(a.implementation.unify(typeArgs, argTypes, a.returnType))
Expand Down
6 changes: 3 additions & 3 deletions hail/hail/src/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ object Interpret extends Logging {
interpret(result)
case ir @ ApplyIR(_, _, _, _, _) =>
interpret(ir.explicitNode, env, args)
case ApplySpecial("lor", _, Seq(left_, right_), _, _) =>
case Apply("lor", _, Seq(left_, right_), _, _) =>
val left = interpret(left_)
if (left == true)
true
Expand All @@ -850,7 +850,7 @@ object Interpret extends Logging {
null
else false
}
case ApplySpecial("land", _, Seq(left_, right_), _, _) =>
case Apply("land", _, Seq(left_, right_), _, _) =>
val left = interpret(left_)
if (left == false)
false
Expand All @@ -862,7 +862,7 @@ object Interpret extends Logging {
null
else true
}
case ir: AbstractApplyNode[_] =>
case ir: AbstractApplyNode =>
val argTuple =
PType.canonical(TTuple(ir.args.map(_.typ): _*)).setRequired(true).asInstanceOf[PTuple]
ctx.r.pool.scopedRegion { region =>
Expand Down
2 changes: 0 additions & 2 deletions hail/hail/src/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1384,8 +1384,6 @@ object IRParser {
} yield ApplySeeded(function, args, rngState, staticUID, rt)
case "ApplyIR" =>
apply_like(ctx, ApplyIR.apply)(it)
case "ApplySpecial" =>
apply_like(ctx, ApplySpecial)(it)
case "Apply" =>
apply_like(ctx, Apply)(it)
case "MatrixCount" =>
Expand Down
2 changes: 0 additions & 2 deletions hail/hail/src/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,6 @@ class Pretty(
FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString())
case ApplySeeded(function, _, _, staticUID, t) =>
FastSeq(prettyIdentifier(function), staticUID.toString, t.parsableString())
case ApplySpecial(function, typeArgs, _, t, errorID) =>
FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString())
case SelectFields(_, fields) =>
single(fillList(fields.view.map(f => text(prettyIdentifier(f)))))
case LowerBoundOnOrderedCollection(_, _, onKey) => single(Pretty.prettyBooleanLiteral(onKey))
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
requiredness.union(oldReq.required)
requiredness.unionFrom(oldReq.field(idx))
case x: ApplyIR => requiredness.unionFrom(lookup(x.body))
case x: AbstractApplyNode[_] => // FIXME: round-tripping via PTypes.
case x: AbstractApplyNode => // FIXME: round-tripping via PTypes.
val argP = x.args.map { a =>
val pt = lookup(a).canonicalPType(a.typ)
EmitType(pt.sType, pt.required)
Expand Down
13 changes: 6 additions & 7 deletions hail/hail/src/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ object Simplify {
*/
private[this] def hasMissingStrictChild(x: IR): Boolean = {
x match {
case _: Apply |
_: ApplySeeded |
_: ApplyUnaryPrimOp |
case _: ApplyUnaryPrimOp |
_: ApplyBinaryPrimOp |
_: ArrayRef |
_: ArrayLen |
_: GetField |
_: GetTupleElement => x.children.exists(_.isInstanceOf[NA])
case x: AbstractApplyNode if x.strictArgs => x.children.exists(_.isInstanceOf[NA])
case ApplyComparisonOp(op, _, _) if op.strict => x.children.exists(_.isInstanceOf[NA])
case _ => false
}
Expand Down Expand Up @@ -933,7 +932,7 @@ object Simplify {
case TableFilter(TableFilter(t, p1), p2) =>
Some(TableFilter(
t,
ApplySpecial("land", Array.empty[Type], Array(p1, p2), TBoolean, ErrorIDs.NO_ERROR),
Apply("land", Array.empty[Type], Array(p1, p2), TBoolean, ErrorIDs.NO_ERROR),
))

case TableFilter(TableKeyBy(child, key, isSorted), p) =>
Expand Down Expand Up @@ -1400,22 +1399,22 @@ object Simplify {
Some(
MatrixFilterRows(
child,
ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR),
Apply("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR),
)
)

case MatrixFilterCols(MatrixFilterCols(child, pred1), pred2) =>
Some(
MatrixFilterCols(
child,
ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR),
Apply("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR),
)
)

case MatrixFilterEntries(MatrixFilterEntries(child, pred1), pred2) =>
Some(MatrixFilterEntries(
child,
ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR),
Apply("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR),
))

case MatrixMapGlobals(MatrixMapGlobals(child, ng1), ng2) =>
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ object TypeCheck {
case Trap(_) =>
case ConsoleLog(msg, _) => assert(msg.typ == TString)
case ApplyIR(_, _, _, _, _) =>
case x: AbstractApplyNode[_] =>
case x: AbstractApplyNode =>
assert(x.implementation.unify(x.typeArgs, x.args.map(_.typ), x.returnType))
case MatrixWrite(_, _) =>
case MatrixMultiWrite(children, _) =>
Expand Down
Loading