Skip to content

Commit

Permalink
Support revert old column attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
sunary committed Nov 20, 2024
1 parent c1bffc9 commit 6ff32a7
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 71 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ CREATE TABLE user (
//CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`);

println(sql1.StringDown())
//ALTER TABLE `user` MODIFY COLUMN `id` int(11);
//ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime;
//DROP INDEX `idx_name_age` ON `user`;
}
```
10 changes: 6 additions & 4 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,13 @@ CREATE TABLE user (

sql1.Diff(*sql2)
println(sql1.StringUp())
// ALTER TABLE `user` MODIFY COLUMN `id` int(11) AUTO_INCREMENT PRIMARY KEY;
// ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP();
// CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`);
//ALTER TABLE `user` MODIFY COLUMN `id` int(11) AUTO_INCREMENT PRIMARY KEY;
//ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP();
//CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`);

println(sql1.StringDown())
// DROP INDEX `idx_name_age` ON `user`;
//ALTER TABLE `user` MODIFY COLUMN `id` int(11);
//ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime;
//DROP INDEX `idx_name_age` ON `user`;
}
```
6 changes: 3 additions & 3 deletions avro/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@ func getAvroType(col element.Column) interface{} {
"type": "string",
"connect.version": 1,
"connect.parameters": map[string]string{
"allowed": strings.Join(col.MysqlType.Elems, ","),
"allowed": strings.Join(col.CurrentAttr.MysqlType.Elems, ","),
},
"connect.default": "init",
"connect.name": "io.debezium.data.Enum",
}
}

switch col.MysqlType.EvalType() {
switch col.CurrentAttr.MysqlType.EvalType() {
case types.ETInt:
return "int"

case types.ETDecimal:
displayFlen, displayDecimal := col.MysqlType.Flen, col.MysqlType.Decimal
displayFlen, displayDecimal := col.CurrentAttr.MysqlType.Flen, col.CurrentAttr.MysqlType.Decimal
return map[string]interface{}{
"type": "bytes",
"scale": displayDecimal,
Expand Down
88 changes: 65 additions & 23 deletions element/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,34 @@ const (
LowerRestoreFlag = format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameLowercase | format.RestoreNameBackQuotes
)

// Column ...
type Column struct {
Node
type SqlAttr struct {
MysqlType *types.FieldType
PgType *ptypes.T
LiteType *sqlite.Type
Options []*ast.ColumnOption
Comment string
}

// Column ...
type Column struct {
Node

CurrentAttr SqlAttr
PreviousAttr SqlAttr
}

// GetType ...
func (c Column) GetType() byte {
if c.MysqlType != nil {
return c.MysqlType.Tp
if c.CurrentAttr.MysqlType != nil {
return c.CurrentAttr.MysqlType.Tp
}

return 0
}

// HasDefaultValue ...
func (c Column) HasDefaultValue() bool {
for _, opt := range c.Options {
for _, opt := range c.CurrentAttr.Options {
if opt.Tp == ast.ColumnOptionDefaultValue {
return true
}
Expand All @@ -54,7 +60,7 @@ func (c Column) HasDefaultValue() bool {

func (c Column) hashValue() string {
strHash := sql.EscapeSqlName(c.Name)
strHash += c.typeDefinition()
strHash += c.typeDefinition(false)
hash := md5.Sum([]byte(strHash))
return hex.EncodeToString(hash[:])
}
Expand All @@ -71,7 +77,7 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
strSql += strings.Repeat(" ", ident-len(c.Name))
}

strSql += c.definition()
strSql += c.definition(false)

if ident < 0 {
if after != "" {
Expand All @@ -90,10 +96,27 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
return []string{fmt.Sprintf(sql.AlterTableDropColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name))}

case MigrateModifyAction:
def := strings.Replace(c.definition(), sql.PrimaryOption(), "", 1)
def, isPk := c.pkDefinition(false)
if isPk {
if _, isPrevPk := c.pkDefinition(true); isPrevPk {
// avoid repeat define primary key
def = strings.Replace(def, sql.PrimaryOption(), "", 1)
}
}

return []string{fmt.Sprintf(sql.AlterTableModifyColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name)+def)}

case MigrateRevertAction:
prevDef, isPrevPk := c.pkDefinition(true)
if isPrevPk {
if _, isPk := c.pkDefinition(false); isPk {
// avoid repeat define primary key
prevDef = strings.Replace(prevDef, sql.PrimaryOption(), "", 1)
}
}

return []string{fmt.Sprintf(sql.AlterTableModifyColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name)+prevDef)}

case MigrateRenameAction:
return []string{fmt.Sprintf(sql.AlterTableRenameColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.OldName), sql.EscapeSqlName(c.Name))}

Expand All @@ -103,12 +126,12 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
}

func (c Column) migrationCommentUp(tbName string) []string {
if c.Comment == "" || sql.GetDialect() != sql_templates.PostgresDialect {
if c.CurrentAttr.Comment == "" || sql.GetDialect() != sql_templates.PostgresDialect {
return nil
}

// apply for postgres only
return []string{fmt.Sprintf(sql.ColumnComment(), tbName, c.Name, c.Comment)}
return []string{fmt.Sprintf(sql.ColumnComment(), tbName, c.Name, c.CurrentAttr.Comment)}
}

func (c Column) migrationDown(tbName, after string) []string {
Expand All @@ -123,7 +146,7 @@ func (c Column) migrationDown(tbName, after string) []string {
c.Action = MigrateAddAction

case MigrateModifyAction:
return nil
c.Action = MigrateRevertAction

case MigrateRenameAction:
c.Name, c.OldName = c.OldName, c.Name
Expand All @@ -135,10 +158,19 @@ func (c Column) migrationDown(tbName, after string) []string {
return c.migrationUp(tbName, after, -1)
}

func (c Column) definition() string {
strSql := c.typeDefinition()
func (c Column) pkDefinition(isPrev bool) (string, bool) {
attr := c.CurrentAttr
if isPrev {
attr = c.PreviousAttr
}
strSql := c.typeDefinition(isPrev)

isPrimaryKey := false
for _, opt := range attr.Options {
if opt.Tp == ast.ColumnOptionPrimaryKey {
isPrimaryKey = true
}

for _, opt := range c.Options {
b := bytes.NewBufferString("")
var ctx *format.RestoreCtx

Expand All @@ -157,17 +189,27 @@ func (c Column) definition() string {
strSql += " " + b.String()
}

return strSql
return strSql, isPrimaryKey
}

func (c Column) definition(isPrev bool) string {
def, _ := c.pkDefinition(isPrev)
return def
}

func (c Column) typeDefinition() string {
func (c Column) typeDefinition(isPrev bool) string {
attr := c.CurrentAttr
if isPrev {
attr = c.PreviousAttr
}

switch {
case sql.IsPostgres() && c.PgType != nil:
return " " + c.PgType.SQLString()
case sql.IsSqlite() && c.LiteType != nil:
return " " + c.LiteType.Name.Name
case c.MysqlType != nil:
return " " + c.MysqlType.String()
case sql.IsPostgres() && attr.PgType != nil:
return " " + attr.PgType.SQLString()
case sql.IsSqlite() && attr.LiteType != nil:
return " " + attr.LiteType.Name.Name
case attr.MysqlType != nil:
return " " + attr.MysqlType.String()
}

return "" // column type is empty
Expand Down
2 changes: 1 addition & 1 deletion element/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (m *Migration) AddComment(tbName, colName, comment string) {
return
}

m.Tables[id].Columns[colIdx].Comment = comment
m.Tables[id].Columns[colIdx].CurrentAttr.Comment = comment
}

// AddIndex ...
Expand Down
2 changes: 2 additions & 0 deletions element/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const (
MigrateRemoveAction
// MigrateModifyAction ...
MigrateModifyAction
// MigrateRevertAction ...
MigrateRevertAction
// MigrateRenameAction ...
MigrateRenameAction
)
Expand Down
19 changes: 10 additions & 9 deletions element/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,18 @@ func (t *Table) AddColumn(col Column) {
t.Columns[id] = col

default:
t.Columns[id].Options = append(t.Columns[id].Options, col.Options...)
t.Columns[id].CurrentAttr.Options = append(t.Columns[id].CurrentAttr.Options, col.CurrentAttr.Options...)

if size := len(t.Columns[id].Options); size > 0 {
for i := range t.Columns[id].Options[:size-1] {
if t.Columns[id].Options[i].Tp == ast.ColumnOptionPrimaryKey {
t.Columns[id].Options[i], t.Columns[id].Options[size-1] = t.Columns[id].Options[size-1], t.Columns[id].Options[i]
if size := len(t.Columns[id].CurrentAttr.Options); size > 0 {
for i := range t.Columns[id].CurrentAttr.Options[:size-1] {
if t.Columns[id].CurrentAttr.Options[i].Tp == ast.ColumnOptionPrimaryKey {
t.Columns[id].CurrentAttr.Options[i], t.Columns[id].CurrentAttr.Options[size-1] = t.Columns[id].CurrentAttr.Options[size-1], t.Columns[id].CurrentAttr.Options[i]
break
}
}
}

t.Columns[id].MysqlType = col.MysqlType
t.Columns[id].CurrentAttr.MysqlType = col.CurrentAttr.MysqlType
return
}

Expand Down Expand Up @@ -291,10 +291,11 @@ func (t *Table) Diff(old Table) {
for i := range t.Columns {
if j := old.getIndexColumn(t.Columns[i].Name); t.Columns[i].Action == MigrateAddAction &&
j >= 0 && old.Columns[j].Action != MigrateNoAction {
if hasChangedMysqlOptions(t.Columns[i].Options, old.Columns[j].Options) ||
hasChangedMysqlType(t.Columns[i].MysqlType, old.Columns[j].MysqlType) ||
hasChangePostgresType(t.Columns[i].PgType, old.Columns[j].PgType) {
if hasChangedMysqlOptions(t.Columns[i].CurrentAttr.Options, old.Columns[j].CurrentAttr.Options) ||
hasChangedMysqlType(t.Columns[i].CurrentAttr.MysqlType, old.Columns[j].CurrentAttr.MysqlType) ||
hasChangePostgresType(t.Columns[i].CurrentAttr.PgType, old.Columns[j].CurrentAttr.PgType) {
t.Columns[i].Action = MigrateModifyAction
t.Columns[i].PreviousAttr = old.Columns[j].CurrentAttr
} else {
t.Columns[i].Action = MigrateNoAction
}
Expand Down
42 changes: 25 additions & 17 deletions sql-parser/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
})
} else {
p.Migration.AddColumn(alter.Table.Text(), element.Column{
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
},
},
},
})
Expand Down Expand Up @@ -113,9 +115,11 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
if len(alter.Specs[i].NewColumns) > 0 {
for j := range alter.Specs[i].NewColumns {
col := element.Column{
Node: element.Node{Name: alter.Specs[i].NewColumns[j].Name.Name.O, Action: element.MigrateModifyAction},
MysqlType: alter.Specs[i].NewColumns[j].Tp,
Comment: alter.Specs[i].Comment,
Node: element.Node{Name: alter.Specs[i].NewColumns[j].Name.Name.O, Action: element.MigrateModifyAction},
CurrentAttr: element.SqlAttr{
MysqlType: alter.Specs[i].NewColumns[j].Tp,
Comment: alter.Specs[i].Comment,
},
}
p.Migration.AddColumn(alter.Table.Name.O, col)
}
Expand Down Expand Up @@ -161,11 +165,13 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
})
} else {
tb.AddColumn(element.Column{
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
},
},
},
})
Expand Down Expand Up @@ -218,10 +224,12 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
}

column := element.Column{
Node: element.Node{Name: def.Name.Name.O, Action: element.MigrateAddAction},
MysqlType: def.Tp,
Options: def.Options,
Comment: comment,
Node: element.Node{Name: def.Name.Name.O, Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
MysqlType: def.Tp,
Options: def.Options,
Comment: comment,
},
}
p.Migration.AddColumn("", column)
}
Expand Down
26 changes: 16 additions & 10 deletions sql-parser/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,24 @@ func (p *Parser) walker(ctx interface{}, node interface{}) (stop bool) {

case *tree.AlterTableAlterColumnType:
col := element.Column{
Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction},
PgType: nc.ToType,
Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction},
CurrentAttr: element.SqlAttr{
PgType: nc.ToType,
},
}
p.Migration.AddColumn(n.Table.String(), col)

case *tree.AlterTableSetDefault:
if nc.Default != nil {
col := element.Column{
Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction},
Options: []*ast.ColumnOption{{
Expr: nil,
Tp: ast.ColumnOptionDefaultValue,
StrValue: nc.Default.String(),
}},
CurrentAttr: element.SqlAttr{
Options: []*ast.ColumnOption{{
Expr: nil,
Tp: ast.ColumnOptionDefaultValue,
StrValue: nc.Default.String(),
}},
},
}
p.Migration.AddColumn(n.Table.String(), col)
}
Expand Down Expand Up @@ -166,9 +170,11 @@ func postgresColumn(n *tree.ColumnTableDef) (element.Column, []element.Index) {
}

return element.Column{
Node: element.Node{Name: n.Name.String(), Action: element.MigrateAddAction},
PgType: n.Type,
Options: opts,
Node: element.Node{Name: n.Name.String(), Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
PgType: n.Type,
Options: opts,
},
}, indexes
}

Expand Down
Loading

0 comments on commit 6ff32a7

Please sign in to comment.