Skip to content

Commit

Permalink
Merge pull request #8347 from dolthub/zachmu/reset
Browse files Browse the repository at this point in the history
Bug fixes for schema names in reset operation
  • Loading branch information
zachmu authored Sep 18, 2024
2 parents a8dca5b + 8ff973a commit ae4c8c7
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 97 deletions.
36 changes: 17 additions & 19 deletions go/libraries/doltcore/env/actions/reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,11 @@ func resetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots
if err != nil {
return nil, doltdb.Roots{}, err
}

// untracked tables exist in |working| but not in |staged|
staged, err := roots.Staged.GetTableNames(ctx, doltdb.DefaultSchemaName)
if err != nil {
return nil, doltdb.Roots{}, err
}
staged := GetAllTableNames(ctx, roots.Staged)
for _, name := range staged {
delete(untracked, doltdb.TableName{Name: name})
delete(untracked, name)
}

newWkRoot := roots.Head
Expand Down Expand Up @@ -116,27 +114,18 @@ func resetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots
}

// need to save the state of files that aren't tracked
untrackedTables := make(map[string]*doltdb.Table)
wTblNames, err := roots.Working.GetTableNames(ctx, doltdb.DefaultSchemaName)

if err != nil {
return nil, doltdb.Roots{}, err
}
untrackedTables := make(map[doltdb.TableName]*doltdb.Table)
wTblNames := GetAllTableNames(ctx, roots.Working)

for _, tblName := range wTblNames {
untrackedTables[tblName], _, err = roots.Working.GetTable(ctx, doltdb.TableName{Name: tblName})
untrackedTables[tblName], _, err = roots.Working.GetTable(ctx, tblName)

if err != nil {
return nil, doltdb.Roots{}, err
}
}

headTblNames, err := roots.Staged.GetTableNames(ctx, doltdb.DefaultSchemaName)

if err != nil {
return nil, doltdb.Roots{}, err
}

headTblNames := GetAllTableNames(ctx, roots.Staged)
for _, tblName := range headTblNames {
delete(untrackedTables, tblName)
}
Expand All @@ -147,6 +136,15 @@ func resetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots
return newHead, roots, nil
}

func GetAllTableNames(ctx context.Context, root doltdb.RootValue) []doltdb.TableName {
tableNames := make([]doltdb.TableName, 0)
_ = root.IterTables(ctx, func(name doltdb.TableName, table *doltdb.Table, sch schema.Schema) (stop bool, err error) {
tableNames = append(tableNames, name)
return false, nil
})
return tableNames
}

