Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions pkg/sql/plan/build_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,18 +315,21 @@ func buildDefaultExpr(col *tree.ColumnTableDef, typ plan.Type, proc *process.Pro
}
}

originExpr := expr
semanticExpr := unwrapParenExpr(expr)

colNameOrigin := col.Name.ColNameOrigin()
if typ.Id == int32(types.T_json) {
if expr != nil && !isNullAstExpr(expr) {
if semanticExpr != nil && !isNullAstExpr(semanticExpr) {
return nil, moerr.NewNotSupported(proc.Ctx, fmt.Sprintf("JSON column '%s' cannot have default value", colNameOrigin))
}
}
if isGeometryPlanType(&typ) {
if expr != nil && !isNullAstExpr(expr) {
if semanticExpr != nil && !isNullAstExpr(semanticExpr) {
return nil, moerr.NewNotSupported(proc.Ctx, fmt.Sprintf("GEOMETRY column '%s' cannot have default value", colNameOrigin))
}
}
if !nullAbility && isNullAstExpr(expr) {
if !nullAbility && isNullAstExpr(semanticExpr) {
return nil, moerr.NewInvalidInputf(proc.Ctx, "invalid default value for column '%s'", colNameOrigin)
}

Expand All @@ -337,15 +340,16 @@ func buildDefaultExpr(col *tree.ColumnTableDef, typ plan.Type, proc *process.Pro
OriginString: "",
}, nil
}
_, isExpressionDefault := originExpr.(*tree.ParenExpr)

binder := NewDefaultBinder(proc.Ctx, nil, nil, typ, nil)
planExpr, err := binder.BindExpr(expr, 0, false)
planExpr, err := binder.BindExpr(semanticExpr, 0, false)
if err != nil {
return nil, err
}

if defaultFunc := planExpr.GetF(); defaultFunc != nil {
if int(typ.Id) != int(types.T_uuid) && defaultFunc.Func.ObjName == "uuid" {
if int(typ.Id) != int(types.T_uuid) && defaultFunc.Func.ObjName == "uuid" && !isExpressionDefault {
return nil, moerr.NewInvalidInputf(proc.Ctx, "invalid default value for column '%s'", colNameOrigin)
}
}
Expand All @@ -362,7 +366,7 @@ func buildDefaultExpr(col *tree.ColumnTableDef, typ plan.Type, proc *process.Pro
}

fmtCtx := tree.NewFmtCtx(dialect.MYSQL, tree.WithSingleQuoteString())
fmtCtx.PrintExpr(expr, expr, false)
fmtCtx.PrintExpr(originExpr, originExpr, false)
return &plan.Default{
NullAbility: nullAbility,
Expr: newExpr,
Expand Down
88 changes: 88 additions & 0 deletions pkg/sql/plan/build_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,91 @@ func TestBuildDefaultExprGeometryAllowsNullDefault(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, def)
}

func TestBuildDefaultExprParenthesizedNullMatchesNullDefault(t *testing.T) {
proc := testutil.NewProcess(t)

tests := []struct {
name string
sql string
wantErr string
}{
{
name: "not null rejects parenthesized null",
sql: "create table t (a int not null default (null))",
wantErr: "invalid default value for column 'a'",
},
{
name: "json allows parenthesized null",
sql: "create table t (j json default (null))",
},
{
name: "geometry allows parenthesized null",
sql: "create table t (g geometry default (null))",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt, err := mysql.ParseOne(context.Background(), tt.sql, 1)
require.NoError(t, err)

createTable, ok := stmt.(*tree.CreateTable)
require.True(t, ok)
colDef, ok := createTable.Defs[0].(*tree.ColumnTableDef)
require.True(t, ok)

typ, err := getTypeFromAst(context.Background(), colDef.Type)
require.NoError(t, err)

def, err := buildDefaultExpr(colDef, typ, proc)
if tt.wantErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.wantErr)
return
}
require.NoError(t, err)
require.NotNil(t, def)
})
}
}

func TestBuildDefaultExprAllowsParenthesizedUuidForStringDefault(t *testing.T) {
proc := testutil.NewProcess(t)

stmt, err := mysql.ParseOne(context.Background(), "create table t (id varchar(191) not null default (uuid()))", 1)
require.NoError(t, err)

createTable, ok := stmt.(*tree.CreateTable)
require.True(t, ok)
colDef, ok := createTable.Defs[0].(*tree.ColumnTableDef)
require.True(t, ok)

typ, err := getTypeFromAst(context.Background(), colDef.Type)
require.NoError(t, err)

def, err := buildDefaultExpr(colDef, typ, proc)
require.NoError(t, err)
require.NotNil(t, def)
require.NotNil(t, def.Expr)
require.Equal(t, "(uuid())", def.OriginString)
}

func TestBuildDefaultExprKeepsBareUuidTypeGuard(t *testing.T) {
proc := testutil.NewProcess(t)

stmt, err := mysql.ParseOne(context.Background(), "create table t (a int default uuid())", 1)
require.NoError(t, err)

createTable, ok := stmt.(*tree.CreateTable)
require.True(t, ok)
colDef, ok := createTable.Defs[0].(*tree.ColumnTableDef)
require.True(t, ok)

typ, err := getTypeFromAst(context.Background(), colDef.Type)
require.NoError(t, err)

_, err = buildDefaultExpr(colDef, typ, proc)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid default value for column 'a'")
}
11 changes: 11 additions & 0 deletions test/distributed/cases/dtype/uuid_type_and_uuid_func.result
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ length(cast(a as varchar))
36
36
drop table t4;
drop table if exists t5;
create table t5 (id varchar(191) not null default (uuid()), n varchar(50), primary key(id));
show create table t5;
Table Create Table
t5 CREATE TABLE `t5` (\n `id` varchar(191) NOT NULL DEFAULT (uuid()),\n `n` varchar(50) DEFAULT NULL,\n PRIMARY KEY (`id`)\n)
insert into t5(n) values('a'),('b');
select length(id), n from t5 order by n;
length(id) n
36 a
36 b
drop table t5;
select length(cast(uuid() as varchar));
length(cast(uuid() as varchar))
36
Expand Down
8 changes: 8 additions & 0 deletions test/distributed/cases/dtype/uuid_type_and_uuid_func.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ insert into t4 values (uuid());
select length(cast(a as varchar)) from t4;
drop table t4;

-- test MySQL 8 expression default uuid on string column
drop table if exists t5;
create table t5 (id varchar(191) not null default (uuid()), n varchar(50), primary key(id));
show create table t5;
insert into t5(n) values('a'),('b');
select length(id), n from t5 order by n;
drop table t5;


-- test cast string
select length(cast(uuid() as varchar));
Expand Down
Loading