Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9,401 changes: 4,753 additions & 4,648 deletions pkg/sql/parsers/dialect/mysql/mysql_sql.go

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion pkg/sql/parsers/dialect/mysql/mysql_sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -11913,6 +11913,17 @@ function_call_generic:
Exprs: $3,
}
}
| INTERVAL '(' bit_expr ',' expression_list ')'
{
name := tree.NewUnresolvedColName($1)
exprs := tree.Exprs{$3}
exprs = append(exprs, $5...)
$$ = &tree.FuncExpr{
Func: tree.FuncName2ResolvableFunctionReference(name),
FuncName: tree.NewCStr($1, 1),
Exprs: exprs,
}
}
| substr_option '(' expression_list_opt ')'
{
name := tree.NewUnresolvedColName($1)
Expand Down Expand Up @@ -12626,7 +12637,7 @@ predicate:
{
$$ = tree.NewRangeCond(true, $1, $4, $6)
}
| bit_expr
| bit_expr %prec LOWER_THAN_COMMA

like_escape_opt:
{
Expand Down
33 changes: 21 additions & 12 deletions pkg/sql/plan/base_binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,26 @@ func (b *baseBinder) bindFuncExprImplByAstExpr(name string, astArgs []tree.Expr,
}
}

//promote interval expr rewrite here
if name == "interval" {
if len(astArgs) == 2 {
//interval expr like 'interval 5 day'
if _, ok := astArgs[1].(*tree.TimeUnitExpr); ok {
// rewrite interval function to ListExpr, and return directly
return &plan.Expr{
Typ: plan.Type{
Id: int32(types.T_interval),
},
Expr: &plan.Expr_List{
List: &plan.ExprList{
List: args,
},
},
}, nil
}
}
}

if b.builder != nil {
e, err := bindFuncExprAndConstFold(b.GetContext(), b.builder.compCtx.GetProcess(), name, args)
if err == nil {
Expand Down Expand Up @@ -1568,18 +1588,7 @@ func BindFuncExprImplByPlanExpr(ctx context.Context, name string, args []*Expr)
return args[0], nil
}
}
case "interval":
// rewrite interval function to ListExpr, and return directly
return &plan.Expr{
Typ: plan.Type{
Id: int32(types.T_interval),
},
Expr: &plan.Expr_List{
List: &plan.ExprList{
List: args,
},
},
}, nil

case "and", "or", "not", "xor":
// why not append cast function?
// for i := 0; i < len(args); i++ {
Expand Down
38 changes: 38 additions & 0 deletions pkg/sql/plan/base_binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/matrixorigin/matrixone/pkg/container/types"
"github.com/matrixorigin/matrixone/pkg/pb/plan"
"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -92,3 +93,40 @@ func TestBindFuncExprImplByPlanExpr_JsonValid(t *testing.T) {
require.Equal(t, int32(types.T_bool), result.Typ.Id)
})
}

func TestBindFuncExprImplByAstExpr_IntervalDisambiguation(t *testing.T) {
builder, bindCtx := genBuilderAndCtx()
whereBinder := NewWhereBinder(builder, bindCtx)

t.Run("function style keeps interval builtin", func(t *testing.T) {
args := []tree.Expr{
tree.NewNumVal(int64(5), "5", false, tree.P_int64),
tree.NewNumVal("day", "day", false, tree.P_char),
}
result, err := whereBinder.bindFuncExprImplByAstExpr("interval", args, 0)
require.NoError(t, err)
require.NotNil(t, result)

f := result.GetF()
require.NotNil(t, f, "interval(5, 'day') should bind to the interval builtin")
require.Equal(t, "interval", f.Func.GetObjName())
require.Len(t, f.Args, 2)
require.NotEqual(t, int32(types.T_interval), result.Typ.Id)
})

t.Run("interval expression rewrites to interval list", func(t *testing.T) {
args := []tree.Expr{
tree.NewNumVal(int64(5), "5", false, tree.P_int64),
tree.NewTimeUnitExpr("day"),
}
result, err := whereBinder.bindFuncExprImplByAstExpr("interval", args, 0)
require.NoError(t, err)
require.NotNil(t, result)

require.Equal(t, int32(types.T_interval), result.Typ.Id)
list := result.GetList()
require.NotNil(t, list, "INTERVAL 5 DAY should bind as an interval expression list")
require.Len(t, list.List, 2)
require.Equal(t, "day", list.List[1].GetLit().GetSval())
})
}
270 changes: 270 additions & 0 deletions pkg/sql/plan/function/func_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,276 @@ func builtInConcat(parameters []*vector.Vector, result vector.FunctionResultWrap
return nil
}

func builtInIntervalCheck(_ []overload, inputs []types.Type) checkResult {
if len(inputs) < 2 {
return newCheckResultWithFailure(failedFunctionParametersWrong)
}

for _, source := range inputs {
if !intervalTypeSupported(source) {
return newCheckResultWithFailure(failedFunctionParametersWrong)
}
}
return newCheckResultWithSuccess(0)
}

func intervalTypeSupported(source types.Type) bool {
return source.IsIntOrUint() ||
source.IsFloat() ||
source.Oid == types.T_decimal64 ||
source.Oid == types.T_decimal128 ||
source.Oid.IsMySQLString() ||
source.Oid == types.T_any
}

