diff --git a/parser/createtable_visitor.go b/parser/createtable_visitor.go index 1d111ea..3448b25 100644 --- a/parser/createtable_visitor.go +++ b/parser/createtable_visitor.go @@ -156,6 +156,7 @@ type Column struct { func (c *CreateTable) Convert() *Table { var ret Table ret.Name = onlyTableName(c.Name) + primaryKeyExists := checkIfPrimaryKeyExists(c.Constraints) for _, e := range c.Columns { definition := e.ColumnDefinition var data Column @@ -163,6 +164,10 @@ func (c *CreateTable) Convert() *Table { if definition != nil { data.DataType = definition.DataType data.Constraint = definition.ColumnConstraint + + if definition.ColumnConstraint != nil && !primaryKeyExists && definition.ColumnConstraint.Primary { + c.Constraints = append(c.Constraints, &TableConstraint{ColumnPrimaryKey: []string{e.Name}}) + } } ret.Columns = append(ret.Columns, &data) } @@ -177,3 +182,13 @@ func onlyTableName(name string) string { ss := strings.Split(name, "`.`") return ss[len(ss)-1] } + +func checkIfPrimaryKeyExists(constraints []*TableConstraint) bool { + for _, constraint := range constraints { + if len(constraint.ColumnPrimaryKey) > 0 { + return true + } + } + + return false +} diff --git a/parser/createtable_visitor_test.go b/parser/createtable_visitor_test.go index 7142c71..e12ae54 100644 --- a/parser/createtable_visitor_test.go +++ b/parser/createtable_visitor_test.go @@ -34,3 +34,81 @@ func Test_onlyTableName(t *testing.T) { }) } } + +func Test_checkIfPrimaryKeyExists(t *testing.T) { + type args struct { + constraints []*TableConstraint + } + + tests := []struct { + name string + args args + want bool + }{ + { + name: "No constraints (nil slice)", + args: args{ + constraints: nil, + }, + want: false, + }, + { + name: "Empty constraints slice", + args: args{ + constraints: []*TableConstraint{}, + }, + want: false, + }, + { + name: "One constraint without primary key", + args: args{ + constraints: []*TableConstraint{ + { + ColumnPrimaryKey: []string{}, + }, + }, + }, + want: false, + }, + { + name: "Multiple constraints, none with primary key", + args: args{ + constraints: []*TableConstraint{ + {ColumnPrimaryKey: []string{}}, + {ColumnPrimaryKey: []string{}}, + }, + }, + want: false, + }, + { + name: "One constraint with primary key", + args: args{ + constraints: []*TableConstraint{ + { + ColumnPrimaryKey: []string{"id"}, + }, + }, + }, + want: true, + }, + { + name: "Multiple constraints, one with primary key", + args: args{ + constraints: []*TableConstraint{ + {ColumnPrimaryKey: []string{}}, + {ColumnPrimaryKey: []string{"user_id"}}, + {ColumnPrimaryKey: []string{}}, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := checkIfPrimaryKeyExists(tt.args.constraints); got != tt.want { + t.Errorf("checkIfPrimaryKeyExists() = %v, want %v", got, tt.want) + } + }) + } +}