diff --git a/pkg/sql/plan/function/func_binary.go b/pkg/sql/plan/function/func_binary.go index 2709513f2c380..2652af891cff7 100644 --- a/pkg/sql/plan/function/func_binary.go +++ b/pkg/sql/plan/function/func_binary.go @@ -896,11 +896,31 @@ type NormalType interface { types.Datetime | types.Decimal64 | types.Decimal128 | + types.Decimal256 | types.Timestamp | types.Uuid | constraints.Integer } +// coalesceDecimalResult derives the common decimal result type for a coalesce +// over decimal/integer branches, preserving the maximum integral width and +// scale and promoting to a wider decimal family (decimal256) when the combined +// precision overflows decimal128. It then resolves the matching overload index. +// Returns ok=false when the required precision overflows decimal256 or no +// overload matches the resulting decimal family. +func coalesceDecimalResult(overloads []overload, minOid types.T, inputs []types.Type) (types.Type, int, bool) { + target := minOid.ToType() + if !setSafeDecimalWidthAndScaleFromSource(&target, inputs) { + return types.Type{}, -1, false + } + for i, over := range overloads { + if len(over.args) == 1 && over.args[0] == target.Oid { + return target, i, true + } + } + return types.Type{}, -1, false +} + func coalesceCheck(overloads []overload, inputs []types.Type) checkResult { if len(inputs) > 0 { if retType, ok := mixedStringNumericToVarchar(inputs); ok { @@ -930,6 +950,19 @@ func coalesceCheck(overloads []overload, inputs []types.Type) checkResult { if sta == matchFailed { continue } else if sta == matchDirectly { + // Decimals that match directly still need scale/width alignment + // across branches: without it the result inherits the first + // branch's scale while carrying another branch's raw value, + // magnifying the result (issue #24565). Keep them as a candidate + // and resolve the aligned type below instead of short-circuiting. + if requireOid.IsDecimal() { + if cos < minCost { + minIndex = i + minCost = cos + minOid = requireOid + } + continue + } return newCheckResultWithSuccess(i) } else { if cos < minCost { @@ -942,6 +975,32 @@ func coalesceCheck(overloads []overload, inputs []types.Type) checkResult { if minIndex == -1 { return newCheckResultWithFailure(failedFunctionParametersWrong) } + + // Decimal branches: choose a common type that keeps the maximum integral + // width and scale (promoting to decimal256 when needed), so the result + // neither loses integer capacity nor inherits a single branch's scale. + if minOid.IsDecimal() { + target, overloadIndex, ok := coalesceDecimalResult(overloads, minOid, inputs) + if !ok { + return newCheckResultWithFailure(failedFunctionParametersWrong) + } + aligned := true + for i := range inputs { + if inputs[i].Oid != target.Oid || inputs[i].Scale != target.Scale || inputs[i].Width != target.Width { + aligned = false + break + } + } + if aligned { + return newCheckResultWithSuccess(overloadIndex) + } + castType := make([]types.Type, len(inputs)) + for i := range castType { + castType[i] = target + } + return newCheckResultWithCast(overloadIndex, castType) + } + castType := make([]types.Type, len(inputs)) for i := range castType { if minOid == inputs[i].Oid { @@ -952,7 +1011,7 @@ func coalesceCheck(overloads []overload, inputs []types.Type) checkResult { } } - if minOid.IsDecimal() || minOid.IsDateRelate() { + if minOid.IsDateRelate() { setMaxScaleForAll(castType) } return newCheckResultWithCast(minIndex, castType) diff --git a/pkg/sql/plan/function/func_compare.go b/pkg/sql/plan/function/func_compare.go index 338be973e625c..9bb57eb429126 100644 --- a/pkg/sql/plan/function/func_compare.go +++ b/pkg/sql/plan/function/func_compare.go @@ -35,7 +35,7 @@ func otherCompareOperatorSupports(typ1, typ2 types.Type) bool { case types.T_uint8, types.T_uint16, types.T_uint32, types.T_uint64: case types.T_int8, types.T_int16, types.T_int32, types.T_int64: case types.T_float32, types.T_float64: - case types.T_decimal64, types.T_decimal128: + case types.T_decimal64, types.T_decimal128, types.T_decimal256: case types.T_char, types.T_varchar: case types.T_date, types.T_datetime: case types.T_timestamp, types.T_time: @@ -60,7 +60,7 @@ func equalAndNotEqualOperatorSupports(typ1, typ2 types.Type) bool { case types.T_uint8, types.T_uint16, types.T_uint32, types.T_uint64: case types.T_int8, types.T_int16, types.T_int32, types.T_int64: case types.T_float32, types.T_float64: - case types.T_decimal64, types.T_decimal128: + case types.T_decimal64, types.T_decimal128, types.T_decimal256: case types.T_char, types.T_varchar: case types.T_date, types.T_datetime: case types.T_timestamp, types.T_time: @@ -243,6 +243,10 @@ func nullSafeEqualFn(parameters []*vector.Vector, result vector.FunctionResultWr return opBinaryFixedFixedToFixedNullSafe[types.Decimal128](parameters, rs, proc, length, func(a, b types.Decimal128) bool { return a == b }, selectList) + case types.T_decimal256: + return opBinaryFixedFixedToFixedNullSafe[types.Decimal256](parameters, rs, proc, length, func(a, b types.Decimal256) bool { + return a == b + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixedNullSafe[types.Rowid](parameters, rs, proc, length, func(a, b types.Rowid) bool { return a.EQ(&b) @@ -370,6 +374,10 @@ func equalFn(parameters []*vector.Vector, result vector.FunctionResultWrapper, p return valueDec128Compare(parameters, rs, uint64(length), func(a, b types.Decimal128) bool { return a == b }, selectList) + case types.T_decimal256: + return valueDec256Compare(parameters, rs, uint64(length), func(a, b types.Decimal256) bool { + return a == b + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixed[types.Rowid, types.Rowid, bool](parameters, rs, proc, length, func(a, b types.Rowid) bool { return a.EQ(&b) @@ -760,6 +768,231 @@ func valueDec128Compare( return nil } +func valueDec256Compare( + parameters []*vector.Vector, result *vector.FunctionResult[bool], length uint64, + cmpFn func(a, b types.Decimal256) bool, selectList *FunctionSelectList) error { + p1 := vector.GenerateFunctionFixedTypeParameter[types.Decimal256](parameters[0]) + p2 := vector.GenerateFunctionFixedTypeParameter[types.Decimal256](parameters[1]) + + m := p2.GetType().Scale - p1.GetType().Scale + + rsVec := result.GetResultVector() + rss := vector.MustFixedColWithTypeCheck[bool](rsVec) + + c1, c2 := parameters[0].IsConst(), parameters[1].IsConst() + rsNull := rsVec.GetNulls() + rsAnyNull := false + + if selectList != nil { + if selectList.IgnoreAllRow() { + nulls.AddRange(rsNull, 0, uint64(length)) + return nil + } + if !selectList.ShouldEvalAllRow() { + rsAnyNull = true + for i := range selectList.SelectList { + if selectList.Contains(uint64(i)) { + rsNull.Add(uint64(i)) + } + } + } + } + if c1 && c2 { + v1, null1 := p1.GetValue(0) + v2, null2 := p2.GetValue(0) + if null1 || null2 { + nulls.AddRange(rsNull, 0, length) + } else { + if m >= 0 { + x, err := v1.Scale(m) + if err != nil { + return err + } + for i := uint64(0); i < length; i++ { + rss[i] = cmpFn(x, v2) + } + } else { + y, err := v2.Scale(-m) + if err != nil { + return err + } + for i := uint64(0); i < length; i++ { + rss[i] = cmpFn(v1, y) + } + } + } + return nil + } + + if c1 { + v1, null1 := p1.GetValue(0) + if null1 { + nulls.AddRange(rsNull, 0, length) + } else { + if m >= 0 { + x, err := v1.Scale(m) + if err != nil { + return err + } + if p2.WithAnyNullValue() || rsAnyNull { + nulls.Or(rsNull, parameters[1].GetNulls(), rsNull) + for i := uint64(0); i < length; i++ { + if rsNull.Contains(i) { + continue + } + v2, _ := p2.GetValue(i) + rss[i] = cmpFn(x, v2) + } + } else { + for i := uint64(0); i < length; i++ { + v2, _ := p2.GetValue(i) + rss[i] = cmpFn(x, v2) + } + } + } else { + if p2.WithAnyNullValue() { + nulls.Or(rsNull, parameters[1].GetNulls(), rsNull) + for i := uint64(0); i < length; i++ { + if rsNull.Contains(i) { + continue + } + v2, _ := p2.GetValue(i) + y, err := v2.Scale(-m) + if err != nil { + return err + } + rss[i] = cmpFn(v1, y) + } + } else { + scaleMy := -m + for i := uint64(0); i < length; i++ { + v2, _ := p2.GetValue(i) + y, err := v2.Scale(scaleMy) + if err != nil { + return err + } + rss[i] = cmpFn(v1, y) + } + } + } + } + + return nil + } + + if c2 { + v2, null2 := p2.GetValue(0) + if null2 { + nulls.AddRange(rsNull, 0, length) + } else { + if m >= 0 { + if p1.WithAnyNullValue() || rsAnyNull { + nulls.Or(rsNull, parameters[0].GetNulls(), rsNull) + for i := uint64(0); i < length; i++ { + if rsNull.Contains(i) { + continue + } + v1, _ := p1.GetValue(i) + x, err := v1.Scale(m) + if err != nil { + return err + } + rss[i] = cmpFn(x, v2) + } + } else { + for i := uint64(0); i < length; i++ { + v1, _ := p1.GetValue(i) + x, err := v1.Scale(m) + if err != nil { + return err + } + rss[i] = cmpFn(x, v2) + } + } + } else { + y, err := v2.Scale(-m) + if err != nil { + return err + } + if p1.WithAnyNullValue() || rsAnyNull { + nulls.Or(rsNull, parameters[0].GetNulls(), rsNull) + for i := uint64(0); i < length; i++ { + if rsNull.Contains(i) { + continue + } + v1, _ := p1.GetValue(i) + rss[i] = cmpFn(v1, y) + } + } else { + for i := uint64(0); i < length; i++ { + v1, _ := p1.GetValue(i) + rss[i] = cmpFn(v1, y) + } + } + } + } + return nil + } + + if p1.WithAnyNullValue() || p2.WithAnyNullValue() || rsAnyNull { + nulls.Or(rsNull, parameters[0].GetNulls(), rsNull) + nulls.Or(rsNull, parameters[1].GetNulls(), rsNull) + if m >= 0 { + for i := uint64(0); i < length; i++ { + if rsNull.Contains(i) { + continue + } + v1, _ := p1.GetValue(i) + v2, _ := p2.GetValue(i) + x, err := v1.Scale(m) + if err != nil { + return err + } + rss[i] = cmpFn(x, v2) + } + } else { + scaleMy := -m + for i := uint64(0); i < length; i++ { + if rsNull.Contains(i) { + continue + } + v1, _ := p1.GetValue(i) + v2, _ := p2.GetValue(i) + y, err := v2.Scale(scaleMy) + if err != nil { + return err + } + rss[i] = cmpFn(v1, y) + } + } + return nil + } + + if m >= 0 { + for i := uint64(0); i < length; i++ { + v1, _ := p1.GetValue(i) + v2, _ := p2.GetValue(i) + x, err := v1.Scale(m) + if err != nil { + return err + } + rss[i] = cmpFn(x, v2) + } + } else { + scaleMy := -m + for i := uint64(0); i < length; i++ { + v1, _ := p1.GetValue(i) + v2, _ := p2.GetValue(i) + y, err := v2.Scale(scaleMy) + if err != nil { + return err + } + rss[i] = cmpFn(v1, y) + } + } + return nil +} + func greatThanFn(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error { // Handle type mismatch for numeric types (fallback when plan-stage casting was not applied) if shouldUseTypeMismatchPath(parameters[0], parameters[1]) { @@ -869,6 +1102,10 @@ func greatThanFn(parameters []*vector.Vector, result vector.FunctionResultWrappe return valueDec128Compare(parameters, rs, uint64(length), func(a, b types.Decimal128) bool { return a.Compare(b) > 0 }, selectList) + case types.T_decimal256: + return valueDec256Compare(parameters, rs, uint64(length), func(a, b types.Decimal256) bool { + return a.Compare(b) > 0 + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixed[types.Rowid, types.Rowid, bool](parameters, rs, proc, length, func(a, b types.Rowid) bool { return a.GT(&b) @@ -986,6 +1223,10 @@ func greatEqualFn(parameters []*vector.Vector, result vector.FunctionResultWrapp return valueDec128Compare(parameters, rs, uint64(length), func(a, b types.Decimal128) bool { return a.Compare(b) >= 0 }, selectList) + case types.T_decimal256: + return valueDec256Compare(parameters, rs, uint64(length), func(a, b types.Decimal256) bool { + return a.Compare(b) >= 0 + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixed[types.Rowid, types.Rowid, bool](parameters, rs, proc, length, func(a, b types.Rowid) bool { return a.GE(&b) @@ -1103,6 +1344,10 @@ func notEqualFn(parameters []*vector.Vector, result vector.FunctionResultWrapper return valueDec128Compare(parameters, rs, uint64(length), func(a, b types.Decimal128) bool { return a != b }, selectList) + case types.T_decimal256: + return valueDec256Compare(parameters, rs, uint64(length), func(a, b types.Decimal256) bool { + return a != b + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixed[types.Rowid, types.Rowid, bool](parameters, rs, proc, length, func(a, b types.Rowid) bool { return !a.EQ(&b) @@ -1220,6 +1465,10 @@ func lessThanFn(parameters []*vector.Vector, result vector.FunctionResultWrapper return valueDec128Compare(parameters, rs, uint64(length), func(a, b types.Decimal128) bool { return a.Compare(b) < 0 }, selectList) + case types.T_decimal256: + return valueDec256Compare(parameters, rs, uint64(length), func(a, b types.Decimal256) bool { + return a.Compare(b) < 0 + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixed[types.Rowid, types.Rowid, bool](parameters, rs, proc, length, func(a, b types.Rowid) bool { return a.LT(&b) @@ -1337,6 +1586,10 @@ func lessEqualFn(parameters []*vector.Vector, result vector.FunctionResultWrappe return valueDec128Compare(parameters, rs, uint64(length), func(a, b types.Decimal128) bool { return a.Compare(b) <= 0 }, selectList) + case types.T_decimal256: + return valueDec256Compare(parameters, rs, uint64(length), func(a, b types.Decimal256) bool { + return a.Compare(b) <= 0 + }, selectList) case types.T_Rowid: return opBinaryFixedFixedToFixed[types.Rowid, types.Rowid, bool](parameters, rs, proc, length, func(a, b types.Rowid) bool { return a.LE(&b) diff --git a/pkg/sql/plan/function/func_compare_decimal256_test.go b/pkg/sql/plan/function/func_compare_decimal256_test.go new file mode 100644 index 0000000000000..2fac9555df9be --- /dev/null +++ b/pkg/sql/plan/function/func_compare_decimal256_test.go @@ -0,0 +1,579 @@ +// Copyright 2021 - 2022 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "strings" + "testing" + + "github.com/matrixorigin/matrixone/pkg/container/nulls" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/stretchr/testify/require" +) + +// issue #24565: a CASE/IFF whose branches need more than 38 digits is promoted +// to decimal256; comparing such a value with a decimal128 must work instead of +// failing with "bad value [DECIMAL256 DECIMAL128]". + +func Test_FixedTypeCastRule1_Decimal256Promotion(t *testing.T) { + // decimal256 vs decimal128 -> both promoted to decimal256, each keeps scale/width. + has, t1, t2 := fixedTypeCastRule1( + types.New(types.T_decimal256, 58, 20), + types.New(types.T_decimal128, 38, 20), + ) + require.True(t, has) + require.Equal(t, types.T_decimal256, t1.Oid) + require.Equal(t, types.T_decimal256, t2.Oid) + require.Equal(t, int32(20), t1.Scale) + require.Equal(t, int32(58), t1.Width) + require.Equal(t, int32(20), t2.Scale) + require.Equal(t, int32(38), t2.Width) + + // decimal128 vs decimal256 (reversed) -> still both decimal256. + has, t1, t2 = fixedTypeCastRule1( + types.New(types.T_decimal128, 38, 2), + types.New(types.T_decimal256, 58, 20), + ) + require.True(t, has) + require.Equal(t, types.T_decimal256, t1.Oid) + require.Equal(t, types.T_decimal256, t2.Oid) + + // decimal256 vs decimal256 -> both decimal256. + has, t1, t2 = fixedTypeCastRule1( + types.New(types.T_decimal256, 58, 20), + types.New(types.T_decimal256, 40, 5), + ) + require.True(t, has) + require.Equal(t, types.T_decimal256, t1.Oid) + require.Equal(t, types.T_decimal256, t2.Oid) + + // decimal256 vs a non-decimal/non-integer type (varchar) must NOT be + // handled by the promotion shortcut. + has, _, _ = fixedTypeCastRule1( + types.New(types.T_decimal256, 58, 20), + types.T_varchar.ToType(), + ) + require.False(t, has) +} + +// Test_FixedTypeCastRule1_Decimal256VsInteger covers the decimal256-vs-integer +// promotion path: a bare integer literal compared with a decimal256 value must +// promote both sides to decimal256 so the comparison binds (issue #24565 +// review). The integer side becomes decimal256(integralWidth, 0). +func Test_FixedTypeCastRule1_Decimal256VsInteger(t *testing.T) { + has, t1, t2 := fixedTypeCastRule1( + types.New(types.T_decimal256, 58, 20), + types.T_int64.ToType(), + ) + require.True(t, has) + require.Equal(t, types.T_decimal256, t1.Oid) + require.Equal(t, int32(20), t1.Scale) + require.Equal(t, types.T_decimal256, t2.Oid) + require.Equal(t, int32(0), t2.Scale) + require.Equal(t, int32(19), t2.Width) // int64 integral width + + // reversed: integer on the left. + has, t1, t2 = fixedTypeCastRule1( + types.T_uint32.ToType(), + types.New(types.T_decimal256, 40, 5), + ) + require.True(t, has) + require.Equal(t, types.T_decimal256, t1.Oid) + require.Equal(t, int32(0), t1.Scale) + require.Equal(t, int32(10), t1.Width) // uint32 integral width + require.Equal(t, types.T_decimal256, t2.Oid) + require.Equal(t, int32(5), t2.Scale) +} + +// Test_FixedTypeCastRule1_Decimal256VsFloat covers decimal256 comparisons +// against approximate numeric values. Existing decimal64/decimal128 vs float +// rules cast both sides to float64; decimal256 must bind the same way. +func Test_FixedTypeCastRule1_Decimal256VsFloat(t *testing.T) { + has, t1, t2 := fixedTypeCastRule1( + types.New(types.T_decimal256, 58, 20), + types.T_float64.ToType(), + ) + require.True(t, has) + require.Equal(t, types.T_float64, t1.Oid) + require.Equal(t, types.T_float64, t2.Oid) + + has, t1, t2 = fixedTypeCastRule1( + types.T_float32.ToType(), + types.New(types.T_decimal256, 58, 20), + ) + require.True(t, has) + require.Equal(t, types.T_float64, t1.Oid) + require.Equal(t, types.T_float64, t2.Oid) +} + +// Test_IsNumericType_Decimal256 verifies decimal256 is recognised by the +// runtime type-mismatch fallback path (issue #24565 review). +func Test_IsNumericType_Decimal256(t *testing.T) { + require.True(t, isNumericType(types.T_decimal256)) + require.True(t, isNumericType(types.T_decimal128)) + require.False(t, isNumericType(types.T_varchar)) +} + +// Test_DecimalHelpers covers the small decimal helper funcs used by the BETWEEN +// and comparison binding paths. +func Test_DecimalHelpers(t *testing.T) { + require.True(t, isDecimalOrInteger(types.New(types.T_decimal128, 38, 2))) + require.True(t, isDecimalOrInteger(types.T_int64.ToType())) + require.False(t, isDecimalOrInteger(types.T_float64.ToType())) + require.False(t, isDecimalOrInteger(types.T_varchar.ToType())) + + require.Equal(t, types.T_decimal64, widestDecimalFamily([]types.Type{ + types.New(types.T_decimal64, 18, 2), types.T_int64.ToType()})) + require.Equal(t, types.T_decimal128, widestDecimalFamily([]types.Type{ + types.New(types.T_decimal64, 18, 2), types.New(types.T_decimal128, 38, 2)})) + require.Equal(t, types.T_decimal256, widestDecimalFamily([]types.Type{ + types.New(types.T_decimal128, 38, 2), types.New(types.T_decimal256, 76, 2)})) + + // decimal256FromSource: decimal256 unchanged, decimal preserves width/scale, + // integer becomes (integralWidth, 0). + d256 := types.New(types.T_decimal256, 50, 10) + require.Equal(t, d256, decimal256FromSource(d256)) + got := decimal256FromSource(types.New(types.T_decimal128, 38, 6)) + require.Equal(t, types.T_decimal256, got.Oid) + require.Equal(t, int32(38), got.Width) + require.Equal(t, int32(6), got.Scale) + got = decimal256FromSource(types.T_int64.ToType()) + require.Equal(t, types.T_decimal256, got.Oid) + require.Equal(t, int32(19), got.Width) + require.Equal(t, int32(0), got.Scale) +} + +// Test_BetweenCheckFn_Decimal256Promotion verifies the BETWEEN checkFn computes +// a common decimal type preserving max integral width + max scale and promotes +// to decimal256 when decimal128 would overflow (issue #24565 review). Example: +// DECIMAL(38,0) BETWEEN DECIMAL(38,30) AND DECIMAL(38,30) needs 38 integral + 30 +// scale = 68 digits, i.e. DECIMAL256(68,30) (scale 30 is the 3.0-dev max). +func Test_BetweenCheckFn_Decimal256Promotion(t *testing.T) { + var betweenFunc *FuncNew + for i := range supportedOperators { + if supportedOperators[i].functionId == BETWEEN { + betweenFunc = &supportedOperators[i] + break + } + } + require.NotNil(t, betweenFunc, "BETWEEN operator should be defined") + + // DECIMAL(38,0) BETWEEN DECIMAL(38,30) AND DECIMAL(38,30) -> DECIMAL256(68,30). + res := betweenFunc.checkFn(betweenFunc.Overloads, []types.Type{ + types.New(types.T_decimal128, 38, 0), + types.New(types.T_decimal128, 38, 30), + types.New(types.T_decimal128, 38, 30), + }) + require.Equal(t, succeedWithCast, res.status) + require.Len(t, res.finalType, 3) + for _, ft := range res.finalType { + require.Equal(t, types.T_decimal256, ft.Oid) + require.Equal(t, int32(30), ft.Scale) + require.Equal(t, int32(68), ft.Width) + } + + // mixed family (dec64 value, dec128 bounds) must not be rejected; all three + // aligned to a common decimal type. + res = betweenFunc.checkFn(betweenFunc.Overloads, []types.Type{ + types.New(types.T_decimal64, 18, 2), + types.New(types.T_decimal128, 38, 4), + types.New(types.T_decimal128, 38, 4), + }) + require.Equal(t, succeedWithCast, res.status) + require.Len(t, res.finalType, 3) + for _, ft := range res.finalType { + require.True(t, ft.Oid.IsDecimal()) + require.Equal(t, res.finalType[0].Oid, ft.Oid) + require.Equal(t, int32(4), ft.Scale) + } + + // decimal value with integer bounds stays numeric and binds. + res = betweenFunc.checkFn(betweenFunc.Overloads, []types.Type{ + types.New(types.T_decimal128, 20, 2), + types.T_int64.ToType(), + types.T_int64.ToType(), + }) + require.Equal(t, succeedWithCast, res.status) + require.Len(t, res.finalType, 3) + for _, ft := range res.finalType { + require.True(t, ft.Oid.IsDecimal()) + } + + // decimal256 compared with approximate numeric bounds follows the existing + // decimal-vs-float rule and casts all operands to float64. + res = betweenFunc.checkFn(betweenFunc.Overloads, []types.Type{ + types.New(types.T_decimal256, 58, 20), + types.T_float32.ToType(), + types.T_float64.ToType(), + }) + require.Equal(t, succeedWithCast, res.status) + require.Len(t, res.finalType, 3) + for _, ft := range res.finalType { + require.Equal(t, types.T_float64, ft.Oid) + } +} + +func Test_CompareOperatorSupports_Decimal256(t *testing.T) { + d256 := types.New(types.T_decimal256, 58, 20) + require.True(t, equalAndNotEqualOperatorSupports(d256, d256)) + require.True(t, otherCompareOperatorSupports(d256, d256)) +} + +func Test_EqualFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + + // equal, same scale + tc := tcTemp{ + info: "decimal256 = decimal256 (same scale)", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(6)}, + []bool{false, false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(7)}, + []bool{false, false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, false}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, equalFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_EqualFn_Decimal256_DifferentScale(t *testing.T) { + proc := testutil.NewProcess(t) + + // 5 (scale 0) compared to 500 (scale 2) == 5.00 -> equal. + tc := tcTemp{ + info: "decimal256 = decimal256 (different scale)", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5)}, []bool{false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 2), + []types.Decimal256{types.Decimal256FromInt64(500)}, []bool{false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true}, []bool{false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, equalFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_GreatThanFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + tc := tcTemp{ + info: "decimal256 > decimal256", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(1)}, + []bool{false, false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(1), types.Decimal256FromInt64(5)}, + []bool{false, false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, false}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, greatThanFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_NotEqualFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + tc := tcTemp{ + info: "decimal256 != decimal256", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(6)}, + []bool{false, false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(7)}, + []bool{false, false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{false, true}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, notEqualFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_NullSafeEqualFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + tc := tcTemp{ + info: "decimal256 <=> decimal256 with nulls", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(0)}, + []bool{false, true}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(0)}, + []bool{false, true}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, true}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, nullSafeEqualFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_Decimal256TypeMismatchCompare(t *testing.T) { + proc := testutil.NewProcess(t) + + t.Run("decimal256 equals float64", func(t *testing.T) { + tc := tcTemp{ + info: "decimal256 = float64 type mismatch", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 65, 2), + []types.Decimal256{ + types.Decimal256FromInt64(100), + types.Decimal256FromInt64(250), + types.Decimal256FromInt64(300), + }, nil), + NewFunctionTestInput(types.T_float64.ToType(), + []float64{1.0, 2.0, 2.5}, nil), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, false, false}, nil), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, equalFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) + }) + + t.Run("decimal256 greater than decimal128", func(t *testing.T) { + tc := tcTemp{ + info: "decimal256 > decimal128 type mismatch", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 65, 2), + []types.Decimal256{ + types.Decimal256FromInt64(100), + types.Decimal256FromInt64(250), + types.Decimal256FromInt64(300), + }, nil), + NewFunctionTestInput(types.New(types.T_decimal128, 38, 2), + []types.Decimal128{ + {B0_63: 100}, + {B0_63: 200}, + {B0_63: 350}, + }, nil), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{false, true, false}, nil), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, greatThanFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) + }) +} + +// Test_ValueDec256Compare_AllBranches exercises the const/variable, null and +// scale-direction (m>=0 / m<0) branches of valueDec256Compare directly, so the +// decimal256 comparison helper is covered beyond the simple variable path. +func Test_ValueDec256Compare_AllBranches(t *testing.T) { + proc := testutil.NewProcess(t) + mp := proc.Mp() + eq := func(a, b types.Decimal256) bool { return a == b } + + mkVar := func(scale int32, vals []int64, nullList []bool) *vector.Vector { + vec := vector.NewVec(types.New(types.T_decimal256, 76, scale)) + ds := make([]types.Decimal256, len(vals)) + for i, v := range vals { + ds[i] = types.Decimal256FromInt64(v) + } + require.NoError(t, vector.AppendFixedList(vec, ds, nullList, mp)) + return vec + } + mkConst := func(scale int32, val int64, length int) *vector.Vector { + vec, err := vector.NewConstFixed(types.New(types.T_decimal256, 76, scale), + types.Decimal256FromInt64(val), length, mp) + require.NoError(t, err) + return vec + } + run := func(p1, p2 *vector.Vector, length int) ([]bool, *nulls.Nulls) { + w := vector.NewFunctionResultWrapper(types.T_bool.ToType(), mp) + require.NoError(t, w.PreExtendAndReset(length)) + rs := vector.MustFunctionResult[bool](w) + require.NoError(t, valueDec256Compare([]*vector.Vector{p1, p2}, rs, uint64(length), eq, nil)) + rsVec := rs.GetResultVector() + col := vector.MustFixedColWithTypeCheck[bool](rsVec) + out := make([]bool, length) + copy(out, col) + return out, rsVec.GetNulls() + } + + // const vs const: m>=0 then m<0 (5 == 5.00). + r, _ := run(mkConst(0, 5, 2), mkConst(2, 500, 2), 2) + require.Equal(t, []bool{true, true}, r) + r, _ = run(mkConst(2, 500, 2), mkConst(0, 5, 2), 2) + require.Equal(t, []bool{true, true}, r) + + // c1 (p1 const) m>=0: with null then without null. + r, ns := run(mkConst(0, 5, 2), mkVar(2, []int64{500, 0}, []bool{false, true}), 2) + require.True(t, r[0]) + require.True(t, ns.Contains(1)) + r, _ = run(mkConst(0, 5, 2), mkVar(2, []int64{500, 400}, nil), 2) + require.Equal(t, []bool{true, false}, r) + + // c1 m<0: with null then without null. + r, ns = run(mkConst(2, 500, 2), mkVar(0, []int64{5, 0}, []bool{false, true}), 2) + require.True(t, r[0]) + require.True(t, ns.Contains(1)) + r, _ = run(mkConst(2, 500, 2), mkVar(0, []int64{5, 6}, nil), 2) + require.Equal(t, []bool{true, false}, r) + + // c2 (p2 const) m>=0: with null then without null. + r, ns = run(mkVar(0, []int64{5, 0}, []bool{false, true}), mkConst(2, 500, 2), 2) + require.True(t, r[0]) + require.True(t, ns.Contains(1)) + r, _ = run(mkVar(0, []int64{5, 4}, nil), mkConst(2, 500, 2), 2) + require.Equal(t, []bool{true, false}, r) + + // c2 m<0: with null then without null. + r, ns = run(mkVar(2, []int64{500, 0}, []bool{false, true}), mkConst(0, 5, 2), 2) + require.True(t, r[0]) + require.True(t, ns.Contains(1)) + r, _ = run(mkVar(2, []int64{500, 600}, nil), mkConst(0, 5, 2), 2) + require.Equal(t, []bool{true, false}, r) + + // var vs var with null: m>=0 then m<0. + r, ns = run(mkVar(0, []int64{5, 0}, []bool{false, true}), mkVar(2, []int64{500, 500}, nil), 2) + require.True(t, r[0]) + require.True(t, ns.Contains(1)) + r, ns = run(mkVar(2, []int64{500, 0}, []bool{false, true}), mkVar(0, []int64{5, 5}, nil), 2) + require.True(t, r[0]) + require.True(t, ns.Contains(1)) + + // var vs var without null: m>=0 then m<0. + r, _ = run(mkVar(0, []int64{5, 6}, nil), mkVar(2, []int64{500, 500}, nil), 2) + require.Equal(t, []bool{true, false}, r) + r, _ = run(mkVar(2, []int64{500, 600}, nil), mkVar(0, []int64{5, 7}, nil), 2) + require.Equal(t, []bool{true, false}, r) +} + +// Test_ValueDec256Compare_ScaleOverflowError verifies that an overflow from +// Decimal256.Scale() (extreme scale span) is propagated as an error instead of +// comparing un-scaled raw values (issue #24565 review). +func Test_ValueDec256Compare_ScaleOverflowError(t *testing.T) { + proc := testutil.NewProcess(t) + mp := proc.Mp() + big, err := types.ParseDecimal256(strings.Repeat("9", 76), 76, 0) + require.NoError(t, err) + p1, err := vector.NewConstFixed(types.New(types.T_decimal256, 76, 0), big, 1, mp) + require.NoError(t, err) + p2, err := vector.NewConstFixed(types.New(types.T_decimal256, 76, 30), types.Decimal256FromInt64(1), 1, mp) + require.NoError(t, err) + w := vector.NewFunctionResultWrapper(types.T_bool.ToType(), mp) + require.NoError(t, w.PreExtendAndReset(1)) + rs := vector.MustFunctionResult[bool](w) + err = valueDec256Compare([]*vector.Vector{p1, p2}, rs, 1, + func(a, b types.Decimal256) bool { return a == b }, nil) + require.Error(t, err) +} + +func Test_LessThanFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + tc := tcTemp{ + info: "decimal256 < decimal256", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(1), types.Decimal256FromInt64(5)}, + []bool{false, false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(1)}, + []bool{false, false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, false}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, lessThanFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_LessEqualFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + tc := tcTemp{ + info: "decimal256 <= decimal256", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(6)}, + []bool{false, false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(1)}, + []bool{false, false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, false}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, lessEqualFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +func Test_GreatEqualFn_Decimal256(t *testing.T) { + proc := testutil.NewProcess(t) + tc := tcTemp{ + info: "decimal256 >= decimal256", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(1)}, + []bool{false, false}), + NewFunctionTestInput(types.New(types.T_decimal256, 58, 0), + []types.Decimal256{types.Decimal256FromInt64(5), types.Decimal256FromInt64(5)}, + []bool{false, false}), + }, + expect: NewFunctionTestResult(types.T_bool.ToType(), false, + []bool{true, false}, []bool{false, false}), + } + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, greatEqualFn) + s, info := fcTC.Run() + require.True(t, s, info, tc.info) +} + +// Test_BetweenImpl_Decimal256_NonConst exercises the runtime decimal256 path of +// BETWEEN with a non-constant left operand, which previously fell through to +// panic("unreached code") because betweenImpl had no decimal256 branch +// (issue #24565 review). +func Test_BetweenImpl_Decimal256_NonConst(t *testing.T) { + proc := testutil.NewProcess(t) + mp := proc.Mp() + + // p0: non-constant decimal256 column; p1/p2: constant bounds [1, 100]. + p0 := vector.NewVec(types.New(types.T_decimal256, 76, 0)) + require.NoError(t, vector.AppendFixedList(p0, []types.Decimal256{ + types.Decimal256FromInt64(5), + types.Decimal256FromInt64(50), + types.Decimal256FromInt64(500), + }, nil, mp)) + p1, err := vector.NewConstFixed(types.New(types.T_decimal256, 76, 0), types.Decimal256FromInt64(1), 3, mp) + require.NoError(t, err) + p2, err := vector.NewConstFixed(types.New(types.T_decimal256, 76, 0), types.Decimal256FromInt64(100), 3, mp) + require.NoError(t, err) + + w := vector.NewFunctionResultWrapper(types.T_bool.ToType(), mp) + require.NoError(t, w.PreExtendAndReset(3)) + require.NoError(t, betweenImpl([]*vector.Vector{p0, p1, p2}, w, proc, 3, nil)) + + rs := vector.MustFunctionResult[bool](w) + col := vector.MustFixedColWithTypeCheck[bool](rs.GetResultVector()) + require.Equal(t, []bool{true, true, false}, col[:3]) +} diff --git a/pkg/sql/plan/function/func_compare_fix.go b/pkg/sql/plan/function/func_compare_fix.go index c521db1be5f97..4e8375b931d37 100644 --- a/pkg/sql/plan/function/func_compare_fix.go +++ b/pkg/sql/plan/function/func_compare_fix.go @@ -562,6 +562,17 @@ func getAsCompareValueSlice(v *vector.Vector) []numericCompareValue { result[i] = numericCompareValue{kind: numericCompareFinite, rat: r} } return result + case types.T_decimal256: + cols := vector.MustFixedColNoTypeCheck[types.Decimal256](v) + result := make([]numericCompareValue, len(cols)) + for i, val := range cols { + r, ok := new(big.Rat).SetString(val.Format(t.Scale)) + if !ok { + panic("invalid decimal256 value for numeric comparison") + } + result[i] = numericCompareValue{kind: numericCompareFinite, rat: r} + } + return result default: return nil } @@ -649,6 +660,13 @@ func getAsFloat64Slice(v *vector.Vector) []float64 { result[i] = types.Decimal128ToFloat64(val, t.Scale) } return result + case types.T_decimal256: + cols := vector.MustFixedColNoTypeCheck[types.Decimal256](v) + result := make([]float64, len(cols)) + for i, val := range cols { + result[i] = types.Decimal256ToFloat64(val, t.Scale) + } + return result default: return nil } @@ -671,7 +689,7 @@ func isNumericType(t types.T) bool { case types.T_int8, types.T_int16, types.T_int32, types.T_int64, types.T_uint8, types.T_uint16, types.T_uint32, types.T_uint64, types.T_float32, types.T_float64, - types.T_decimal64, types.T_decimal128: + types.T_decimal64, types.T_decimal128, types.T_decimal256: return true default: return false diff --git a/pkg/sql/plan/function/func_compare_fix_test.go b/pkg/sql/plan/function/func_compare_fix_test.go index 52c02d35eb808..91ebd7ae51bbf 100644 --- a/pkg/sql/plan/function/func_compare_fix_test.go +++ b/pkg/sql/plan/function/func_compare_fix_test.go @@ -15,6 +15,7 @@ package function import ( + "math/big" "testing" "github.com/matrixorigin/matrixone/pkg/common/mpool" @@ -46,6 +47,7 @@ func TestShouldUseTypeMismatchPath(t *testing.T) { {"decimal64 vs decimal128 - mismatch", types.T_decimal64, types.T_decimal128, true}, {"decimal128 vs decimal128 - same type", types.T_decimal128, types.T_decimal128, false}, {"decimal64 vs float64 - mismatch", types.T_decimal64, types.T_float64, true}, + {"decimal256 vs float64 - mismatch", types.T_decimal256, types.T_float64, true}, {"varchar vs int64 - not numeric", types.T_varchar, types.T_int64, false}, {"int64 vs varchar - not numeric", types.T_int64, types.T_varchar, false}, {"varchar vs varchar - not numeric", types.T_varchar, types.T_varchar, false}, @@ -86,6 +88,7 @@ func TestIsNumericType(t *testing.T) { {types.T_datetime, false}, {types.T_decimal64, true}, {types.T_decimal128, true}, + {types.T_decimal256, true}, {types.T_bool, false}, } @@ -200,6 +203,39 @@ func TestGetAsFloat64Slice(t *testing.T) { result := getAsFloat64Slice(v) require.InDeltaSlice(t, []float64{1.00, 2.50, 0.50}, result, 0.0001) }) + + t.Run("decimal256 to float64", func(t *testing.T) { + typ := types.New(types.T_decimal256, 65, 2) + v := vector.NewVec(typ) + defer v.Free(mp) + require.NoError(t, vector.AppendFixedList(v, []types.Decimal256{ + types.Decimal256FromInt64(100), + types.Decimal256FromInt64(250), + types.Decimal256FromInt64(50), + }, nil, mp)) + + result := getAsFloat64Slice(v) + require.InDeltaSlice(t, []float64{1.00, 2.50, 0.50}, result, 0.0001) + }) +} + +func TestGetAsCompareValueSliceDecimal256(t *testing.T) { + proc := testutil.NewProcess(t) + mp := proc.Mp() + typ := types.New(types.T_decimal256, 65, 2) + v := vector.NewVec(typ) + defer v.Free(mp) + require.NoError(t, vector.AppendFixedList(v, []types.Decimal256{ + types.Decimal256FromInt64(100), + types.Decimal256FromInt64(250), + }, nil, mp)) + + result := getAsCompareValueSlice(v) + require.Len(t, result, 2) + require.Equal(t, numericCompareFinite, result[0].kind) + require.Equal(t, 0, result[0].rat.Cmp(big.NewRat(1, 1))) + require.Equal(t, numericCompareFinite, result[1].kind) + require.Equal(t, 0, result[1].rat.Cmp(big.NewRat(5, 2))) } // TestLessEqualFnWithTypeMismatch tests lessEqualFn with mismatched numeric types diff --git a/pkg/sql/plan/function/list_operator.go b/pkg/sql/plan/function/list_operator.go index 81114b496c354..f9c099edbe376 100644 --- a/pkg/sql/plan/function/list_operator.go +++ b/pkg/sql/plan/function/list_operator.go @@ -272,29 +272,40 @@ var supportedOperators = []FuncNew{ return newCheckResultWithFailure(failedFunctionParametersWrong) } + // All-numeric decimal BETWEEN: betweenImpl compares raw decimal + // values without rescaling at runtime, so the three operands must + // share one identical decimal type. Compute a common type that + // preserves the max integral width AND max scale across the three + // operands (issue #24565), promoting to decimal256 when decimal128 + // would overflow, instead of only aligning scale (which could keep + // an under-sized width / family and overflow) or rejecting + // mixed-scale/mixed-family decimals. + anyDecimal := inputs[0].Oid.IsDecimal() || inputs[1].Oid.IsDecimal() || inputs[2].Oid.IsDecimal() + allDecimalOrInt := isDecimalOrInteger(inputs[0]) && isDecimalOrInteger(inputs[1]) && isDecimalOrInteger(inputs[2]) + if anyDecimal && allDecimalOrInt { + target := widestDecimalFamily(inputs).ToType() + if !setSafeDecimalWidthAndScaleFromSource(&target, inputs) { + return newCheckResultWithFailure(failedFunctionParametersWrong) + } + if !otherCompareOperatorSupports(target, target) { + return newCheckResultWithFailure(failedFunctionParametersWrong) + } + return newCheckResultWithCast(0, []types.Type{target, target, target}) + } + has0, t01, t1 := fixedTypeCastRule1(inputs[0], inputs[1]) has1, t02, t2 := fixedTypeCastRule1(inputs[0], inputs[2]) if t01.Oid != t02.Oid { return newCheckResultWithFailure(failedFunctionParametersWrong) } - - if t01.Oid == types.T_decimal64 || t01.Oid == types.T_decimal128 || t01.Oid == types.T_decimal256 { - if t01.Scale != t1.Scale || t02.Scale != t2.Scale { - return newCheckResultWithFailure(failedFunctionParametersWrong) - } + if !otherCompareOperatorSupports(t01, t1) || !otherCompareOperatorSupports(t01, t2) { + return newCheckResultWithFailure(failedFunctionParametersWrong) } if has0 || has1 { - if otherCompareOperatorSupports(t01, t1) && otherCompareOperatorSupports(t01, t2) { - return newCheckResultWithCast(0, []types.Type{t01, t1, t2}) - } - } else { - if otherCompareOperatorSupports(t01, t1) && otherCompareOperatorSupports(t01, t2) { - return newCheckResultWithSuccess(0) - } + return newCheckResultWithCast(0, []types.Type{t01, t1, t2}) } - - return newCheckResultWithFailure(failedFunctionParametersWrong) + return newCheckResultWithSuccess(0) }, Overloads: []overload{ @@ -2522,6 +2533,16 @@ var supportedOperators = []FuncNew{ return CoalesceStr }, }, + { + overloadId: 25, + args: []types.T{types.T_decimal256}, + retType: func(parameters []types.Type) types.Type { + return parameters[0] + }, + newOp: func() executeLogicOfOverload { + return CoalesceGeneral[types.Decimal256] + }, + }, }, }, diff --git a/pkg/sql/plan/function/operatorSet_test.go b/pkg/sql/plan/function/operatorSet_test.go index 3a324f230d9c7..da64a0c48adf0 100644 --- a/pkg/sql/plan/function/operatorSet_test.go +++ b/pkg/sql/plan/function/operatorSet_test.go @@ -704,6 +704,97 @@ func Test_CoalesceCheck_MixedStringNumeric(t *testing.T) { } } +// issue #24565: COALESCE over decimal branches with different scales must align +// scale/width across all branches, otherwise the result inherits the first +// branch's scale while carrying another branch's raw value (magnified result). +func Test_CoalesceCheck_DecimalScaleAlignment(t *testing.T) { + overloads := []overload{ + {args: []types.T{types.T_decimal64}}, + {args: []types.T{types.T_decimal128}}, + } + inputs := []types.Type{ + types.New(types.T_decimal128, 23, 2), + types.New(types.T_decimal128, 38, 7), + } + result := coalesceCheck(overloads, inputs) + require.Equal(t, succeedWithCast, result.status) + require.Equal(t, 1, result.idx) // decimal128 overload + require.Len(t, result.finalType, len(inputs)) + for _, typ := range result.finalType { + require.Equal(t, types.T_decimal128, typ.Oid) + require.Equal(t, int32(7), typ.Scale) + require.Equal(t, int32(38), typ.Width) + } +} + +func Test_CoalesceCheck_DecimalAligned_NoCast(t *testing.T) { + overloads := []overload{ + {args: []types.T{types.T_decimal64}}, + {args: []types.T{types.T_decimal128}}, + } + inputs := []types.Type{ + types.New(types.T_decimal128, 20, 4), + types.New(types.T_decimal128, 20, 4), + } + result := coalesceCheck(overloads, inputs) + // Already aligned -> no alignment cast needed, plain direct match. + require.Equal(t, succeedMatched, result.status) + require.Equal(t, 1, result.idx) +} + +// issue #24565 review: when the combined integral width + scale overflows +// decimal128, coalesce must promote the common type to decimal256 instead of +// keeping decimal128 (which would drop integer capacity and overflow on cast). +func Test_CoalesceCheck_DecimalPromoteToDecimal256(t *testing.T) { + overloads := []overload{ + {args: []types.T{types.T_decimal64}}, + {args: []types.T{types.T_decimal128}}, + {args: []types.T{types.T_decimal256}}, + } + inputs := []types.Type{ + types.New(types.T_decimal128, 38, 0), // 38 integral digits + types.New(types.T_decimal128, 38, 38), // 38 fractional digits + } + result := coalesceCheck(overloads, inputs) + require.Equal(t, succeedWithCast, result.status) + require.Equal(t, 2, result.idx) // decimal256 overload + require.Len(t, result.finalType, len(inputs)) + for _, typ := range result.finalType { + require.Equal(t, types.T_decimal256, typ.Oid) + require.Equal(t, int32(38), typ.Scale) + require.Equal(t, int32(76), typ.Width) // 38 integral + 38 scale + } +} + +// required precision overflows decimal256 -> coalesce fails instead of +// silently truncating. +func Test_CoalesceCheck_DecimalOverflowFails(t *testing.T) { + overloads := []overload{ + {args: []types.T{types.T_decimal128}}, + {args: []types.T{types.T_decimal256}}, + } + inputs := []types.Type{ + types.New(types.T_decimal256, 76, 0), + types.New(types.T_decimal256, 76, 76), // 76 integral + 76 scale = 152 > 76 + } + result := coalesceCheck(overloads, inputs) + require.Equal(t, failedFunctionParametersWrong, result.status) +} + +// the aligned decimal type has no matching overload -> coalesce fails. +func Test_CoalesceCheck_DecimalNoOverloadFails(t *testing.T) { + overloads := []overload{ + {args: []types.T{types.T_decimal64}}, + {args: []types.T{types.T_decimal128}}, + } // no decimal256 overload + inputs := []types.Type{ + types.New(types.T_decimal128, 38, 0), + types.New(types.T_decimal128, 38, 38), // requires decimal256 + } + result := coalesceCheck(overloads, inputs) + require.Equal(t, failedFunctionParametersWrong, result.status) +} + func Test_CaseWhen_WithNullAndStringComparison(t *testing.T) { // Test CASE WHEN with NULL value compared to string // This should not error, matching MySQL behavior diff --git a/pkg/sql/plan/function/operator_between.go b/pkg/sql/plan/function/operator_between.go index 07984727fa56d..d1ed7ec431d92 100644 --- a/pkg/sql/plan/function/operator_between.go +++ b/pkg/sql/plan/function/operator_between.go @@ -73,6 +73,10 @@ func betweenImpl(parameters []*vector.Vector, result vector.FunctionResultWrappe return opBetweenFixedWithFn(parameters, rs, proc, length, func(lhs, rhs types.Decimal128) bool { return lhs.Compare(rhs) <= 0 }) + case types.T_decimal256: + return opBetweenFixedWithFn(parameters, rs, proc, length, func(lhs, rhs types.Decimal256) bool { + return lhs.Compare(rhs) <= 0 + }) case types.T_Rowid: return opBetweenFixedWithFn(parameters, rs, proc, length, func(lhs, rhs types.Rowid) bool { return lhs.LE(&rhs) diff --git a/pkg/sql/plan/function/type_check.go b/pkg/sql/plan/function/type_check.go index 43a11d74f0e3f..11a0c63d7f4b2 100644 --- a/pkg/sql/plan/function/type_check.go +++ b/pkg/sql/plan/function/type_check.go @@ -26,6 +26,19 @@ import ( // 3. >= > < <= // 4. Mod func fixedTypeCastRule1(s1, s2 types.Type) (bool, types.Type, types.Type) { + // decimal256 is not present in the static binary cast matrix below. + // When one side is decimal256, handle the cases missing from the static + // binary cast matrix below. Decimal/integer pairs keep exact decimal256 + // semantics. Float pairs follow the existing decimal64/decimal128 vs float + // rule and use approximate float64 comparison. + if s1.Oid == types.T_decimal256 || s2.Oid == types.T_decimal256 { + if isDecimalOrInteger(s1) && isDecimalOrInteger(s2) { + return true, decimal256FromSource(s1), decimal256FromSource(s2) + } + if isFloatType(s1.Oid) || isFloatType(s2.Oid) { + return true, types.T_float64.ToType(), types.T_float64.ToType() + } + } check := fixedBinaryCastRule1[s1.Oid][s2.Oid] if check.cast { t1, t2 := check.left.ToType(), check.right.ToType() @@ -481,6 +494,42 @@ func integerIntegralWidth(oid types.T) int32 { } } +// isDecimalOrInteger reports whether t is a decimal or an integer type, i.e. a +// numeric type that can be losslessly cast into a wider decimal. +func isDecimalOrInteger(t types.Type) bool { + return t.Oid.IsDecimal() || t.IsIntOrUint() +} + +// decimal256FromSource converts a decimal-or-integer source type into the +// decimal256 type used for a decimal256 comparison. A decimal source keeps its +// own width/scale; an integer source becomes decimal256(integralWidth, 0). +func decimal256FromSource(s types.Type) types.Type { + if s.Oid == types.T_decimal256 { + return s + } + if s.Oid.IsDecimal() { + return types.New(types.T_decimal256, s.Width, s.Scale) + } + return types.New(types.T_decimal256, integerIntegralWidth(s.Oid), 0) +} + +// widestDecimalFamily returns the widest decimal family present among the +// inputs (decimal256 > decimal128 > decimal64). It is used as the floor family +// for setSafeDecimalWidthAndScaleFromSource so the computed common type never +// drops below any operand's family. +func widestDecimalFamily(inputs []types.Type) types.T { + widest := types.T_decimal64 + for i := range inputs { + switch inputs[i].Oid { + case types.T_decimal256: + return types.T_decimal256 + case types.T_decimal128: + widest = types.T_decimal128 + } + } + return widest +} + func setMaxWidthFromSource(t *types.Type, source []types.Type) { t.Width = -1 for i := range source { diff --git a/test/distributed/cases/expression/case_when.result b/test/distributed/cases/expression/case_when.result index ddcd342b89dc6..b205431d439a4 100755 --- a/test/distributed/cases/expression/case_when.result +++ b/test/distributed/cases/expression/case_when.result @@ -278,3 +278,78 @@ CAST(1 AS DECIMAL(38,0)), CAST(0 AS DECIMAL(38,20))) AS iff_decimal256_false; iff_decimal256_false 0E-20 +SELECT 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) AS direct_mul; +direct_mul +-408125.3580000 +SELECT COALESCE( +CAST(NULL AS DECIMAL(23,2)), +7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) +) AS coalesce_decimal_scale; +coalesce_decimal_scale +-408125.3580000 +SELECT COALESCE( +CAST(1.23 AS DECIMAL(23,2)), +7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) +) AS coalesce_first_non_null; +coalesce_first_non_null +1.2300000 +SELECT (CASE WHEN 1 = 1 THEN CAST(1 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) += CAST(1 AS DECIMAL(38,20)) AS decimal256_eq_decimal128; +decimal256_eq_decimal128 +1 +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) +> CAST(1 AS DECIMAL(38,20)) AS decimal256_gt_decimal128; +decimal256_gt_decimal128 +1 +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) +< CAST(1 AS DECIMAL(38,20)) AS decimal256_lt_decimal128; +decimal256_lt_decimal128 +0 +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) +!= CAST(1 AS DECIMAL(38,20)) AS decimal256_ne_decimal128; +decimal256_ne_decimal128 +1 +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) +BETWEEN CAST(1 AS DECIMAL(38,20)) AND CAST(10 AS DECIMAL(38,20)) AS decimal256_between; +decimal256_between +1 +SELECT COALESCE(CAST(1 AS DECIMAL(38,0)), CAST(0.5 AS DECIMAL(30,30))) AS coalesce_promote_decimal256; +coalesce_promote_decimal256 +1.000000000000000000000000000000 +SELECT COALESCE(CAST(12345678901234567890123456789012345678 AS DECIMAL(38,0)), CAST(0.5 AS DECIMAL(30,30))) AS coalesce_promote_bignum; +coalesce_promote_bignum +12345678901234567890123456789012345678.000000000000000000000000000000 +drop table if exists t_dec256_between; +create table t_dec256_between (a decimal(38,0)); +insert into t_dec256_between values (5),(50),(500); +select +(case when 1 = 1 then a else cast(0 as decimal(38,30)) end) +between cast(1 as decimal(38,20)) and cast(100 as decimal(38,20)) as in_range +from t_dec256_between order by a; +in_range +1 +1 +0 +drop table t_dec256_between; +SELECT (CASE WHEN 1 = 1 THEN CAST(1 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) = 1 AS decimal256_eq_int; +decimal256_eq_int +1 +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) > 1 AS decimal256_gt_int; +decimal256_gt_int +1 +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) +ELSE CAST(0 AS DECIMAL(38,20)) END) +BETWEEN 1 AND 10 AS decimal256_between_int; +decimal256_between_int +1 +SELECT CAST(1 AS DECIMAL(38,0)) +BETWEEN CAST(0.5 AS DECIMAL(38,30)) AND CAST(2 AS DECIMAL(38,30)) AS between_promote_decimal256; +between_promote_decimal256 +1 diff --git a/test/distributed/cases/expression/case_when.sql b/test/distributed/cases/expression/case_when.sql index a33e6a9fd152c..2cb417ea04b5d 100755 --- a/test/distributed/cases/expression/case_when.sql +++ b/test/distributed/cases/expression/case_when.sql @@ -1,225 +1,292 @@ --- @suit - --- @case --- @desc:test for case_when expression with constant operand --- @label:bvt -select CASE "b" when "a" then 1 when "b" then 2 END; -select CASE "c" when "a" then 1 when "b" then 2 END; -select CASE "c" when "a" then 1 when "b" then 2 ELSE 3 END; -select CASE when 1=0 then "true" else "false" END; -select CASE 1 when 1 then "one" WHEN 2 then "two" ELSE "more" END; -select CASE 2.0 when 1 then "one" WHEN 2.0 then "two" ELSE "more" END; - -select (CASE "two" when "one" then "1" WHEN "two" then "2" END) | 0; - -select (CASE "two" when "one" then 1.00 WHEN "two" then 2.00 END) +0.0; -select case 1/0 when "a" then "true" else "false" END; -select case 1/0 when "a" then "true" END; - -select (case 1/0 when "a" then "true" END) | 0; - -select (case 1/0 when "a" then "true" END) + 0.0; -select case when 1>0 then "TRUE" else "FALSE" END; -select case when 1<0 then "TRUE" else "FALSE" END; -SELECT CAST(CASE WHEN 0 THEN '2001-01-01' END AS DATE); -SELECT CAST(CASE WHEN 0 THEN DATE'2001-01-01' END AS DATE); -select case 1.0 when 0.1 then "a" when 1.0 then "b" else "c" END; -select case 0.1 when 0.1 then "a" when 1.0 then "b" else "c" END; -select case 1 when 0.1 then "a" when 1.0 then "b" else "c" END; -select case 1.0 when 0.1 then "a" when 1 then "b" else "c" END; -select case 1.001 when 0.1 then "a" when 1 then "b" else "c" END; - --- @case --- @desc:test for case_when expression with normal select --- @label:bvt -drop table if exists t1; -drop table if exists t2; -CREATE TABLE t1 (a varchar(10), PRIMARY KEY (a)); -CREATE TABLE t2 (a varchar(10), b date, PRIMARY KEY(a)); -INSERT INTO t1 VALUES ('test1'); -INSERT INTO t2 VALUES -('test1','2016-12-13'),('test2','2016-12-14'),('test3','2016-12-15'); --- @bvt:issue#3254 -SELECT b, b = '20161213', - CASE b WHEN '20161213' then 'found' ELSE 'not found' END FROM t2; --- @bvt:issue - - --- @case --- @desc:test for case_when expression with group by --- @label:bvt -drop table if exists t1; -create table t1 (a int); -insert into t1 values(1),(2),(3),(4); -select case a when 1 then 2 when 2 then 3 else 0 end as fcase, count(*) from t1 group by fcase; -select case a when 1 then "one" when 2 then "two" else "nothing" end as fcase, count(*) from t1 group by fcase; -drop table if exists t1; - --- @case --- @desc:test for case_when expression with function --- @label:bvt -create table t1 (`row` int not null, col int not null, val varchar(255) not null); -insert into t1 values (1,1,'orange'),(1,2,'large'),(2,1,'yellow'),(2,2,'medium'),(3,1,'green'),(3,2,'small'); -select col,val, case when val="orange" then 1 when upper(val)="LARGE" then 2 else 3 end from t1; -select max(case col when 1 then val else null end) as color from t1 group by `row`; -drop table if exists t1; - -create table t1(a float, b int default 3); -insert into t1 (a) values (2), (11), (8); -select min(a), min(case when 1=1 then a else NULL end), - min(case when 1!=1 then NULL else a end) -from t1 where b=3 group by b; - -drop table if exists t1; -CREATE TABLE t1 (a INT, b INT); -INSERT INTO t1 VALUES (1,1),(2,1),(3,2),(4,2),(5,3),(6,3); -SELECT CASE WHEN AVG(a)>=0 THEN 'Positive' ELSE 'Negative' END FROM t1 GROUP BY b; - -drop table if exists t1; - --- @case --- @desc:test for case_when expression with join --- @label:bvt -drop table if exists t1; -drop table if exists t2; -create table t1 (a int, b bigint unsigned); -create table t2 (c int); -insert into t1 (a, b) values (1,4572794622775114594), (2,18196094287899841997), - (3,11120436154190595086); -insert into t2 (c) values (1), (2), (3); -select t1.a, (case t1.a when 0 then 0 else t1.b end) d from t1 - join t2 on t1.a=t2.c order by d; -select t1.a, (case t1.a when 0 then 0 else t1.b end) d from t1 - join t2 on t1.a=t2.c where b=11120436154190595086 order by d; -drop table if exists small; -drop table if exists big; -CREATE TABLE small (id int not null,PRIMARY KEY (id)); -CREATE TABLE big (id int not null,PRIMARY KEY (id)); -INSERT INTO small VALUES (1), (2); -INSERT INTO big VALUES (1), (2), (3), (4); -SELECT big.*, dt.* FROM big LEFT JOIN (SELECT id as dt_id, - CASE id WHEN 0 THEN 0 ELSE 1 END AS simple, - CASE WHEN id=0 THEN NULL ELSE 1 END AS cond - FROM small) AS dt - ON big.id=dt.dt_id; - -drop table if exists small; -drop table if exists big; - --- @case --- @desc:test for case_when expression with union --- @label:bvt -SELECT 'case+union+test' -UNION -SELECT CASE '1' WHEN '2' THEN 'BUG' ELSE 'nobug' END; - --- @case --- @desc:test for case_when expression in where filter --- @label:bvt -drop table t1; -CREATE TABLE t1(a int); -insert into t1 values(1),(1),(2),(1),(3),(2),(1); -SELECT 1 FROM t1 WHERE a=1 AND CASE 1 WHEN a THEN 1 ELSE 1 END; -DROP TABLE if exists t1; - --- @case --- @desc:test for case_when expression with count() --- @label:bvt -DROP TABLE if exists t1; -create table t1 (USR_ID int not null, MAX_REQ int not null); -insert into t1 values (1, 3); -select count(*) + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ from t1 group by MAX_REQ; -select Case When Count(*) < MAX_REQ Then 1 Else 0 End from t1 where t1.USR_ID = 1 group by MAX_REQ; -DROP TABLE if exists t1; - -select case when 1 in (1.0, 2.0, 3.0) then true else false end; - -DROP TABLE if exists t1; -CREATE TABLE t1 ( - id int NOT NULL AUTO_INCREMENT, - key_num int NOT NULL DEFAULT '0', - hiredate date NOT NULL, - PRIMARY KEY (id), - KEY key_num (key_num) -); - -insert into t1 values - (1, 7369, '1980-12-17'), - (2, 7499, '1981-02-20'), - (3, 7521, '1981-02-22'), - (4, 7566, '1981-04-02'), - (5, 7654, '1981-09-28'), - (6, 7698, '1981-05-01'), - (7, 7782, '1981-06-09'), - (8, 7788, '0087-07-13'), - (9, 7839, '1981-11-17'), - (10, 7844, '1981-09-08'), - (11, 7876, '2007-07-13'), - (12, 7900, '1981-12-03'), - (13, 7980, '1987-07-13'), - (14, 7981, '2001-11-17'), - (15, 7982, '1951-11-08'), - (16, 7983, '1927-10-13'), - (17, 7984, '1671-12-09'), - (18, 7985, '1981-11-06'), - (19, 7986, '1771-12-06'), - (20, 7987, '1985-10-06'); -select id, case when id < 5 then 0 when id < 10 then 1 when id < 15 then 2 when true then 3 else -1 end as xxx from t1; -DROP TABLE t1; -create table t1(a varchar(100)); -insert into t1 values ("a"); -select a, case when a="a" then 1 when upper(a)="b" then 2 end from t1; -drop table if exists t1; - --- @case --- @desc:test for case_when expression with mixed decimal scales --- @label:bvt -SELECT - 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) AS direct_mul, - CASE WHEN 'USD' = 'RMB' - THEN CAST(-58140.00 AS DECIMAL(23,2)) - ELSE 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) - END AS bug_case; - --- @case --- @desc:test for case_when expression with then branch decimal cast --- @label:bvt -SELECT - CASE WHEN 'USD' = 'USD' - THEN CAST(-58140.00 AS DECIMAL(23,2)) - ELSE 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) - END AS bug_case_then; - --- @case --- @desc:test for iff expression with mixed decimal scales --- @label:bvt -SELECT - IFF('USD' = 'USD', - CAST(-58140.00 AS DECIMAL(23,2)), - 7.01970 * CAST(-58140.00 AS DECIMAL(23,2))) AS bug_iff; - --- @case --- @desc:test for case_when expression with decimal128 branches promoting to decimal256 result type --- @label:bvt -SELECT - CASE WHEN 1 = 1 - THEN CAST(1 AS DECIMAL(38,0)) - ELSE CAST(0 AS DECIMAL(38,20)) - END AS case_decimal256_then; -SELECT - CASE WHEN 1 = 2 - THEN CAST(1 AS DECIMAL(38,0)) - ELSE CAST(0 AS DECIMAL(38,20)) - END AS case_decimal256_else; - --- @case --- @desc:test for iff expression with decimal128 branches promoting to decimal256 result type --- @label:bvt -SELECT - IFF(1 = 1, - CAST(1 AS DECIMAL(38,0)), - CAST(0 AS DECIMAL(38,20))) AS iff_decimal256_true; -SELECT - IFF(1 = 2, - CAST(1 AS DECIMAL(38,0)), - CAST(0 AS DECIMAL(38,20))) AS iff_decimal256_false; +-- @suit + +-- @case +-- @desc:test for case_when expression with constant operand +-- @label:bvt +select CASE "b" when "a" then 1 when "b" then 2 END; +select CASE "c" when "a" then 1 when "b" then 2 END; +select CASE "c" when "a" then 1 when "b" then 2 ELSE 3 END; +select CASE when 1=0 then "true" else "false" END; +select CASE 1 when 1 then "one" WHEN 2 then "two" ELSE "more" END; +select CASE 2.0 when 1 then "one" WHEN 2.0 then "two" ELSE "more" END; + +select (CASE "two" when "one" then "1" WHEN "two" then "2" END) | 0; + +select (CASE "two" when "one" then 1.00 WHEN "two" then 2.00 END) +0.0; +select case 1/0 when "a" then "true" else "false" END; +select case 1/0 when "a" then "true" END; + +select (case 1/0 when "a" then "true" END) | 0; + +select (case 1/0 when "a" then "true" END) + 0.0; +select case when 1>0 then "TRUE" else "FALSE" END; +select case when 1<0 then "TRUE" else "FALSE" END; +SELECT CAST(CASE WHEN 0 THEN '2001-01-01' END AS DATE); +SELECT CAST(CASE WHEN 0 THEN DATE'2001-01-01' END AS DATE); +select case 1.0 when 0.1 then "a" when 1.0 then "b" else "c" END; +select case 0.1 when 0.1 then "a" when 1.0 then "b" else "c" END; +select case 1 when 0.1 then "a" when 1.0 then "b" else "c" END; +select case 1.0 when 0.1 then "a" when 1 then "b" else "c" END; +select case 1.001 when 0.1 then "a" when 1 then "b" else "c" END; + +-- @case +-- @desc:test for case_when expression with normal select +-- @label:bvt +drop table if exists t1; +drop table if exists t2; +CREATE TABLE t1 (a varchar(10), PRIMARY KEY (a)); +CREATE TABLE t2 (a varchar(10), b date, PRIMARY KEY(a)); +INSERT INTO t1 VALUES ('test1'); +INSERT INTO t2 VALUES +('test1','2016-12-13'),('test2','2016-12-14'),('test3','2016-12-15'); +-- @bvt:issue#3254 +SELECT b, b = '20161213', + CASE b WHEN '20161213' then 'found' ELSE 'not found' END FROM t2; +-- @bvt:issue + + +-- @case +-- @desc:test for case_when expression with group by +-- @label:bvt +drop table if exists t1; +create table t1 (a int); +insert into t1 values(1),(2),(3),(4); +select case a when 1 then 2 when 2 then 3 else 0 end as fcase, count(*) from t1 group by fcase; +select case a when 1 then "one" when 2 then "two" else "nothing" end as fcase, count(*) from t1 group by fcase; +drop table if exists t1; + +-- @case +-- @desc:test for case_when expression with function +-- @label:bvt +create table t1 (`row` int not null, col int not null, val varchar(255) not null); +insert into t1 values (1,1,'orange'),(1,2,'large'),(2,1,'yellow'),(2,2,'medium'),(3,1,'green'),(3,2,'small'); +select col,val, case when val="orange" then 1 when upper(val)="LARGE" then 2 else 3 end from t1; +select max(case col when 1 then val else null end) as color from t1 group by `row`; +drop table if exists t1; + +create table t1(a float, b int default 3); +insert into t1 (a) values (2), (11), (8); +select min(a), min(case when 1=1 then a else NULL end), + min(case when 1!=1 then NULL else a end) +from t1 where b=3 group by b; + +drop table if exists t1; +CREATE TABLE t1 (a INT, b INT); +INSERT INTO t1 VALUES (1,1),(2,1),(3,2),(4,2),(5,3),(6,3); +SELECT CASE WHEN AVG(a)>=0 THEN 'Positive' ELSE 'Negative' END FROM t1 GROUP BY b; + +drop table if exists t1; + +-- @case +-- @desc:test for case_when expression with join +-- @label:bvt +drop table if exists t1; +drop table if exists t2; +create table t1 (a int, b bigint unsigned); +create table t2 (c int); +insert into t1 (a, b) values (1,4572794622775114594), (2,18196094287899841997), + (3,11120436154190595086); +insert into t2 (c) values (1), (2), (3); +select t1.a, (case t1.a when 0 then 0 else t1.b end) d from t1 + join t2 on t1.a=t2.c order by d; +select t1.a, (case t1.a when 0 then 0 else t1.b end) d from t1 + join t2 on t1.a=t2.c where b=11120436154190595086 order by d; +drop table if exists small; +drop table if exists big; +CREATE TABLE small (id int not null,PRIMARY KEY (id)); +CREATE TABLE big (id int not null,PRIMARY KEY (id)); +INSERT INTO small VALUES (1), (2); +INSERT INTO big VALUES (1), (2), (3), (4); +SELECT big.*, dt.* FROM big LEFT JOIN (SELECT id as dt_id, + CASE id WHEN 0 THEN 0 ELSE 1 END AS simple, + CASE WHEN id=0 THEN NULL ELSE 1 END AS cond + FROM small) AS dt + ON big.id=dt.dt_id; + +drop table if exists small; +drop table if exists big; + +-- @case +-- @desc:test for case_when expression with union +-- @label:bvt +SELECT 'case+union+test' +UNION +SELECT CASE '1' WHEN '2' THEN 'BUG' ELSE 'nobug' END; + +-- @case +-- @desc:test for case_when expression in where filter +-- @label:bvt +drop table t1; +CREATE TABLE t1(a int); +insert into t1 values(1),(1),(2),(1),(3),(2),(1); +SELECT 1 FROM t1 WHERE a=1 AND CASE 1 WHEN a THEN 1 ELSE 1 END; +DROP TABLE if exists t1; + +-- @case +-- @desc:test for case_when expression with count() +-- @label:bvt +DROP TABLE if exists t1; +create table t1 (USR_ID int not null, MAX_REQ int not null); +insert into t1 values (1, 3); +select count(*) + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ + MAX_REQ - MAX_REQ from t1 group by MAX_REQ; +select Case When Count(*) < MAX_REQ Then 1 Else 0 End from t1 where t1.USR_ID = 1 group by MAX_REQ; +DROP TABLE if exists t1; + +select case when 1 in (1.0, 2.0, 3.0) then true else false end; + +DROP TABLE if exists t1; +CREATE TABLE t1 ( + id int NOT NULL AUTO_INCREMENT, + key_num int NOT NULL DEFAULT '0', + hiredate date NOT NULL, + PRIMARY KEY (id), + KEY key_num (key_num) +); + +insert into t1 values + (1, 7369, '1980-12-17'), + (2, 7499, '1981-02-20'), + (3, 7521, '1981-02-22'), + (4, 7566, '1981-04-02'), + (5, 7654, '1981-09-28'), + (6, 7698, '1981-05-01'), + (7, 7782, '1981-06-09'), + (8, 7788, '0087-07-13'), + (9, 7839, '1981-11-17'), + (10, 7844, '1981-09-08'), + (11, 7876, '2007-07-13'), + (12, 7900, '1981-12-03'), + (13, 7980, '1987-07-13'), + (14, 7981, '2001-11-17'), + (15, 7982, '1951-11-08'), + (16, 7983, '1927-10-13'), + (17, 7984, '1671-12-09'), + (18, 7985, '1981-11-06'), + (19, 7986, '1771-12-06'), + (20, 7987, '1985-10-06'); +select id, case when id < 5 then 0 when id < 10 then 1 when id < 15 then 2 when true then 3 else -1 end as xxx from t1; +DROP TABLE t1; +create table t1(a varchar(100)); +insert into t1 values ("a"); +select a, case when a="a" then 1 when upper(a)="b" then 2 end from t1; +drop table if exists t1; + +-- @case +-- @desc:test for case_when expression with mixed decimal scales +-- @label:bvt +SELECT + 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) AS direct_mul, + CASE WHEN 'USD' = 'RMB' + THEN CAST(-58140.00 AS DECIMAL(23,2)) + ELSE 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) + END AS bug_case; + +-- @case +-- @desc:test for case_when expression with then branch decimal cast +-- @label:bvt +SELECT + CASE WHEN 'USD' = 'USD' + THEN CAST(-58140.00 AS DECIMAL(23,2)) + ELSE 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) + END AS bug_case_then; + +-- @case +-- @desc:test for iff expression with mixed decimal scales +-- @label:bvt +SELECT + IFF('USD' = 'USD', + CAST(-58140.00 AS DECIMAL(23,2)), + 7.01970 * CAST(-58140.00 AS DECIMAL(23,2))) AS bug_iff; + +-- @case +-- @desc:test for case_when expression with decimal128 branches promoting to decimal256 result type +-- @label:bvt +SELECT + CASE WHEN 1 = 1 + THEN CAST(1 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) + END AS case_decimal256_then; +SELECT + CASE WHEN 1 = 2 + THEN CAST(1 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) + END AS case_decimal256_else; + +-- @case +-- @desc:test for iff expression with decimal128 branches promoting to decimal256 result type +-- @label:bvt +SELECT + IFF(1 = 1, + CAST(1 AS DECIMAL(38,0)), + CAST(0 AS DECIMAL(38,20))) AS iff_decimal256_true; +SELECT + IFF(1 = 2, + CAST(1 AS DECIMAL(38,0)), + CAST(0 AS DECIMAL(38,20))) AS iff_decimal256_false; + +-- @case +-- @desc:test for coalesce over decimal branches with different scales aligns scale/width +-- @label:bvt +SELECT 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) AS direct_mul; +SELECT COALESCE( + CAST(NULL AS DECIMAL(23,2)), + 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) +) AS coalesce_decimal_scale; +SELECT COALESCE( + CAST(1.23 AS DECIMAL(23,2)), + 7.01970 * CAST(-58140.00 AS DECIMAL(23,2)) +) AS coalesce_first_non_null; + +-- @case +-- @desc:test for comparing a decimal256 case result with a decimal128 value +-- @label:bvt +SELECT (CASE WHEN 1 = 1 THEN CAST(1 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) + = CAST(1 AS DECIMAL(38,20)) AS decimal256_eq_decimal128; +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) + > CAST(1 AS DECIMAL(38,20)) AS decimal256_gt_decimal128; +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) + < CAST(1 AS DECIMAL(38,20)) AS decimal256_lt_decimal128; +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) + != CAST(1 AS DECIMAL(38,20)) AS decimal256_ne_decimal128; +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) + BETWEEN CAST(1 AS DECIMAL(38,20)) AND CAST(10 AS DECIMAL(38,20)) AS decimal256_between; + +-- @case +-- @desc:test for coalesce promoting decimal branches to decimal256 when integral+scale overflows decimal128 +-- @label:bvt +SELECT COALESCE(CAST(1 AS DECIMAL(38,0)), CAST(0.5 AS DECIMAL(30,30))) AS coalesce_promote_decimal256; +SELECT COALESCE(CAST(12345678901234567890123456789012345678 AS DECIMAL(38,0)), CAST(0.5 AS DECIMAL(30,30))) AS coalesce_promote_bignum; + +-- @case +-- @desc:test for non-constant decimal256 BETWEEN with mixed-scale bounds +-- @label:bvt +drop table if exists t_dec256_between; +create table t_dec256_between (a decimal(38,0)); +insert into t_dec256_between values (5),(50),(500); +select + (case when 1 = 1 then a else cast(0 as decimal(38,30)) end) + between cast(1 as decimal(38,20)) and cast(100 as decimal(38,20)) as in_range +from t_dec256_between order by a; +drop table t_dec256_between; + +-- @case +-- @desc:test for comparing a decimal256 case result with a bare integer literal +-- @label:bvt +SELECT (CASE WHEN 1 = 1 THEN CAST(1 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) = 1 AS decimal256_eq_int; +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) > 1 AS decimal256_gt_int; +SELECT (CASE WHEN 1 = 1 THEN CAST(5 AS DECIMAL(38,0)) + ELSE CAST(0 AS DECIMAL(38,20)) END) + BETWEEN 1 AND 10 AS decimal256_between_int; + +-- @case +-- @desc:test for between aligning a decimal value to a wider scale that overflows decimal128 into decimal256 +-- @label:bvt +SELECT CAST(1 AS DECIMAL(38,0)) + BETWEEN CAST(0.5 AS DECIMAL(38,30)) AND CAST(2 AS DECIMAL(38,30)) AS between_promote_decimal256;