func builtInInterval(parameters []*vector.Vector, result vector.FunctionResultWrapper, _ *process.Process, length int, selectList *FunctionSelectList) error {
rs := vector.MustFunctionResult[int64](result)
args := make([]intervalParam, len(parameters))
for i := range parameters {
var err error
args[i], err = makeIntervalParam(parameters[i])
if err != nil {
return err
}
}

for i := uint64(0); i < uint64(length); i++ {
n, null, err := args[0].float(i)
if err != nil {
return err
}
if null {
if err := rs.Append(-1, false); err != nil {
return err
}
continue
}

var nDec types.Decimal128
useDecimalComparison := args[0].useDecimalComparison()
if useDecimalComparison {
nDec, _, err = args[0].decimal(i)
if err != nil {
return err
}
}

ret := len(args) - 1
for j := 1; j < len(args); j++ {
var cmp int
if useDecimalComparison && args[j].canCompareAsDecimal() {
var vDec types.Decimal128
vDec, null, err = args[j].decimal(i)
if err != nil {
return err
}
if null {
continue
}
cmp = types.CompareDecimal128WithScale(vDec, nDec, args[j].decimalScale(), args[0].decimalScale())
} else {
var v float64
v, null, err = args[j].float(i)
if err != nil {
return err
}
if null {
continue
}
switch {
case v > n:
cmp = 1
case v < n:
cmp = -1
default:
cmp = 0
}
}
if cmp > 0 {
ret = j - 1
break
}
}

if err := rs.Append(int64(ret), false); err != nil {
return err
}
}
return nil
}

type intervalParam struct {
float func(uint64) (float64, bool, error)
decimal func(uint64) (types.Decimal128, bool, error)
useDecimalCompare bool
canDecimalCompare bool
decScale int32
}

func makeIntervalParam(v *vector.Vector) (intervalParam, error) {
typ := *v.GetType()
p := intervalParam{
useDecimalCompare: typ.IsIntOrUint() || typ.IsDecimal(),
canDecimalCompare: typ.IsIntOrUint() || typ.Oid == types.T_decimal64 || typ.Oid == types.T_decimal128,
}
if typ.Oid == types.T_decimal64 || typ.Oid == types.T_decimal128 {
p.decScale = typ.Scale
}

switch typ.Oid {
case types.T_float64:
fp := vector.GenerateFunctionFixedTypeParameter[float64](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return v, null, nil
}
case types.T_float32:
fp := vector.GenerateFunctionFixedTypeParameter[float32](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
case types.T_int64:
fp := vector.GenerateFunctionFixedTypeParameter[int64](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(v), null, nil
}
case types.T_int32:
fp := vector.GenerateFunctionFixedTypeParameter[int32](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(int64(v)), null, nil
}
case types.T_int16:
fp := vector.GenerateFunctionFixedTypeParameter[int16](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(int64(v)), null, nil
}
case types.T_int8:
fp := vector.GenerateFunctionFixedTypeParameter[int8](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(int64(v)), null, nil
}
case types.T_uint64:
fp := vector.GenerateFunctionFixedTypeParameter[uint64](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
if v <= math.MaxInt64 {
return types.Decimal128FromInt64(int64(v)), null, nil
}
d, err := types.ParseDecimal128(strconv.FormatUint(v, 10), 38, 0)
return d, null, err
}
case types.T_uint32:
fp := vector.GenerateFunctionFixedTypeParameter[uint32](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(int64(v)), null, nil
}
case types.T_uint16:
fp := vector.GenerateFunctionFixedTypeParameter[uint16](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(int64(v)), null, nil
}
case types.T_uint8:
fp := vector.GenerateFunctionFixedTypeParameter[uint8](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromInt64(int64(v)), null, nil
}
case types.T_decimal64:
fp := vector.GenerateFunctionFixedTypeParameter[types.Decimal64](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal64ToFloat64(v, typ.Scale), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128FromDecimal64(v, typ.Scale), null, nil
}
case types.T_decimal128:
fp := vector.GenerateFunctionFixedTypeParameter[types.Decimal128](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return types.Decimal128ToFloat64(v, typ.Scale), null, nil
}
p.decimal = func(idx uint64) (types.Decimal128, bool, error) {
v, null := fp.GetValue(idx)
return v, null, nil
}
case types.T_char, types.T_varchar, types.T_text, types.T_binary, types.T_varbinary, types.T_blob:
fp := vector.GenerateFunctionStrParameter(v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetStrValue(idx)
if null {
return 0, true, nil
}
f, err := strconv.ParseFloat(string(v), 64)
if err != nil {
return 0, false, moerr.NewInvalidArgNoCtx("cast to double", fmt.Sprintf("bad value %s", string(v)))
}
return f, false, nil
}
case types.T_any:
fp := vector.GenerateFunctionFixedTypeParameter[int64](v)
p.float = func(idx uint64) (float64, bool, error) {
v, null := fp.GetValue(idx)
return float64(v), null, nil
}
default:
return p, moerr.NewInvalidInputNoCtxf("interval function have invalid input args type %s", typ.Oid.String())
}
return p, nil
}

func (p intervalParam) useDecimalComparison() bool {
return p.useDecimalCompare
}

func (p intervalParam) canCompareAsDecimal() bool {
return p.canDecimalCompare
}

func (p intervalParam) decimalScale() int32 {
return p.decScale
}

func builtInCharCheck(_ []overload, inputs []types.Type) checkResult {
// CHAR accepts one or more integer arguments
if len(inputs) < 1 {
Expand Down
Loading
Loading