diff --git a/hail/hail/ir-gen/src/Main.scala b/hail/hail/ir-gen/src/Main.scala index 910c27a3d63..50bd28c1238 100644 --- a/hail/hail/ir-gen/src/Main.scala +++ b/hail/hail/ir-gen/src/Main.scala @@ -76,8 +76,7 @@ trait IRDSL { val BaseRef: Trait def TypedIR(t: String): Trait val NDArrayIR: Trait - // AbstractApplyNodeUnseededMissingness{Aware, Oblivious}JVMFunction - def ApplyNode(missingnessAware: Boolean = false): Trait + val ApplyNode: Trait // Implicits for common names @@ -121,12 +120,7 @@ object IRDSL_Impl extends IRDSL { override val BaseRef: Trait = Trait("BaseRef") override def TypedIR(typ: String): Trait = Trait(s"TypedIR[$typ]") override val NDArrayIR: Trait = Trait("NDArrayIR") - - override def ApplyNode(missingnessAware: Boolean = false): Trait = { - val t = - s"AbstractApplyNode[UnseededMissingness${if (missingnessAware) "Aware" else "Oblivious"}JVMFunction]" - Trait(t) - } + override val ApplyNode: Trait = Trait("AbstractApplyNode") trait Repr[+T] { def typ: Type[T] @@ -1109,7 +1103,7 @@ object Main { in("args", child.*), in("returnType", att("Type")), errorID, - ).withTraits(ApplyNode()) + ).withTraits(ApplyNode) r += node( "ApplySeeded", @@ -1118,19 +1112,10 @@ object Main { in("rngState", child), in("staticUID", att("Long")), in("returnType", att("Type")), - ).withTraits(ApplyNode()) + ).withTraits(ApplyNode) .withPreamble("val args = rngState +: _args") .withPreamble("val typeArgs: Seq[Type] = Seq.empty[Type]") - r += node( - "ApplySpecial", - in("function", att("String")), - in("typeArgs", att("Seq[Type]")), - in("args", child.*), - in("returnType", att("Type")), - errorID, - ).withTraits(ApplyNode(missingnessAware = true)) - r += node("LiftMeOut", in("child", child)) r += node("TableCount", tableChild) @@ -1222,8 +1207,8 @@ object Main { "BlockMatrixMultiWriter, ValueReader, ValueWriter}", "is.hail.expr.ir.lowering.TableStageDependency", "is.hail.expr.ir.agg.{PhysicalAggSig, AggStateSig}", - "is.hail.expr.ir.functions.{UnseededMissingnessAwareJVMFunction, " + - "UnseededMissingnessObliviousJVMFunction, TableToValueFunction, MatrixToValueFunction, " + + "is.hail.expr.ir.functions.{MissingnessAwareJVMFunction, " + + "MissingnessObliviousJVMFunction, TableToValueFunction, MatrixToValueFunction, " + "BlockMatrixToValueFunction}", "is.hail.expr.ir.defs.exts._", "scala.collection.compat._", diff --git a/hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala b/hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala index 5b47d3f46da..78191b19e8a 100644 --- a/hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala +++ b/hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala @@ -163,14 +163,7 @@ class BlockMatrixNativeReader( val reader = ETypeValueReader(spec) def blockIR(ctx: IR): IR = { - val path = Apply( - "concat", - FastSeq(), - FastSeq(Str(s"${params.path}/parts/"), ctx), - TString, - ErrorIDs.NO_ERROR, - ) - + val path = invoke("concat", TString, Str(s"${params.path}/parts/"), ctx) ReadValue(path, reader, vType) } diff --git a/hail/hail/src/is/hail/expr/ir/Emit.scala b/hail/hail/src/is/hail/expr/ir/Emit.scala index bd5eb117e68..7f6a36bc2c2 100644 --- a/hail/hail/src/is/hail/expr/ir/Emit.scala +++ b/hail/hail/src/is/hail/expr/ir/Emit.scala @@ -9,6 +9,7 @@ import is.hail.expr.ir.analyses.{ } import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs._ +import is.hail.expr.ir.functions.{MissingnessAwareJVMFunction, MissingnessObliviousJVMFunction} import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.ndarrays.EmitNDArray import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils} @@ -2771,6 +2772,45 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { val rvAgg = agg.Extract.getAgg(sig) rvAgg.result(cb, sc.states(idx), region) + case ir @ Apply(fn, typeArgs, args, rt, errorID) => + ir.implementation match { + case impl: MissingnessObliviousJVMFunction => + val unified = impl.unify(typeArgs, args.map(_.typ), rt) + assert(unified) + + IEmitCode.multiMap( + cb, + args.map(arg => (cb: EmitCodeBuilder) => emitInNewBuilder(cb, arg)), + ) { codeArgs => + val argSTypes = codeArgs.map(_.st) + val retType = impl.computeStrictReturnEmitType(ir.typ, argSTypes) + val k = (fn, typeArgs, argSTypes, retType) + val meth = + methods.get(k) match { + case Some(funcMB) => + funcMB + case None => + val funcMB = impl.getAsMethod(mb.ecb, retType, typeArgs, argSTypes: _*) + methods.update(k, funcMB) + funcMB + } + cb.invokeSCode( + meth, + FastSeq[Param](cb.this_, CodeParam(region), CodeParam(errorID)) ++ codeArgs.map( + pc => + pc: Param + ): _* + ) + } + + case impl: MissingnessAwareJVMFunction => + val codeArgs = args.map(a => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, a))) + val unified = impl.unify(typeArgs, args.map(_.typ), rt) + assert(unified) + val retType = impl.computeReturnEmitType(ir.typ, codeArgs.map(_.emitType)) + impl.apply(cb, region, retType.st, typeArgs, errorID, codeArgs: _*) + } + case x @ ApplySeeded(_, args, rngState, staticUID, rt) => val codeArgs = args.map(a => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, a))) val codeArgsMem = codeArgs.map(_.memoize(cb, "ApplySeeded_arg")) @@ -2778,7 +2818,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { val impl = x.implementation assert(impl.unify(Array.empty[Type], x.argTypes, rt)) val newState = EmitCode.present(mb, state.asRNGState.splitStatic(cb, staticUID)) - impl.applyI( + impl.asInstanceOf[MissingnessObliviousJVMFunction].applyI( region, cb, impl.computeReturnEmitType(x.typ, newState.emitType +: codeArgs.map(_.emitType)).st, @@ -3497,45 +3537,6 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { ) ev.load - case ir @ Apply(fn, typeArgs, args, rt, errorID) => - val impl = ir.implementation - val unified = impl.unify(typeArgs, args.map(_.typ), rt) - assert(unified) - - val emitArgs = args.map(a => EmitCode.fromI(mb)(emitI(a, _))).toFastSeq - - val argSTypes = emitArgs.map(_.st) - val retType = impl.computeStrictReturnEmitType(ir.typ, argSTypes) - val k = (fn, typeArgs, argSTypes, retType) - val meth = - methods.get(k) match { - case Some(funcMB) => - funcMB - case None => - val funcMB = impl.getAsMethod(mb.ecb, retType, typeArgs, argSTypes: _*) - methods.update(k, funcMB) - funcMB - } - EmitCode.fromI(mb) { cb => - val emitArgs = args.map(a => EmitCode.fromI(cb.emb)(emitI(a, _))).toFastSeq - IEmitCode.multiMapEmitCodes(cb, emitArgs) { codeArgs => - cb.invokeSCode( - meth, - FastSeq[Param](cb.this_, CodeParam(region), CodeParam(errorID)) ++ codeArgs.map(pc => - pc: Param - ): _* - ) - } - } - - case x @ ApplySpecial(_, typeArgs, args, rt, errorID) => - val codeArgs = args.map(a => emit(a)) - val impl = x.implementation - val unified = impl.unify(typeArgs, args.map(_.typ), rt) - assert(unified) - val retType = impl.computeReturnEmitType(x.typ, codeArgs.map(_.emitType)) - impl.apply(mb, region, retType.st, typeArgs, errorID, codeArgs: _*) - case WritePartition(stream, pctx, writer) => val ctxCode = emit(pctx) val streamCode = emitStream(stream, region) diff --git a/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala index f7ddadfdf01..6f8a5402a73 100644 --- a/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -809,12 +809,12 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) extends Logg }, x.typ, ) - case ApplySpecial("lor", _, _, _, _) => children match { + case Apply("lor", _, _, _, _) => children match { case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) => ConstantValue(l || r, TBoolean) case _ => AbstractLattice.top } - case ApplySpecial("land", _, _, _, _) => children match { + case Apply("land", _, _, _, _) => children match { case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) => ConstantValue(l && r, TBoolean) case _ => AbstractLattice.top @@ -842,7 +842,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) extends Logg .restrict(keySet) case (IsNA(_), Seq(b: BoolValue)) => b.isNA.restrict(keySet) // collection contains - case (ApplySpecial("contains", _, _, _, _), Seq(ConstantValue(intervalVal), queryVal)) => + case (Apply("contains", _, _, _, _), Seq(ConstantValue(intervalVal), queryVal)) => (intervalVal: @unchecked) match { case null => BoolValue.allNA(keySet) case i: Interval => queryVal match { @@ -891,9 +891,9 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) extends Logg } case (ApplyComparisonOp(op, _, _), Seq(l, r)) => AbstractLattice.compare(l, r, op, keySet) - case (ApplySpecial("lor", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) => + case (Apply("lor", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) => BoolValue.or(l, r) - case (ApplySpecial("land", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) => + case (Apply("land", _, _, _, _), Seq(l: BoolValue, r: BoolValue)) => BoolValue.and(l, r) case (ApplyUnaryPrimOp(Bang, _), Seq(x: BoolValue)) => BoolValue.not(x) diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index 4074a6a314b..422b663b3c4 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -220,7 +220,7 @@ package defs { } } - trait AbstractApplyNode[F <: JVMFunction] extends IR { + trait AbstractApplyNode extends IR { def function: String def args: Seq[IR] @@ -231,9 +231,10 @@ package defs { def argTypes: Seq[Type] = args.map(_.typ) - lazy val implementation: F = + lazy val implementation: JVMFunction = IRFunctionRegistry.lookupFunctionOrFail(function, returnType, typeArgs, argTypes) - .asInstanceOf[F] + + def strictArgs: Boolean = implementation.isInstanceOf[MissingnessObliviousJVMFunction] } object PartitionReader { @@ -779,7 +780,7 @@ package defs { def all(element: IR): IR = aggFoldIR(True()) { accum => - ApplySpecial( + Apply( "land", Seq.empty[Type], FastSeq(accum, element), @@ -787,7 +788,7 @@ package defs { ErrorIDs.NO_ERROR, ) } { (accum1, accum2) => - ApplySpecial( + Apply( "land", Seq.empty[Type], FastSeq(accum1, accum2), diff --git a/hail/hail/src/is/hail/expr/ir/InferType.scala b/hail/hail/src/is/hail/expr/ir/InferType.scala index 83d525969a1..29f02bc2fc5 100644 --- a/hail/hail/src/is/hail/expr/ir/InferType.scala +++ b/hail/hail/src/is/hail/expr/ir/InferType.scala @@ -79,7 +79,7 @@ object InferType { case _ => TBoolean } case a: ApplyIR => a.returnType - case a: AbstractApplyNode[_] => + case a: AbstractApplyNode => val typeArgs = a.typeArgs val argTypes = a.args.map(_.typ) assert(a.implementation.unify(typeArgs, argTypes, a.returnType)) diff --git a/hail/hail/src/is/hail/expr/ir/Interpret.scala b/hail/hail/src/is/hail/expr/ir/Interpret.scala index 90f93b60453..7a30b373fc0 100644 --- a/hail/hail/src/is/hail/expr/ir/Interpret.scala +++ b/hail/hail/src/is/hail/expr/ir/Interpret.scala @@ -838,7 +838,7 @@ object Interpret extends Logging { interpret(result) case ir @ ApplyIR(_, _, _, _, _) => interpret(ir.explicitNode, env, args) - case ApplySpecial("lor", _, Seq(left_, right_), _, _) => + case Apply("lor", _, Seq(left_, right_), _, _) => val left = interpret(left_) if (left == true) true @@ -850,7 +850,7 @@ object Interpret extends Logging { null else false } - case ApplySpecial("land", _, Seq(left_, right_), _, _) => + case Apply("land", _, Seq(left_, right_), _, _) => val left = interpret(left_) if (left == false) false @@ -862,7 +862,7 @@ object Interpret extends Logging { null else true } - case ir: AbstractApplyNode[_] => + case ir: AbstractApplyNode => val argTuple = PType.canonical(TTuple(ir.args.map(_.typ): _*)).setRequired(true).asInstanceOf[PTuple] ctx.r.pool.scopedRegion { region => diff --git a/hail/hail/src/is/hail/expr/ir/Parser.scala b/hail/hail/src/is/hail/expr/ir/Parser.scala index 99ba3c74234..88cc398c3a4 100644 --- a/hail/hail/src/is/hail/expr/ir/Parser.scala +++ b/hail/hail/src/is/hail/expr/ir/Parser.scala @@ -1384,8 +1384,6 @@ object IRParser { } yield ApplySeeded(function, args, rngState, staticUID, rt) case "ApplyIR" => apply_like(ctx, ApplyIR.apply)(it) - case "ApplySpecial" => - apply_like(ctx, ApplySpecial)(it) case "Apply" => apply_like(ctx, Apply)(it) case "MatrixCount" => diff --git a/hail/hail/src/is/hail/expr/ir/Pretty.scala b/hail/hail/src/is/hail/expr/ir/Pretty.scala index 035a8804fe9..00b0ee733ea 100644 --- a/hail/hail/src/is/hail/expr/ir/Pretty.scala +++ b/hail/hail/src/is/hail/expr/ir/Pretty.scala @@ -429,8 +429,6 @@ class Pretty( FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) case ApplySeeded(function, _, _, staticUID, t) => FastSeq(prettyIdentifier(function), staticUID.toString, t.parsableString()) - case ApplySpecial(function, typeArgs, _, t, errorID) => - FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) case SelectFields(_, fields) => single(fillList(fields.view.map(f => text(prettyIdentifier(f))))) case LowerBoundOnOrderedCollection(_, _, onKey) => single(Pretty.prettyBooleanLiteral(onKey)) diff --git a/hail/hail/src/is/hail/expr/ir/Requiredness.scala b/hail/hail/src/is/hail/expr/ir/Requiredness.scala index dc5629ea27d..4d71a1287fb 100644 --- a/hail/hail/src/is/hail/expr/ir/Requiredness.scala +++ b/hail/hail/src/is/hail/expr/ir/Requiredness.scala @@ -822,7 +822,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.union(oldReq.required) requiredness.unionFrom(oldReq.field(idx)) case x: ApplyIR => requiredness.unionFrom(lookup(x.body)) - case x: AbstractApplyNode[_] => // FIXME: round-tripping via PTypes. + case x: AbstractApplyNode => // FIXME: round-tripping via PTypes. val argP = x.args.map { a => val pt = lookup(a).canonicalPType(a.typ) EmitType(pt.sType, pt.required) diff --git a/hail/hail/src/is/hail/expr/ir/Simplify.scala b/hail/hail/src/is/hail/expr/ir/Simplify.scala index 2d4ef2e3ced..aeb949b0a7c 100644 --- a/hail/hail/src/is/hail/expr/ir/Simplify.scala +++ b/hail/hail/src/is/hail/expr/ir/Simplify.scala @@ -63,14 +63,13 @@ object Simplify { */ private[this] def hasMissingStrictChild(x: IR): Boolean = { x match { - case _: Apply | - _: ApplySeeded | - _: ApplyUnaryPrimOp | + case _: ApplyUnaryPrimOp | _: ApplyBinaryPrimOp | _: ArrayRef | _: ArrayLen | _: GetField | _: GetTupleElement => x.children.exists(_.isInstanceOf[NA]) + case x: AbstractApplyNode if x.strictArgs => x.children.exists(_.isInstanceOf[NA]) case ApplyComparisonOp(op, _, _) if op.strict => x.children.exists(_.isInstanceOf[NA]) case _ => false } @@ -933,7 +932,7 @@ object Simplify { case TableFilter(TableFilter(t, p1), p2) => Some(TableFilter( t, - ApplySpecial("land", Array.empty[Type], Array(p1, p2), TBoolean, ErrorIDs.NO_ERROR), + Apply("land", Array.empty[Type], Array(p1, p2), TBoolean, ErrorIDs.NO_ERROR), )) case TableFilter(TableKeyBy(child, key, isSorted), p) => @@ -1400,7 +1399,7 @@ object Simplify { Some( MatrixFilterRows( child, - ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), + Apply("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), ) ) @@ -1408,14 +1407,14 @@ object Simplify { Some( MatrixFilterCols( child, - ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), + Apply("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), ) ) case MatrixFilterEntries(MatrixFilterEntries(child, pred1), pred2) => Some(MatrixFilterEntries( child, - ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), + Apply("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), )) case MatrixMapGlobals(MatrixMapGlobals(child, ng1), ng2) => diff --git a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala index 166014c34ba..7e004022b5e 100644 --- a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala +++ b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala @@ -543,7 +543,7 @@ object TypeCheck { case Trap(_) => case ConsoleLog(msg, _) => assert(msg.typ == TString) case ApplyIR(_, _, _, _, _) => - case x: AbstractApplyNode[_] => + case x: AbstractApplyNode => assert(x.implementation.unify(x.typeArgs, x.args.map(_.typ), x.returnType)) case MatrixWrite(_, _) => case MatrixMultiWrite(children, _) => diff --git a/hail/hail/src/is/hail/expr/ir/analyses/SemanticHash.scala b/hail/hail/src/is/hail/expr/ir/analyses/SemanticHash.scala index 4e0f2db011b..a3a88bebeb5 100644 --- a/hail/hail/src/is/hail/expr/ir/analyses/SemanticHash.scala +++ b/hail/hail/src/is/hail/expr/ir/analyses/SemanticHash.scala @@ -116,11 +116,6 @@ case object SemanticHash extends Logging { Bytes.fromLong(staticUID) ++= EncodeTypename(retTy): Unit - case ApplySpecial(fname, tyArgs, _, retTy, _) => - buffer ++= fname.getBytes - tyArgs.foreach(buffer ++= EncodeTypename(_)) - buffer ++= EncodeTypename(retTy) - case ApplyUnaryPrimOp(op, _) => buffer ++= Bytes.fromClass(op.getClass) diff --git a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala index 665146bc422..49f7ed31839 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala @@ -5,7 +5,7 @@ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.experimental.ExperimentalFunctions import is.hail.expr.ir._ -import is.hail.expr.ir.defs.{Apply, ApplyIR, ApplySeeded, ApplySpecial} +import is.hail.expr.ir.defs.{Apply, ApplyIR, ApplySeeded} import is.hail.io.bgen.BGENFunctions import is.hail.types.physical._ import is.hail.types.physical.stypes.{EmitType, SType, SValue} @@ -14,6 +14,7 @@ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.utils.compat.immutable.ArraySeq import is.hail.variant.{Locus, ReferenceGenome} import scala.collection.compat._ @@ -37,7 +38,7 @@ object IRFunctionRegistry { type IRFunctionSignature = (Seq[Type], Seq[Type], Type, Boolean) type IRFunctionImplementation = (Seq[Type], Seq[IR], Int) => IR - type ConcreteIRFunctionImplementation = (Seq[IR], Int) => IR + type ConcreteIRFunctionImplementation = (IndexedSeq[IR], Int) => IR val irRegistry: mutable.Map[String, mutable.Map[IRFunctionSignature, IRFunctionImplementation]] = new mutable.HashMap() @@ -203,32 +204,12 @@ object IRFunctionRegistry { ): Option[ConcreteIRFunctionImplementation] = { val validIR: Option[ConcreteIRFunctionImplementation] = lookupIR(name, typeParameters, arguments).map { _ => (args, errorID) => - ApplyIR(name, typeParameters, args.toFastSeq, returnType, errorID) + ApplyIR(name, typeParameters, args, returnType, errorID) } - val validMethods = lookupFunction(name, returnType, typeParameters, arguments) - .map { f => - { (irArguments: Seq[IR], errorID: Int) => - f match { - case _: UnseededMissingnessObliviousJVMFunction => - Apply( - name, - typeParameters, - irArguments.toFastSeq, - returnType, - errorID, - ) - case _: UnseededMissingnessAwareJVMFunction => - ApplySpecial( - name, - typeParameters, - irArguments.toFastSeq, - returnType, - errorID, - ) - } - } - } + val validMethods: Option[ConcreteIRFunctionImplementation] = + lookupFunction(name, returnType, typeParameters, arguments) + .map(_ => (args, errorID) => Apply(name, typeParameters, args, returnType, errorID)) (validIR, validMethods) match { case (None, None) => None @@ -458,126 +439,33 @@ abstract class RegistryFunctions { ) } - def registerSCode( - name: String, - valueParameterTypes: Array[Type], - returnType: Type, - calculateReturnType: (Type, Seq[SType]) => SType, - typeParameters: Array[Type] = Array.empty, - )( - impl: (Value[Region], EmitCodeBuilder, Seq[Type], SType, Array[SValue], Value[Int]) => SValue - ): Unit = { - IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, - returnType, calculateReturnType) { - override def apply( - r: Value[Region], - cb: EmitCodeBuilder, - returnSType: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: SValue* - ): SValue = - impl(r, cb, typeParameters, returnSType, args.toArray, errorID) - } - ) - } - - def registerCode( + private def registerSCode( name: String, valueParameterTypes: Array[Type], returnType: Type, calculateReturnType: (Type, Seq[SType]) => SType, typeParameters: Array[Type] = Array.empty, )( - impl: (Value[Region], EmitCodeBuilder, SType, Array[Type], Array[SValue]) => Value[_] - ): Unit = { - IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, - returnType, calculateReturnType) { - override def apply( - r: Value[Region], - cb: EmitCodeBuilder, - returnSType: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: SValue* - ): SValue = { - assert(unify(typeParameters, args.map(_.st.virtualType), returnSType.virtualType)) - val returnValue = impl(r, cb, returnSType, typeParameters.toArray, args.toArray) - returnSType.fromValues(FastSeq(returnValue)) - } - } - ) - } - - def registerEmitCode( - name: String, - valueParameterTypes: Array[Type], - returnType: Type, - calculateReturnType: (Type, Seq[EmitType]) => EmitType, - typeParameters: Array[Type] = Array.empty, - )( - impl: (EmitMethodBuilder[_], Value[Region], SType, Value[Int], Array[EmitCode]) => EmitCode - ): Unit = { + impl: (EmitCodeBuilder, Value[Region], Seq[Type], SType, Array[SValue], Value[Int]) => SValue + ): Unit = IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, - calculateReturnType) { - override def apply( - mb: EmitMethodBuilder[_], - region: Value[Region], - rpt: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: EmitCode* - ): EmitCode = { - assert(unify(typeParameters, args.map(_.st.virtualType), rpt.virtualType)) - impl(mb, region, rpt, errorID, args.toArray) - } - } + new MissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, + returnType, calculateReturnType, impl) ) - } - def registerIEmitCode( + private def registerIEmitCode( name: String, valueParameterTypes: Array[Type], returnType: Type, calculateReturnType: (Type, Seq[EmitType]) => EmitType, typeParameters: Array[Type] = Array.empty, )( - impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], Array[EmitCode]) => IEmitCode - ): Unit = { + impl: (EmitCodeBuilder, Value[Region], SType, Array[EmitCode], Value[Int]) => IEmitCode + ): Unit = IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, - calculateReturnType) { - override def apply( - cb: EmitCodeBuilder, - r: Value[Region], - rpt: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: EmitCode* - ): IEmitCode = { - val res = impl(cb, r, rpt, errorID, args.toArray) - if (res.emitType != calculateReturnType(rpt.virtualType, args.map(_.emitType))) - throw new RuntimeException( - s"type mismatch while registering ${this.name}" + - s"\n got ${res.emitType}, got ${calculateReturnType(rpt.virtualType, args.map(_.emitType))}" - ) - res - } - override def apply( - mb: EmitMethodBuilder[_], - region: Value[Region], - rpt: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: EmitCode* - ): EmitCode = - EmitCode.fromI(mb)(cb => apply(cb, region, rpt, typeParameters, errorID, args: _*)) - } + new MissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, + calculateReturnType, impl) ) - } def registerScalaFunction( name: String, @@ -589,7 +477,7 @@ abstract class RegistryFunctions { method: String, ): Unit = { registerSCode(name, valueParameterTypes, returnType, calculateReturnType) { - case (_, cb, _, rt, args, _) => + case (cb, _, _, rt, args, _) => val cts = valueParameterTypes.map(PrimitiveTypeToIRIntermediateClassTag(_).runtimeClass) val returnValue = cb.memoizeAny( @@ -646,7 +534,7 @@ abstract class RegistryFunctions { } registerSCode(name, valueParameterTypes, returnType, calculateReturnType) { - case (r, cb, _, rt, args, _) => + case (cb, r, _, rt, args, _) => val cts = valueParameterTypes.map(ct(_).runtimeClass) try { val ret = Code.invokeScalaObject(cls, method, cts, args.map(a => wrap(cb, r, a).get))( @@ -725,13 +613,14 @@ abstract class RegistryFunctions { cls: Class[_], method: String, ): Unit = { - registerCode(name, valueParameterTypes, returnType, pt) { case (_, cb, _, _, args) => + registerSCode(name, valueParameterTypes, returnType, pt) { case (cb, _, _, rSType, args, _) => val cts = valueParameterTypes.map(PrimitiveTypeToIRIntermediateClassTag(_).runtimeClass) val ct = PrimitiveTypeToIRIntermediateClassTag(returnType) - cb.memoizeAny( + val retValue = cb.memoizeAny( Code.invokeStatic(cls, method, cts, args.map(a => SType.extractPrimValue(cb, a).get))(ct), typeInfoFromClassTag(ct), ) + rSType.fromValues(ArraySeq(retValue)) } } @@ -755,7 +644,7 @@ abstract class RegistryFunctions { impl: (Value[Region], EmitCodeBuilder, SType, SValue, Value[Int]) => SValue ): Unit = registerSCode(name, Array(mt1), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1), errorID) => impl(r, cb, rt, a1, errorID) + case (cb, r, _, rt, Array(a1), errorID) => impl(r, cb, rt, a1, errorID) } def registerSCode1t( @@ -768,7 +657,7 @@ abstract class RegistryFunctions { impl: (Value[Region], EmitCodeBuilder, Seq[Type], SType, SValue, Value[Int]) => SValue ): Unit = registerSCode(name, Array(mt1), rt, unwrappedApply(pt), typeParameters = typeParams) { - case (r, cb, typeParams, rt, Array(a1), errorID) => impl(r, cb, typeParams, rt, a1, errorID) + case (cb, r, typeParams, rt, Array(a1), errorID) => impl(r, cb, typeParams, rt, a1, errorID) } def registerSCode2( @@ -781,7 +670,7 @@ abstract class RegistryFunctions { impl: (Value[Region], EmitCodeBuilder, SType, SValue, SValue, Value[Int]) => SValue ): Unit = registerSCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2), errorID) => impl(r, cb, rt, a1, a2, errorID) + case (cb, r, _, rt, Array(a1, a2), errorID) => impl(r, cb, rt, a1, a2, errorID) } def registerSCode2t( @@ -795,7 +684,7 @@ abstract class RegistryFunctions { impl: (Value[Region], EmitCodeBuilder, Seq[Type], SType, SValue, SValue, Value[Int]) => SValue ): Unit = registerSCode(name, Array(mt1, mt2), rt, unwrappedApply(pt), typeParameters = typeParams) { - case (r, cb, typeParams, rt, Array(a1, a2), errorID) => + case (cb, r, typeParams, rt, Array(a1, a2), errorID) => impl(r, cb, typeParams, rt, a1, a2, errorID) } @@ -810,7 +699,7 @@ abstract class RegistryFunctions { impl: (Value[Region], EmitCodeBuilder, SType, SValue, SValue, SValue, Value[Int]) => SValue ): Unit = registerSCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3), errorID) => impl(r, cb, rt, a1, a2, a3, errorID) + case (cb, r, _, rt, Array(a1, a2, a3), errorID) => impl(r, cb, rt, a1, a2, a3, errorID) } def registerSCode3t( @@ -835,7 +724,7 @@ abstract class RegistryFunctions { ): Unit = registerSCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt), typeParams) { case (r, cb, typeParams, rt, Array(a1, a2, a3), errorID) => - impl(r, cb, typeParams, rt, a1, a2, a3, errorID) + impl(cb, r, typeParams, rt, a1, a2, a3, errorID) } def registerSCode4( @@ -860,7 +749,7 @@ abstract class RegistryFunctions { ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt)) { case (r, cb, _, rt, Array(a1, a2, a3, a4), errorID) => - impl(r, cb, rt, a1, a2, a3, a4, errorID) + impl(cb, r, rt, a1, a2, a3, a4, errorID) } def registerSCode4t( @@ -886,7 +775,7 @@ abstract class RegistryFunctions { ) => SValue ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt), typeParams) { - case (r, cb, typeParams, rt, Array(a1, a2, a3, a4), errorID) => + case (cb, r, typeParams, rt, Array(a1, a2, a3, a4), errorID) => impl(r, cb, typeParams, rt, a1, a2, a3, a4, errorID) } @@ -913,10 +802,39 @@ abstract class RegistryFunctions { ) => SValue ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4, a5), errorID) => + case (cb, r, _, rt, Array(a1, a2, a3, a4, a5), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, errorID) } + def registerSCode5t( + name: String, + typeParams: Array[Type], + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + mt5: Type, + rt: Type, + pt: (Type, SType, SType, SType, SType, SType) => SType, + )( + impl: ( + Value[Region], + EmitCodeBuilder, + Seq[Type], + SType, + SValue, + SValue, + SValue, + SValue, + SValue, + Value[Int], + ) => SValue + ): Unit = + registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt), typeParams) { + case (cb, r, typeParams, rt, Array(a1, a2, a3, a4, a5), errorID) => + impl(r, cb, typeParams, rt, a1, a2, a3, a4, a5, errorID) + } + def registerSCode6( name: String, mt1: Type, @@ -942,7 +860,7 @@ abstract class RegistryFunctions { ) => SValue ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6), errorID) => + case (cb, r, _, rt, Array(a1, a2, a3, a4, a5, a6), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, a6, errorID) } @@ -973,35 +891,10 @@ abstract class RegistryFunctions { ) => SValue ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6, mt7), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6, a7), errorID) => + case (cb, r, _, rt, Array(a1, a2, a3, a4, a5, a6, a7), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, a6, a7, errorID) } - def registerCode1( - name: String, - mt1: Type, - rt: Type, - pt: (Type, SType) => SType, - )( - impl: (EmitCodeBuilder, Value[Region], SType, SValue) => Value[_] - ): Unit = - registerCode(name, Array(mt1), rt, unwrappedApply(pt)) { - case (r, cb, rt, _, Array(a1)) => impl(cb, r, rt, a1) - } - - def registerCode2( - name: String, - mt1: Type, - mt2: Type, - rt: Type, - pt: (Type, SType, SType) => SType, - )( - impl: (EmitCodeBuilder, Value[Region], SType, SValue, SValue) => Value[_] - ): Unit = - registerCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { - case (r, cb, rt, _, Array(a1, a2)) => impl(cb, r, rt, a1, a2) - } - def registerIEmitCode1( name: String, mt1: Type, @@ -1011,7 +904,7 @@ abstract class RegistryFunctions { impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode) => IEmitCode ): Unit = registerIEmitCode(name, Array(mt1), rt, unwrappedApply(pt)) { - case (cb, r, rt, errorID, Array(a1)) => + case (cb, r, rt, Array(a1), errorID) => impl(cb, r, rt, errorID, a1) } @@ -1025,7 +918,7 @@ abstract class RegistryFunctions { impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode) => IEmitCode ): Unit = registerIEmitCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { - case (cb, r, rt, errorID, Array(a1, a2)) => + case (cb, r, rt, Array(a1, a2), errorID) => impl(cb, r, rt, errorID, a1, a2) } @@ -1048,7 +941,7 @@ abstract class RegistryFunctions { ) => IEmitCode ): Unit = registerIEmitCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt)) { - case (cb, r, rt, errorID, Array(a1, a2, a3)) => + case (cb, r, rt, Array(a1, a2, a3), errorID) => impl(cb, r, rt, errorID, a1, a2, a3) } @@ -1073,7 +966,7 @@ abstract class RegistryFunctions { ) => IEmitCode ): Unit = registerIEmitCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt)) { - case (cb, r, rt, errorID, Array(a1, a2, a3, a4)) => + case (cb, r, rt, Array(a1, a2, a3, a4), errorID) => impl(cb, r, rt, errorID, a1, a2, a3, a4) } @@ -1100,7 +993,7 @@ abstract class RegistryFunctions { ) => IEmitCode ): Unit = registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt)) { - case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5)) => + case (cb, r, rt, Array(a1, a2, a3, a4, a5), errorID) => impl(cb, r, rt, errorID, a1, a2, a3, a4, a5) } @@ -1129,23 +1022,10 @@ abstract class RegistryFunctions { ) => IEmitCode ): Unit = registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6), rt, unwrappedApply(pt)) { - case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5, a6)) => + case (cb, r, rt, Array(a1, a2, a3, a4, a5, a6), errorID) => impl(cb, r, rt, errorID, a1, a2, a3, a4, a5, a6) } - def registerEmitCode2( - name: String, - mt1: Type, - mt2: Type, - rt: Type, - pt: (Type, EmitType, EmitType) => EmitType, - )( - impl: (EmitMethodBuilder[_], Value[Region], SType, Value[Int], EmitCode, EmitCode) => EmitCode - ): Unit = - registerEmitCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { - case (mb, r, rt, errorID, Array(a1, a2)) => impl(mb, r, rt, errorID, a1, a2) - } - def registerIR1( name: String, mt1: Type, @@ -1212,19 +1092,10 @@ sealed abstract class JVMFunction { def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType - def apply( - mb: EmitMethodBuilder[_], - region: Value[Region], - returnType: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: EmitCode* - ): EmitCode - - override def toString: String = + final override def toString: String = s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" - def unify(typeArguments: Seq[Type], valueArgumentTypes: Seq[Type], returnTypeIn: Type) + final def unify(typeArguments: Seq[Type], valueArgumentTypes: Seq[Type], returnTypeIn: Type) : Boolean = { val concrete = (typeArguments ++ valueArgumentTypes) :+ returnTypeIn val types = (typeParameters ++ valueParameterTypes) :+ returnType @@ -1235,26 +1106,22 @@ sealed abstract class JVMFunction { } } -object MissingnessObliviousJVMFunction { - def returnSType( - computeStrictReturnEmitType: (Type, Seq[SType]) => SType - )( - returnType: Type, - valueParameterTypes: Seq[SType], - ): SType = - if (computeStrictReturnEmitType == null) - SType.canonical(returnType) - else - computeStrictReturnEmitType(returnType, valueParameterTypes) -} - -abstract class UnseededMissingnessObliviousJVMFunction( +class MissingnessObliviousJVMFunction( override val name: String, override val typeParameters: Seq[Type], override val valueParameterTypes: Seq[Type], override val returnType: Type, missingnessObliviousComputeReturnType: (Type, Seq[SType]) => SType, + private val impl: ( + EmitCodeBuilder, + Value[Region], + Seq[Type], + SType, + Array[SValue], + Value[Int], + ) => SValue, ) extends JVMFunction { + override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]) : EmitType = EmitType( @@ -1263,33 +1130,22 @@ abstract class UnseededMissingnessObliviousJVMFunction( ) def computeStrictReturnEmitType(returnType: Type, valueParameterTypes: Seq[SType]): SType = - MissingnessObliviousJVMFunction.returnSType(missingnessObliviousComputeReturnType)( - returnType, - valueParameterTypes, - ) + if (missingnessObliviousComputeReturnType == null) + SType.canonical(returnType) + else + missingnessObliviousComputeReturnType(returnType, valueParameterTypes) - def apply( + private def apply( r: Value[Region], cb: EmitCodeBuilder, returnSType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: SValue* - ): SValue - - override def apply( - mb: EmitMethodBuilder[_], - region: Value[Region], - returnType: SType, - typeParameters: Seq[Type], - errorID: Value[Int], - args: EmitCode* - ): EmitCode = - EmitCode.fromI(mb)(cb => - IEmitCode.multiMapEmitCodes(cb, args.toFastSeq) { args => - apply(region, cb, returnType, typeParameters, errorID, args: _*) - } - ) + ): SValue = { + assert(unify(typeParameters, args.map(_.st.virtualType), returnSType.virtualType)) + impl(cb, r, typeParameters, returnSType, args.toArray, errorID) + } def applyI( r: Value[Region], @@ -1305,8 +1161,6 @@ abstract class UnseededMissingnessObliviousJVMFunction( def getAsMethod[C](cb: EmitClassBuilder[C], rpt: SType, typeParameters: Seq[Type], args: SType*) : EmitMethodBuilder[C] = { - val unified = unify(typeParameters, args.map(_.virtualType), rpt.virtualType) - assert(unified, name) val methodbuilder = cb.genEmitMethod( name, FastSeq[ParamType](typeInfo[Region], typeInfo[Int]) ++ args.map(_.paramType), @@ -1326,30 +1180,27 @@ abstract class UnseededMissingnessObliviousJVMFunction( } } -object MissingnessAwareJVMFunction { - def returnSType( - calculateReturnType: (Type, Seq[EmitType]) => EmitType - )( - returnType: Type, - valueParameterTypes: Seq[EmitType], - ): EmitType = - if (calculateReturnType == null) EmitType(SType.canonical(returnType), false) - else calculateReturnType(returnType, valueParameterTypes) -} - -abstract class UnseededMissingnessAwareJVMFunction( +final class MissingnessAwareJVMFunction( override val name: String, override val typeParameters: Seq[Type], override val valueParameterTypes: Seq[Type], override val returnType: Type, missingnessAwareComputeReturnSType: (Type, Seq[EmitType]) => EmitType, + private val impl: ( + EmitCodeBuilder, + Value[Region], + SType, + Array[EmitCode], + Value[Int], + ) => IEmitCode, ) extends JVMFunction { + override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]) : EmitType = - MissingnessAwareJVMFunction.returnSType(missingnessAwareComputeReturnSType)( - returnType, - valueParameterTypes, - ) + if (missingnessAwareComputeReturnSType == null) + EmitType(SType.canonical(returnType), false) + else + missingnessAwareComputeReturnSType(returnType, valueParameterTypes) def apply( cb: EmitCodeBuilder, @@ -1358,6 +1209,13 @@ abstract class UnseededMissingnessAwareJVMFunction( typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode* - ): IEmitCode = - ??? + ): IEmitCode = { + val res = impl(cb, r, rpt, args.toArray, errorID) + if (res.emitType != missingnessAwareComputeReturnSType(rpt.virtualType, args.map(_.emitType))) + throw new RuntimeException( + s"type mismatch while registering ${this.name}" + + s"\n got ${res.emitType}, got ${missingnessAwareComputeReturnSType(rpt.virtualType, args.map(_.emitType))}" + ) + res + } } diff --git a/hail/hail/src/is/hail/expr/ir/functions/LocusFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/LocusFunctions.scala index 5b0b8d65f80..ff3e707ecbc 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/LocusFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/LocusFunctions.scala @@ -10,6 +10,7 @@ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.utils.compat.immutable.ArraySeq import is.hail.variant._ object LocusFunctions extends RegistryFunctions { @@ -530,50 +531,43 @@ object LocusFunctions extends RegistryFunctions { invalidMissing: EmitCode, ) => val plocus = rt.pointType.asInstanceOf[PLocus] + IEmitCode.multiFlatMap( + cb, + ArraySeq( + locusString.toI, + pos1.toI, + pos2.toI, + include1.toI, + include2.toI, + invalidMissing.toI, + ), + ) { + case Seq(locusString, pos1, pos2, include1, include2, invalidMissing) => + val interval = cb.newLocal[Interval]( + "locus_interval_interval", + Code.invokeScalaObject7[ + String, + Int, + Int, + Boolean, + Boolean, + ReferenceGenome, + Boolean, + Interval, + ]( + locusClass, + "makeInterval", + locusString.asString.loadString(cb), + pos1.asInt.value, + pos2.asInt.value, + include1.asBoolean.value, + include2.asBoolean.value, + rgCode(cb.emb, plocus.rg), + invalidMissing.asBoolean.value, + ), + ) - locusString.toI(cb).flatMap(cb) { locusString => - pos1.toI(cb).flatMap(cb) { pos1 => - pos2.toI(cb).flatMap(cb) { pos2 => - include1.toI(cb).flatMap(cb) { include1 => - include2.toI(cb).flatMap(cb) { include2 => - invalidMissing.toI(cb).flatMap(cb) { invalidMissing => - val Lmissing = CodeLabel() - val Ldefined = CodeLabel() - - val interval = cb.newLocal[Interval]( - "locus_interval_interval", - Code.invokeScalaObject7[ - String, - Int, - Int, - Boolean, - Boolean, - ReferenceGenome, - Boolean, - Interval, - ]( - locusClass, - "makeInterval", - locusString.asString.loadString(cb), - pos1.asInt.value, - pos2.asInt.value, - include1.asBoolean.value, - include2.asBoolean.value, - rgCode(cb.emb, plocus.rg), - invalidMissing.asBoolean.value, - ), - ) - - cb.if_(interval.isNull, cb.goto(Lmissing)) - - val intervalCode = emitLocusInterval(cb, r, interval, rt) - cb.goto(Ldefined) - IEmitCode(Lmissing, Ldefined, intervalCode, false) - } - } - } - } - } + IEmitCode.apply(cb, interval.isNull, emitLocusInterval(cb, r, interval, rt)) } } diff --git a/hail/hail/src/is/hail/expr/ir/functions/NDArrayFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/NDArrayFunctions.scala index c6608115720..1e443a36e41 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/NDArrayFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/NDArrayFunctions.scala @@ -163,46 +163,32 @@ object NDArrayFunctions extends RegistryFunctions { (outputFinisher(cb), infoDGESVResult) } - registerIEmitCode2( + registerSCode2( "linear_solve_no_crash", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), - (t, p1, p2) => - EmitType( - PCanonicalStruct( - false, - ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), - ("failed", PBooleanRequired), - ).sType, + (_, _, _) => + PCanonicalStruct( false, - ), + ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), + ("failed", PBooleanRequired), + ).sType, ) { - case ( - cb, - region, - SBaseStructPointer(outputStructType: PCanonicalStruct), - errorID, - aec, - bec, - ) => - aec.toI(cb).flatMap(cb) { apc => - bec.toI(cb).map(cb) { bpc => - val outputNDArrayPType = outputStructType.fieldType("solution") - val (resNDPCode, info) = - linear_solve(apc.asNDArray, bpc.asNDArray, outputNDArrayPType, cb, region, errorID) - val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) - outputStructType.constructFromFields( - cb, - region, - IndexedSeq[EmitCode]( - ndEmitCode, - EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0))), - ), - false, - ) - } - } + case (r, cb, SBaseStructPointer(outputStructType: PCanonicalStruct), a, b, errorID) => + val outputNDArrayPType = outputStructType.fieldType("solution") + val (resNDPCode, info) = + linear_solve(a.asNDArray, b.asNDArray, outputNDArrayPType, cb, r, errorID) + val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) + outputStructType.constructFromFields( + cb, + r, + IndexedSeq[EmitCode]( + ndEmitCode, + EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0))), + ), + false, + ) } registerSCode2( @@ -226,57 +212,48 @@ object NDArrayFunctions extends RegistryFunctions { resPCode } - registerIEmitCode3( + registerSCode3( "linear_triangular_solve_no_crash", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TBoolean, TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), (t, p1, p2, p3) => - EmitType( - PCanonicalStruct( - false, - ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), - ("failed", PBooleanRequired), - ).sType, + PCanonicalStruct( false, - ), + ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), + ("failed", PBooleanRequired), + ).sType, ) { case ( - cb, region, + cb, SBaseStructPointer(outputStructType: PCanonicalStruct), + a, + b, + lower, errorID, - aec, - bec, - lowerec, ) => - aec.toI(cb).flatMap(cb) { apc => - bec.toI(cb).flatMap(cb) { bpc => - lowerec.toI(cb).map(cb) { lowerpc => - val outputNDArrayPType = outputStructType.fieldType("solution") - val (resNDPCode, info) = linear_triangular_solve( - apc.asNDArray, - bpc.asNDArray, - lowerpc.asBoolean, - outputNDArrayPType, - cb, - region, - errorID, - ) - val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) - outputStructType.constructFromFields( - cb, - region, - IndexedSeq[EmitCode]( - ndEmitCode, - EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0))), - ), - false, - ) - } - } - } + val outputNDArrayPType = outputStructType.fieldType("solution") + val (resNDPCode, info) = linear_triangular_solve( + a.asNDArray, + b.asNDArray, + lower.asBoolean, + outputNDArrayPType, + cb, + region, + errorID, + ) + val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) + outputStructType.constructFromFields( + cb, + region, + IndexedSeq[EmitCode]( + ndEmitCode, + EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0))), + ), + false, + ) } registerSCode3( diff --git a/hail/hail/src/is/hail/expr/ir/functions/SetFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/SetFunctions.scala index 453e0b72d40..34ba8fe518a 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/SetFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/SetFunctions.scala @@ -65,7 +65,7 @@ object SetFunctions extends RegistryFunctions { registerIR2("isSubset", TSet(tv("T")), TSet(tv("T")), TBoolean) { (_, s, w, errorID) => foldIR(ToStream(s), True()) { (a, x) => // FIXME short circuit - ApplySpecial( + Apply( "land", FastSeq(), FastSeq(a, contains(w, x)), diff --git a/hail/hail/src/is/hail/expr/ir/functions/StringFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/StringFunctions.scala index a21a88aa2c3..ca74733256e 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/StringFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/StringFunctions.scala @@ -707,7 +707,7 @@ object StringFunctions extends RegistryFunctions { } } - registerEmitCode2( + registerIEmitCode2( "hamming", TString, TString, @@ -715,39 +715,37 @@ object StringFunctions extends RegistryFunctions { { case (_: Type, _: EmitType, _: EmitType) => EmitType(SInt32, false) }, - ) { case (mb, _, _, _, e1, e2) => - EmitCode.fromI(mb) { cb => - e1.toI(cb).flatMap(cb) { case sc1: SStringValue => - e2.toI(cb).flatMap(cb) { case sc2: SStringValue => - val n = cb.newLocal("hamming_n", 0) - val i = cb.newLocal("hamming_i", 0) - - val v1 = cb.newLocal[String]("hamming_str_1", sc1.loadString(cb)) - val v2 = cb.newLocal[String]("hamming_str_2", sc2.loadString(cb)) - - val l1 = cb.newLocal[Int]("hamming_len_1", v1.invoke[Int]("length")) - val l2 = cb.newLocal[Int]("hamming_len_2", v2.invoke[Int]("length")) - val m = l1.cne(l2) - - IEmitCode( - cb, - m, { - cb.while_( - i < l1, { - cb.if_( - v1.invoke[Int, Char]("charAt", i).toI.cne(v2.invoke[Int, Char]( - "charAt", - i, - ).toI), - cb.assign(n, n + 1), - ) - cb.assign(i, i + 1) - }, - ) - primitive(n) - }, - ) - } + ) { case (cb, _, _, _, e1, e2) => + e1.toI(cb).flatMap(cb) { case sc1: SStringValue => + e2.toI(cb).flatMap(cb) { case sc2: SStringValue => + val n = cb.newLocal("hamming_n", 0) + val i = cb.newLocal("hamming_i", 0) + + val v1 = cb.newLocal[String]("hamming_str_1", sc1.loadString(cb)) + val v2 = cb.newLocal[String]("hamming_str_2", sc2.loadString(cb)) + + val l1 = cb.newLocal[Int]("hamming_len_1", v1.invoke[Int]("length")) + val l2 = cb.newLocal[Int]("hamming_len_2", v2.invoke[Int]("length")) + val m = l1.cne(l2) + + IEmitCode( + cb, + m, { + cb.while_( + i < l1, { + cb.if_( + v1.invoke[Int, Char]("charAt", i).toI.cne(v2.invoke[Int, Char]( + "charAt", + i, + ).toI), + cb.assign(n, n + 1), + ) + cb.assign(i, i + 1) + }, + ) + primitive(n) + }, + ) } } } @@ -779,13 +777,13 @@ object StringFunctions extends RegistryFunctions { }, )(thisClass, "strptime") - registerSCode( + registerSCode1t( "parse_json", - Array(TString), + typeParams = Array(tv("T")), + TString, TTuple(tv("T")), - (rType: Type, _: Seq[SType]) => SType.canonical(rType), - typeParameters = Array(tv("T")), - ) { case (r, cb, _, resultType, Array(s: SStringValue), _) => + (rType, _) => SType.canonical(rType), + ) { case (r, cb, _, resultType, s: SStringValue, _) => val warnCtx = cb.emb.genFieldThisRef[mutable.HashSet[String]]("parse_json_context") cb.if_(warnCtx.load().isNull, cb.assign(warnCtx, Code.newInstance[mutable.HashSet[String]]())) diff --git a/hail/hail/src/is/hail/expr/ir/functions/UtilFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/UtilFunctions.scala index e6b63f8a53c..59e225984da 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/UtilFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/UtilFunctions.scala @@ -231,10 +231,10 @@ object UtilFunctions extends RegistryFunctions { ]("valuesSimilar", lb, rb, tol.asDouble.value, abs.asBoolean.value))) } - registerCode1("triangle", TInt32, TInt32, (_: Type, _: SType) => SInt32) { - case (cb, _, _, nn) => + registerSCode1("triangle", TInt32, TInt32, (_: Type, _: SType) => SInt32) { + case (_, cb, _, nn, _) => val n = nn.asInt.value - cb.memoize((n * (n + 1)) / 2) + primitive(cb.memoize((n * (n + 1)) / 2)) } registerSCode1("toInt32", TBoolean, TInt32, (_: Type, _: SType) => SInt32) { @@ -313,58 +313,70 @@ object UtilFunctions extends RegistryFunctions { } Array("min", "max").foreach { name => - registerCode2(name, TFloat32, TFloat32, TFloat32, (_: Type, _: SType, _: SType) => SFloat32) { - case (cb, _, _, v1, v2) => - cb.memoize(Code.invokeStatic2[Math, Float, Float, Float]( + registerSCode2( + name, + TFloat32, + TFloat32, + TFloat32, + (_: Type, _: SType, _: SType) => SFloat32, + ) { + case (_, cb, _, v1, v2, _) => + primitive(cb.memoize(Code.invokeStatic2[Math, Float, Float, Float]( name, v1.asFloat.value, v2.asFloat.value, - )) + ))) } - registerCode2(name, TFloat64, TFloat64, TFloat64, (_: Type, _: SType, _: SType) => SFloat64) { - case (cb, _, _, v1, v2) => - cb.memoize(Code.invokeStatic2[Math, Double, Double, Double]( + registerSCode2( + name, + TFloat64, + TFloat64, + TFloat64, + (_: Type, _: SType, _: SType) => SFloat64, + ) { + case (_, cb, _, v1, v2, _) => + primitive(cb.memoize(Code.invokeStatic2[Math, Double, Double, Double]( name, v1.asDouble.value, v2.asDouble.value, - )) + ))) } val ignoreMissingName = name + "_ignore_missing" val ignoreNanName = "nan" + name val ignoreBothName = ignoreNanName + "_ignore_missing" - registerCode2( + registerSCode2( ignoreNanName, TFloat32, TFloat32, TFloat32, (_: Type, _: SType, _: SType) => SFloat32, ) { - case (cb, _, _, v1, v2) => - cb.memoize(Code.invokeScalaObject2[Float, Float, Float]( + case (_, cb, _, v1, v2, _) => + primitive(cb.memoize(Code.invokeScalaObject2[Float, Float, Float]( thisClass, ignoreNanName, v1.asFloat.value, v2.asFloat.value, - )) + ))) } - registerCode2( + registerSCode2( ignoreNanName, TFloat64, TFloat64, TFloat64, (_: Type, _: SType, _: SType) => SFloat64, ) { - case (cb, _, _, v1, v2) => - cb.memoize(Code.invokeScalaObject2[Double, Double, Double]( + case (_, cb, _, v1, v2, _) => + primitive(cb.memoize(Code.invokeScalaObject2[Double, Double, Double]( thisClass, ignoreNanName, v1.asDouble.value, v2.asDouble.value, - )) + ))) } def ignoreMissingTriplet[T]( diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 62c4dd70800..33c5cf7b800 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -426,7 +426,7 @@ object LowerDistributedSort extends Logging { ArrayFunctions.extend(minArray, sortedSampling), maxArray, ), - "isSorted" -> ApplySpecial( + "isSorted" -> Apply( "land", Seq.empty[Type], FastSeq(GetField(aggResults, "eachPartSorted"), tuplesInSortedOrder), @@ -931,7 +931,7 @@ object LowerDistributedSort extends Logging { }, True(), ) { case (accum, elt) => - ApplySpecial("land", Seq.empty[Type], FastSeq(accum, elt), TBoolean, ErrorIDs.NO_ERROR) + Apply("land", Seq.empty[Type], FastSeq(accum, elt), TBoolean, ErrorIDs.NO_ERROR) } } } diff --git a/hail/hail/src/is/hail/expr/ir/package.scala b/hail/hail/src/is/hail/expr/ir/package.scala index 53714ab1bfe..c0491de6440 100644 --- a/hail/hail/src/is/hail/expr/ir/package.scala +++ b/hail/hail/src/is/hail/expr/ir/package.scala @@ -42,7 +42,7 @@ package object ir { def invoke(name: String, rt: Type, typeArgs: Seq[Type], errorID: Int, args: IR*): IR = IRFunctionRegistry.lookupUnseeded(name, rt, typeArgs, args.map(_.typ)) match { - case Some(f) => f(args, errorID) + case Some(f) => f(args.toFastSeq, errorID) case None => fatal( s"no conversion found for $name[${typeArgs.mkString(", ")}](${args.map(_.typ).mkString(", ")}) => $rt" ) diff --git a/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala b/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala index 0a76ab436f8..639777f915e 100644 --- a/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala +++ b/hail/hail/src/is/hail/io/bgen/StagedBGENReader.scala @@ -611,19 +611,27 @@ object BGENFunctions extends RegistryFunctions { def uuid(): String = uuid4() override def registerAll(): Unit = { - registerSCode( + registerSCode5t( "index_bgen", - Array(TString, TString, TDict(TString, TString), TBoolean, TInt32), - TInt64, - (_, _) => SInt64, Array(TVariable("locusType")), + TString, + TString, + TDict(TString, TString), + TBoolean, + TInt32, + TInt64, + (_, _, _, _, _, _) => SInt64, ) { case ( r, cb, Seq(locType), _, - Array(_path, _idxPath, _recoding, _skipInvalidLoci, _bufferSize), + _path, + _idxPath, + _recoding, + _skipInvalidLoci, + _bufferSize, err, ) => val mb = cb.emb diff --git a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala index 67ac11bd135..db8aef2f015 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala @@ -674,13 +674,8 @@ class EmitStreamSuite extends HailSuite { val end = if (e == null) NA(TInt32) else I32(e.asInstanceOf[Int]) val includesStart = is == '[' val includesEnd = ie == ']' - val interval = ApplySpecial( - "Interval", - FastSeq(), - FastSeq(start, end, includesStart, includesEnd), - TInterval(TInt32), - 0, - ) + val interval = + invoke("Interval", TInterval(TInt32), start, end, includesStart, includesEnd) MakeStruct(IndexedSeq("k" -> interval, "v" -> Str(v))) }, TStream(rEltType), diff --git a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala index 04fbc108720..34c7ca3173c 100644 --- a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala @@ -5,8 +5,8 @@ import is.hail.ExecStrategy.ExecStrategy import is.hail.annotations._ import is.hail.asm4s._ import is.hail.expr.ir.defs.{ - ApplyComparisonOp, ApplySpecial, ArraySort, ErrorIDs, GetField, I32, In, IsNA, Literal, - MakeStream, NA, ToArray, ToDict, ToSet, ToStream, True, + Apply, ApplyComparisonOp, ArraySort, ErrorIDs, GetField, I32, In, IsNA, Literal, MakeStream, NA, + ToArray, ToDict, ToSet, ToStream, True, } import is.hail.expr.ir.orderings.CodeOrdering import is.hail.scalacheck._ @@ -569,7 +569,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val set2 = ToSet(MakeStream(IndexedSeq(I32(9), I32(1), I32(4)), TStream(TInt32))) assertEvalsTo( foldIR(ToStream(set1), True()) { (acc, elt) => - ApplySpecial( + Apply( "land", FastSeq(), FastSeq(