Skip to content

Commit

Permalink
do early table name length check and calculate possible lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
jayjanssen committed Jun 27, 2024
1 parent 4fad849 commit 5fdeb6d
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 19 deletions.
15 changes: 8 additions & 7 deletions pkg/check/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ const (
)

type Resources struct {
DB *sql.DB
Replica *sql.DB
Table *table.TableInfo
Alter string
TargetChunkTime time.Duration
Threads int
ReplicaMaxLag time.Duration
DB *sql.DB
Replica *sql.DB
Table *table.TableInfo
Alter string
TargetChunkTime time.Duration
Threads int
ReplicaMaxLag time.Duration
SkipDropAfterCutover bool
// The following resources are only used by the
// pre-run checks
Host string
Expand Down
62 changes: 62 additions & 0 deletions pkg/check/tablename.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package check

import (
"context"
"fmt"
"strings"

"github.com/siddontang/loggers"
)

const (
// Max table name length in MySQL
maxTableNameLength = 64

// Formats for table names
NameFormatSentinel = "_%s_sentinel"
NameFormatCheckpoint = "_%s_chkpnt"
NameFormatNew = "_%s_new"
NameFormatOld = "_%s_old"
NameFormatOldTimeStamp = "_%s_old_%s"
NameFormatTimestamp = "20060102_150405"
)

var (
// The number of extra characters needed for table names with all possible
// formats. These vars are calculated in the `init` function below.
NameFormatNormalExtraChars = 0
NameFormatTimestampExtraChars = 0
)

func init() {
registerCheck("tablename", tableNameCheck, ScopePreflight)

// Calculate the number of extra characters needed table names with all possible formats
for _, format := range []string{NameFormatSentinel, NameFormatCheckpoint, NameFormatNew, NameFormatOld} {
extraChars := len(strings.Replace(format, "%s", "", -1))
if extraChars > NameFormatNormalExtraChars {
NameFormatNormalExtraChars = extraChars
}
}

// Calculate the number of extra characters needed for table names with the old timestamp format
NameFormatTimestampExtraChars = len(strings.Replace(NameFormatOldTimeStamp, "%s", "", -1)) + len(NameFormatTimestamp)
}

func tableNameCheck(ctx context.Context, r Resources, logger loggers.Advanced) error {
tableName := r.Table.TableName
if len(tableName) < 1 {
return fmt.Errorf("table name must be at least 1 character")

Check failure on line 49 in pkg/check/tablename.go

View workflow job for this annotation

GitHub Actions / lint

fmt.Errorf can be replaced with errors.New (perfsprint)
}

timestampTableNameLength := maxTableNameLength - NameFormatTimestampExtraChars
if r.SkipDropAfterCutover && len(tableName) > timestampTableNameLength {
return fmt.Errorf("table name must be less than %d characters when --skip-drop-after-cutover is set", timestampTableNameLength)
}

normalTableNameLength := maxTableNameLength - NameFormatNormalExtraChars
if len(tableName) > normalTableNameLength {
return fmt.Errorf("table name must be less than %d characters", normalTableNameLength)
}
return nil
}
43 changes: 43 additions & 0 deletions pkg/check/tablename_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package check

import (
"context"
"testing"

"github.com/cashapp/spirit/pkg/table"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

func TestCheckTableNameConstants(t *testing.T) {
// Calculated extra chars should always be greater than 0
assert.Greater(t, NameFormatNormalExtraChars, 0)

Check failure on line 14 in pkg/check/tablename_test.go

View workflow job for this annotation

GitHub Actions / lint

negative-positive: use assert.Positive (testifylint)
assert.Greater(t, NameFormatTimestampExtraChars, 0)

Check failure on line 15 in pkg/check/tablename_test.go

View workflow job for this annotation

GitHub Actions / lint

negative-positive: use assert.Positive (testifylint)

// Calculated extra chars should be less than the max table name length
assert.Less(t, NameFormatNormalExtraChars, maxTableNameLength)
assert.Less(t, NameFormatTimestampExtraChars, maxTableNameLength)
}

func TestCheckTableName(t *testing.T) {
testTableName := func(name string, skipDropAfterCutover bool) error {
r := Resources{
Table: &table.TableInfo{
TableName: name,
},
SkipDropAfterCutover: skipDropAfterCutover,
}
return tableNameCheck(context.Background(), r, logrus.New())
}

assert.NoError(t, testTableName("a", false))
assert.NoError(t, testTableName("a", true))

assert.ErrorContains(t, testTableName("", false), "table name must be at least 1 character")
assert.ErrorContains(t, testTableName("", true), "table name must be at least 1 character")

longName := "thisisareallylongtablenamethisisareallylongtablenamethisisareallylongtablename"
assert.ErrorContains(t, testTableName(longName, false), "table name must be less than")
assert.ErrorContains(t, testTableName(longName, true), "table name must be less than")

}

Check failure on line 43 in pkg/check/tablename_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)
27 changes: 15 additions & 12 deletions pkg/migration/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,10 @@ func (r *Runner) runChecks(ctx context.Context, scope check.ScopeFlag) error {
ReplicaMaxLag: r.migration.ReplicaMaxLag,
// For the pre-run checks we don't have a DB connection yet.
// Instead we check the credentials provided.
Host: r.migration.Host,
Username: r.migration.Username,
Password: r.migration.Password,
Host: r.migration.Host,
Username: r.migration.Username,
Password: r.migration.Password,
SkipDropAfterCutover: r.migration.SkipDropAfterCutover,
}, r.logger, scope)
}

