Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 84 additions & 88 deletions hail/hail/src/is/hail/expr/ir/DeprecatedIRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package is.hail.expr.ir
import is.hail.collection.FastSeq
import is.hail.collection.compat.immutable.ArraySeq
import is.hail.collection.implicits.toRichIterable
import is.hail.expr.ir.Scope._
import is.hail.expr.ir.defs._
import is.hail.types.virtual._

import scala.language.dynamics

object DeprecatedIRBuilder {
type E = Env[Type]
type E = BindingEnv[Type]

implicit def funcToIRProxy(ir: E => IR): IRProxy = new IRProxy(ir)

Expand All @@ -25,7 +26,7 @@ object DeprecatedIRBuilder {
implicit def booleanToProxy(b: Boolean): IRProxy = if (b) True() else False()

implicit def ref(s: Symbol): IRProxy = (env: E) =>
Ref(Name(s.name), env.lookup(Name(s.name)))
Ref(Name(s.name), env.eval.lookup(Name(s.name)))

implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s)

Expand Down Expand Up @@ -56,9 +57,8 @@ object DeprecatedIRBuilder {

def concatStructs(struct1: IRProxy, struct2: IRProxy): IRProxy = (env: E) => {
val s2Type = struct2(env).typ.asInstanceOf[TStruct]
let(__struct2 = struct2) {
struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*)
}(env)
(let(__struct2 = struct2) in
struct1.insertFields(s2Type.fieldNames.map(f => Symbol(f) -> '__struct2(Symbol(f))): _*))(env)
}

def makeTuple(values: IRProxy*): IRProxy = (env: E) =>
Expand All @@ -69,34 +69,31 @@ object DeprecatedIRBuilder {
initOpArgs: IndexedSeq[IRProxy] = FastSeq(),
seqOpArgs: IndexedSeq[IRProxy] = FastSeq(),
): IRProxy = (env: E) => {
val i = initOpArgs.map(x => x(env))
val s = seqOpArgs.map(x => x(env))
val i = initOpArgs.map(x => x(env.noAgg))
val s = seqOpArgs.map(x => x(env.promoteAgg))
ApplyAggOp(i, s, op)
}

def aggFilter(filterCond: IRProxy, query: IRProxy, isScan: Boolean = false): IRProxy = (env: E) =>
AggFilter(filterCond(env), query(env), isScan)
AggFilter(filterCond(env.promoteAgg), query(env), isScan)

class TableIRProxy(val tir: TableIR) extends AnyVal {
def empty: E = Env.empty

def globalEnv: E = typ.globalEnv

def env: E = typ.rowEnv
def empty: E = BindingEnv.empty

def typ: TableType = tir.typ

def getGlobals: IR = TableGetGlobals(tir)

def mapGlobals(newGlobals: IRProxy): TableIR =
TableMapGlobals(tir, newGlobals(globalEnv))
TableMapGlobals(tir, newGlobals(BindingEnv(typ.globalEnv)))

def mapRows(newRow: IRProxy): TableIR =
TableMapRows(tir, newRow(env))
TableMapRows(tir, newRow(BindingEnv(typ.rowEnv, scan = Some(typ.rowEnv))))

def explode(sym: Symbol): TableIR = TableExplode(tir, FastSeq(sym.name))

def aggregateByKey(aggIR: IRProxy): TableIR = TableAggregateByKey(tir, aggIR(env))
def aggregateByKey(aggIR: IRProxy): TableIR =
TableAggregateByKey(tir, aggIR(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv))))

def keyBy(keys: IndexedSeq[String], isSorted: Boolean = false): TableIR =
TableKeyBy(tir, keys, isSorted)
Expand All @@ -108,7 +105,7 @@ object DeprecatedIRBuilder {
rename(Map.empty, globalMap)

def filter(ir: IRProxy): TableIR =
TableFilter(tir, ir(env))
TableFilter(tir, ir(BindingEnv(typ.rowEnv)))

def distinct(): TableIR = TableDistinct(tir)

Expand All @@ -129,7 +126,7 @@ object DeprecatedIRBuilder {
}

def aggregate(ir: IRProxy): IR =
TableAggregate(tir, ir(env))
TableAggregate(tir, ir(BindingEnv(typ.globalEnv, agg = Some(typ.rowEnv))))
}

class IRProxy(val ir: E => IR) extends AnyVal with Dynamic {
Expand Down Expand Up @@ -214,7 +211,7 @@ object DeprecatedIRBuilder {
case _: TStruct =>
GetField(eval, lookup.name)
case _: TArray =>
ArrayRef(ir(env), ref(lookup)(env))
ArrayRef(eval, ref(lookup)(env))
}
}

Expand Down Expand Up @@ -259,65 +256,70 @@ object DeprecatedIRBuilder {

def isNA: IRProxy = (env: E) => IsNA(ir(env))

def orElse(alt: IRProxy): IRProxy = { env: E =>
val uid = freshName()
val eir = ir(env)
Let(FastSeq(uid -> eir), If(IsNA(Ref(uid, eir.typ)), alt(env), Ref(uid, eir.typ)))
}
def orElse(alt: IRProxy): IRProxy =
(env: E) => bindIR(ir(env))(x => If(IsNA(x), alt(env), x))

def filter(pred: LambdaProxy): IRProxy = (env: E) => {
def filter(pred: LambdaProxy): IRProxy = { env: E =>
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
val binding = Name(pred.s.name) -> TIterable.elementType(array.typ)
ToArray(StreamFilter(
ToStream(array),
Name(pred.s.name),
pred.body(env.bind(Name(pred.s.name) -> eltType)),
binding._1,
pred.body(env.bindEval(binding)),
))
}

def map(f: LambdaProxy): IRProxy = (env: E) => {
def map(f: LambdaProxy): IRProxy = { env: E =>
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
val binding = Name(f.s.name) -> TIterable.elementType(array.typ)
ToArray(StreamMap(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name) -> eltType)),
binding._1,
f.body(env.bindEval(binding)),
))
}

def aggExplode(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
def aggExplode(f: LambdaProxy): IRProxy = { env: E =>
val array = ir(env.promoteAgg)
val binding = Name(f.s.name) -> TIterable.elementType(array.typ)
AggExplode(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name), array.typ.asInstanceOf[TArray].elementType)),
binding._1,
f.body(env.bindEval(binding).bindAgg(binding)),
isScan = false,
)
}

def flatMap(f: LambdaProxy): IRProxy = (env: E) => {
def flatMap(f: LambdaProxy): IRProxy = { env: E =>
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
val binding = Name(f.s.name) -> TIterable.elementType(array.typ)
ToArray(StreamFlatMap(
ToStream(array),
Name(f.s.name),
ToStream(f.body(env.bind(Name(f.s.name) -> eltType))),
binding._1,
ToStream(f.body(env.bindEval(binding))),
))
}

def streamAgg(f: LambdaProxy): IRProxy = (env: E) => {
def flatten: IRProxy =
flatMap('a ~> 'a)

def streamAgg(f: LambdaProxy): IRProxy = { env: E =>
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
StreamAgg(ToStream(array), Name(f.s.name), f.body(env.bind(Name(f.s.name) -> eltType)))
val binding = Name(f.s.name) -> TIterable.elementType(array.typ)
StreamAgg(
ToStream(array),
binding._1,
f.body(env.bindEval(binding).createAgg),
)
}

def streamAggScan(f: LambdaProxy): IRProxy = (env: E) => {
def streamAggScan(f: LambdaProxy): IRProxy = { env: E =>
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
val binding = Name(f.s.name) -> TIterable.elementType(array.typ)
ToArray(StreamAggScan(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name) -> eltType)),
binding._1,
f.body(env.bindEval(binding).createScan),
))
}

Expand All @@ -338,14 +340,20 @@ object DeprecatedIRBuilder {
knownLength: Option[IRProxy],
)(
aggBody: IRProxy
): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
): IRProxy = { env: E =>
val array = ir(env.promoteAgg)

val bindings =
FastSeq(
Name(elementsSym.name) -> TIterable.elementType(array.typ),
Name(indexSym.name) -> TInt32,
)

AggArrayPerElement(
array,
Name(elementsSym.name),
Name(indexSym.name),
aggBody.apply(env.bind(Name(elementsSym.name) -> eltType, Name(indexSym.name) -> TInt32)),
bindings(0)._1,
bindings(1)._1,
aggBody(env.bindEval(bindings: _*).bindAgg(bindings: _*)),
knownLength.map(_(env)),
isScan = false,
)
Expand All @@ -361,18 +369,17 @@ object DeprecatedIRBuilder {
def toDict: IRProxy = (env: E) => ToDict(ToStream(ir(env)))

def parallelize(nPartitions: Option[Int] = None): TableIR =
TableParallelize(ir(Env.empty), nPartitions)
TableParallelize(ir(BindingEnv.empty), nPartitions)

def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = {
val element = Symbol(genUID())
ir
.map(element ~>
def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy =
ir.map(
'__elem ~>
makeTuple(
element.selectFields(keyFields: _*),
element.dropFieldList(keyFields),
))
'__elem.selectFields(keyFields: _*),
'__elem.dropFieldList(keyFields),
)
)
.toDict
}

def tupleElement(i: Int): IRProxy = (env: E) => GetTupleElement(ir(env), i)

Expand All @@ -391,8 +398,13 @@ object DeprecatedIRBuilder {
def bind(bindings: IndexedSeq[BindingProxy], body: IRProxy, env: E): IR = {
var newEnv = env
val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) =>
val resolvedValue = value(newEnv)
newEnv = newEnv.bind(Name(sym.name) -> resolvedValue.typ)
val resolvedValue =
value(
if (scope == AGG) newEnv.promoteAgg
else if (scope == SCAN) newEnv.promoteScan
else newEnv
)
newEnv = newEnv.bindInScope(Name(sym.name), resolvedValue.typ, scope)
Binding(Name(sym.name), resolvedValue, scope)
}
Block(resolvedBindings, body(newEnv))
Expand All @@ -412,38 +424,22 @@ object DeprecatedIRBuilder {
}

class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal {
def apply(body: IRProxy): IRProxy = in(body)

def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) }
def in(body: IRProxy): IRProxy = { env: E => LetProxy.bind(bindings, body, env) }
}