// ResetHardTables resets the tables in working, staged, and head based on the given parameters. Returns the new
// head commit and resulting roots
func ResetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots doltdb.Roots) (*doltdb.Commit, doltdb.Roots, error) {
Expand Down Expand Up @@ -203,7 +201,7 @@ func ResetHard(
return nil
}

func ResetSoftTables(ctx context.Context, dbData env.DbData, apr *argparser.ArgParseResults, roots doltdb.Roots) (doltdb.Roots, error) {
func ResetSoftTables(ctx context.Context, apr *argparser.ArgParseResults, roots doltdb.Roots) (doltdb.Roots, error) {
tables, err := getUnionedTables(ctx, tableNamesFromArgs(apr.Args), roots.Staged, roots.Head)
if err != nil {
return doltdb.Roots{}, err
Expand Down
224 changes: 146 additions & 78 deletions go/libraries/doltcore/sqle/dprocedures/dolt_reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
)

// doltReset is the stored procedure version for the CLI command `dolt reset`.
Expand Down Expand Up @@ -80,101 +82,30 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) {
}

if apr.Contains(cli.HardResetParam) {
// Get the commitSpec for the branch if it exists
arg := ""
if apr.NArg() > 1 {
return 1, fmt.Errorf("--hard supports at most one additional param")
} else if apr.NArg() == 1 {
arg = apr.Arg(0)
}

var newHead *doltdb.Commit
newHead, roots, err = actions.ResetHardTables(ctx, dbData, arg, roots)
if err != nil {
return 1, err
}

// TODO: this overrides the transaction setting, needs to happen at commit, not here
if newHead != nil {
headRef, err := dbData.Rsr.CWBHeadRef()
if err != nil {
return 1, err
}
if err := dbData.Ddb.SetHeadToCommit(ctx, headRef, newHead); err != nil {
return 1, err
}
}

// TODO - refactor and make transactional with the head update above.
ws, err := dSess.WorkingSet(ctx, dbName)
if err != nil {
return 1, err
}
err = dSess.SetWorkingSet(ctx, dbName, ws.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
err = resetHard(ctx, apr, roots, dbData, dSess, dbName)
if err != nil {
return 1, err
}
err = dSess.ResetGlobals(ctx, dbName, roots.Working)
} else if apr.Contains(cli.SoftResetParam) {
err = resetSoft(ctx, apr, dbData, dSess, dbName)
if err != nil {
return 1, err
}

} else if apr.Contains(cli.SoftResetParam) {
arg := ""
if apr.NArg() > 1 {
return 1, fmt.Errorf("--soft supports at most one additional param")
} else if apr.NArg() == 1 {
arg = apr.Arg(0)
}

if arg != "" {
roots, err = actions.ResetSoftToRef(ctx, dbData, arg)
if err != nil {
return 1, err
}
ws, err := dSess.WorkingSet(ctx, dbName)
if err != nil {
return 1, err
}
err = dSess.SetWorkingSet(ctx, dbName, ws.WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
if err != nil {
return 1, err
}
}
} else {
if apr.NArg() != 1 || (apr.NArg() == 1 && apr.Arg(0) == ".") {
roots, err = actions.ResetSoftTables(ctx, dbData, apr, roots)
if err != nil {
return 1, err
}
err = dSess.SetRoots(ctx, dbName, roots)
err := resetSoftTables(ctx, apr, roots, dSess, dbName)
if err != nil {
return 1, err
}
} else {
// check if the input is a table name or commit ref
_, okHead, _ := roots.Head.ResolveTableName(ctx, doltdb.TableName{Name: apr.Arg(0)})
_, okStaged, _ := roots.Staged.ResolveTableName(ctx, doltdb.TableName{Name: apr.Arg(0)})
_, okWorking, _ := roots.Working.ResolveTableName(ctx, doltdb.TableName{Name: apr.Arg(0)})
if okHead || okStaged || okWorking {
roots, err = actions.ResetSoftTables(ctx, dbData, apr, roots)
if err != nil {
return 1, err
}
err = dSess.SetRoots(ctx, dbName, roots)
if isTableInRoots(ctx, roots, apr.Arg(0)) {
err := resetSoftTables(ctx, apr, roots, dSess, dbName)
if err != nil {
return 1, err
}
} else {
roots, err = actions.ResetSoftToRef(ctx, dbData, apr.Arg(0))
if err != nil {
return 1, err
}
ws, err := dSess.WorkingSet(ctx, dbName)
if err != nil {
return 1, err
}
err = dSess.SetWorkingSet(ctx, dbName, ws.WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
err := resetSoftToRef(ctx, dbData, apr.Arg(0), dSess, dbName)
if err != nil {
return 1, err
}
Expand All @@ -188,3 +119,140 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) {

return 0, nil
}

// resetSoftToRef resets the session HEAD to the commit ref given
func resetSoftToRef(
ctx *sql.Context,
dbData env.DbData,
firstArg string,
dSess *dsess.DoltSession,
dbName string,
) error {
roots, err := actions.ResetSoftToRef(ctx, dbData, firstArg)
if err != nil {
return err
}
ws, err := dSess.WorkingSet(ctx, dbName)
if err != nil {
return err
}
err = dSess.SetWorkingSet(ctx, dbName, ws.WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
if err != nil {
return err
}
return nil
}

// isTableInRoots returns true if the table given exists in any of the roots given
func isTableInRoots(ctx *sql.Context, roots doltdb.Roots, tableName string) bool {
_, tableNameInHead, _ := roots.Head.ResolveTableName(ctx, doltdb.TableName{Name: tableName})
_, tableNameInStaged, _ := roots.Staged.ResolveTableName(ctx, doltdb.TableName{Name: tableName})
_, tableNameInWorking, _ := roots.Working.ResolveTableName(ctx, doltdb.TableName{Name: tableName})
isTableName := tableNameInHead || tableNameInStaged || tableNameInWorking
return isTableName
}

// resetSoftTables replaces staged tables named from HEAD
func resetSoftTables(
ctx *sql.Context,
apr *argparser.ArgParseResults,
roots doltdb.Roots,
dSess *dsess.DoltSession,
dbName string,
) error {
roots, err := actions.ResetSoftTables(ctx, apr, roots)
if err != nil {
return err
}
err = dSess.SetRoots(ctx, dbName, roots)
if err != nil {
return err
}
return nil
}

// resetSoft resets the session HEAD without making changes to the working set
func resetSoft(
ctx *sql.Context,
apr *argparser.ArgParseResults,
dbData env.DbData,
dSess *dsess.DoltSession,
dbName string,
) error {
arg := ""
if apr.NArg() > 1 {
return fmt.Errorf("--soft supports at most one additional param")
} else if apr.NArg() == 1 {
arg = apr.Arg(0)
}

// If ref is "" that means HEAD, which makes reset --soft a no-op
if arg != "" {
roots, err := actions.ResetSoftToRef(ctx, dbData, arg)
if err != nil {
return err
}
ws, err := dSess.WorkingSet(ctx, dbName)
if err != nil {
return err
}
err = dSess.SetWorkingSet(ctx, dbName, ws.WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
if err != nil {
return err
}
}

return nil
}

// resetHard resets the session working and staged to HEAD
func resetHard(
ctx *sql.Context,
apr *argparser.ArgParseResults,
roots doltdb.Roots,
dbData env.DbData,
dSess *dsess.DoltSession,
dbName string,
) error {
// Get the commitSpec for the branch if it exists
arg := ""
if apr.NArg() > 1 {
return fmt.Errorf("--hard supports at most one additional param")
} else if apr.NArg() == 1 {
arg = apr.Arg(0)
}

var newHead *doltdb.Commit
newHead, roots, err := actions.ResetHardTables(ctx, dbData, arg, roots)

if err != nil {
return err
}

// TODO: this overrides the transaction setting, needs to happen at commit, not here
if newHead != nil {
headRef, err := dbData.Rsr.CWBHeadRef()
if err != nil {
return err
}
if err := dbData.Ddb.SetHeadToCommit(ctx, headRef, newHead); err != nil {
return err
}
}

// TODO - refactor and make transactional with the head update above.
ws, err := dSess.WorkingSet(ctx, dbName)
if err != nil {
return err
}
err = dSess.SetWorkingSet(ctx, dbName, ws.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
if err != nil {
return err
}
err = dSess.ResetGlobals(ctx, dbName, roots.Working)
if err != nil {
return err
}

return nil
}

0 comments on commit ae4c8c7

Please sign in to comment.