Expand Down Expand Up @@ -566,10 +567,7 @@ func (r *Runner) dropCheckpoint(ctx context.Context) error {
}

func (r *Runner) createNewTable(ctx context.Context) error {
newName := fmt.Sprintf("_%s_new", r.table.TableName)
if len(newName) > 64 {
return fmt.Errorf("table name is too long: '%s'. new table name will exceed 64 characters", r.table.TableName)
}
newName := fmt.Sprintf(check.NameFormatNew, r.table.TableName)
// drop both if we've decided to call this func.
if err := dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, newName); err != nil {
return err
Expand Down Expand Up @@ -609,7 +607,12 @@ func (r *Runner) dropOldTable(ctx context.Context) error {
}

func (r *Runner) oldTableName() string {
return fmt.Sprintf("_%s_old_%s", r.table.TableName, r.startTime.UTC().Format("20060102_150405.000"))
// By default we just set the old table name to _<table>_old
// but if they've enabled SkipDropAfterCutover, we add a timestamp
if !r.migration.SkipDropAfterCutover {
return fmt.Sprintf(check.NameFormatOld, r.table.TableName)
}
return fmt.Sprintf(check.NameFormatOldTimeStamp, r.table.TableName, r.startTime.UTC().Format(check.NameFormatTimestamp))
}

func (r *Runner) attemptInstantDDL(ctx context.Context) error {
Expand All @@ -621,7 +624,7 @@ func (r *Runner) attemptInplaceDDL(ctx context.Context) error {
}

func (r *Runner) createCheckpointTable(ctx context.Context) error {
cpName := fmt.Sprintf("_%s_chkpnt", r.table.TableName)
cpName := fmt.Sprintf(check.NameFormatCheckpoint, r.table.TableName)
// drop both if we've decided to call this func.
if err := dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, cpName); err != nil {
return err
Expand Down Expand Up @@ -658,7 +661,7 @@ func (r *Runner) GetProgress() Progress {
}

func (r *Runner) sentinelTableName() string {
return fmt.Sprintf("_%s_sentinel", r.table.TableName)
return fmt.Sprintf(check.NameFormatSentinel, r.table.TableName)
}

func (r *Runner) createSentinelTable(ctx context.Context) error {
Expand Down Expand Up @@ -729,8 +732,8 @@ func (r *Runner) resumeFromCheckpoint(ctx context.Context) error {

// The objects for these are not available until we confirm
// tables exist and we
newName := fmt.Sprintf("_%s_new", r.table.TableName)
cpName := fmt.Sprintf("_%s_chkpnt", r.table.TableName)
newName := fmt.Sprintf(check.NameFormatNew, r.table.TableName)
cpName := fmt.Sprintf(check.NameFormatCheckpoint, r.table.TableName)

// Make sure we can read from the new table.
if err := dbconn.Exec(ctx, r.db, "SELECT * FROM %n.%n LIMIT 1",
Expand Down

0 comments on commit 5fdeb6d

Please sign in to comment.