object aggLet extends Dynamic {
def applyDynamicNamed(method: String)(args: (String, IRProxy)*): AggLetProxy = {
def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = {
assert(method == "apply")
new AggLetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) })
new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) })
}
}

class AggLetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal {
def apply(body: IRProxy): IRProxy = in(body)

def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) }
}

object MapIRProxy {
def apply(f: (IRProxy) => IRProxy)(x: IRProxy): IRProxy = (e: E) =>
MapIR(x => f(x)(e))(x(e))
object scanLet extends Dynamic {
def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = {
assert(method == "apply")
new LetProxy(args.toFastSeq.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.SCAN) })
}
}

def subst(x: IRProxy, env: BindingEnv[IRProxy]): IRProxy = (e: E) =>
Subst(
x(e),
BindingEnv(
env.eval.mapValues(_(e)),
agg = env.agg.map(_.mapValues(_(e))),
scan = env.scan.map(_.mapValues(_(e))),
),
)

def lift(f: (IR) => IRProxy)(x: IRProxy): IRProxy = (e: E) => f(x(e))(e)
}
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object ForwardLets extends Logging {
else {
logger.info(
f"Eliminating unused binding:\n" +
f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value, preserveNames = true).trim}"
f"$name: ${value.typ} = ($scope) ${Pretty.ssaStyle(value).trim}"
)
env
}
Expand Down
19 changes: 9 additions & 10 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,16 @@ package defs {
}

object Let {
def apply(bindings: IndexedSeq[(Name, IR)], body: IR): Block =
Block(
bindings.map { case (name, value) => Binding(name, value) },
body,
)
def apply(bindings: IndexedSeq[(Name, IR)], body: IR): IR =
if (bindings.isEmpty) body
else Block(bindings.map { case (n, v) => Binding(n, v) }, body)

def void(bindings: IndexedSeq[(Name, IR)]): IR = {
if (bindings.isEmpty) {
Void()
} else {
def void(bindings: IndexedSeq[(Name, IR)]): IR =
if (bindings.isEmpty) Void()
else {
assert(bindings.last._2.typ == TVoid)
Let(bindings.init, bindings.last._2)
}
}
}

object Begin {
Expand Down Expand Up @@ -538,6 +534,9 @@ package defs {
MakeArray(args.toFastSeq, TArray(args.head.typ))
}

def empty(elementType: Type): MakeArray =
MakeArray(FastSeq.empty[IR], TArray(elementType))

def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null)
: MakeArray = {
assert(requestedType != null || args.nonEmpty)
Expand Down
Loading
Loading