diff --git a/pkg/meta/model/index.go b/pkg/meta/model/index.go index d65af0255cbbc..da8518aebc000 100644 --- a/pkg/meta/model/index.go +++ b/pkg/meta/model/index.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/planner/cascades/base" ) // DistanceMetric is the distance metric used by the vector index. @@ -78,6 +79,27 @@ type IndexInfo struct { VectorInfo *VectorIndexInfo `json:"vector_index"` // VectorInfo is the vector index information. } +// Hash64 implement HashEquals interface. +func (index *IndexInfo) Hash64(h base.Hasher) { + h.HashInt64(index.ID) +} + +// Equals implements HashEquals interface. +func (index *IndexInfo) Equals(other any) bool { + // any(nil) can still be converted as (*IndexInfo)(nil) + index2, ok := other.(*IndexInfo) + if !ok { + return false + } + if index == nil { + return index2 == nil + } + if index2 == nil { + return false + } + return index.ID == index2.ID +} + // Clone clones IndexInfo. func (index *IndexInfo) Clone() *IndexInfo { if index == nil { diff --git a/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go b/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go index 8c8e7bf9d1b64..b3d22ec5317b7 100644 --- a/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go +++ b/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go @@ -40,6 +40,7 @@ func GenHash64Equals4LogicalOps() ([]byte, error) { logicalop.LogicalExpand{}, logicalop.LogicalLimit{}, logicalop.LogicalMaxOneRow{}, logicalop.DataSource{}, logicalop.LogicalMemTable{}, logicalop.LogicalUnionAll{}, logicalop.LogicalPartitionUnionAll{}, logicalop.LogicalProjection{}, logicalop.LogicalSelection{}, logicalop.LogicalShow{}, logicalop.LogicalShowDDLJobs{}, logicalop.LogicalSort{}, + logicalop.LogicalTableDual{}, logicalop.LogicalTopN{}, logicalop.LogicalUnionScan{}, logicalop.LogicalWindow{}, } c := new(cc) c.write(codeGenHash64EqualsPrefix) @@ -144,6 +145,14 @@ func logicalOpName2PlanCodecString(name string) string { return "plancodec.TypeShowDDLJobs" case "LogicalSort": return "plancodec.TypeSort" + case "LogicalTableDual": + return "plancodec.TypeDual" + case "LogicalTopN": + return "plancodec.TypeTopN" + case "LogicalUnionScan": + return "plancodec.TypeUnionScan" + case "LogicalWindow": + return "plancodec.TypeWindow" default: return "" } diff --git a/pkg/planner/core/operator/logicalop/hash64_equals_generated.go b/pkg/planner/core/operator/logicalop/hash64_equals_generated.go index 6664067b8ea44..99444d045048e 100644 --- a/pkg/planner/core/operator/logicalop/hash64_equals_generated.go +++ b/pkg/planner/core/operator/logicalop/hash64_equals_generated.go @@ -809,3 +809,222 @@ func (op *LogicalSort) Equals(other any) bool { } return true } + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalTableDual) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeDual) + op.LogicalSchemaProducer.Hash64(h) + h.HashInt64(int64(op.RowCount)) +} + +// Equals implements the Hash64Equals interface, only receive *LogicalTableDual pointer. +func (op *LogicalTableDual) Equals(other any) bool { + op2, ok := other.(*LogicalTableDual) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) { + return false + } + if op.RowCount != op2.RowCount { + return false + } + return true +} + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalTopN) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeTopN) + if op.ByItems == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.ByItems)) + for _, one := range op.ByItems { + one.Hash64(h) + } + } + if op.PartitionBy == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.PartitionBy)) + for _, one := range op.PartitionBy { + one.Hash64(h) + } + } + h.HashUint64(uint64(op.Offset)) + h.HashUint64(uint64(op.Count)) + h.HashBool(op.PreferLimitToCop) +} + +// Equals implements the Hash64Equals interface, only receive *LogicalTopN pointer. +func (op *LogicalTopN) Equals(other any) bool { + op2, ok := other.(*LogicalTopN) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if (op.ByItems == nil && op2.ByItems != nil) || (op.ByItems != nil && op2.ByItems == nil) || len(op.ByItems) != len(op2.ByItems) { + return false + } + for i, one := range op.ByItems { + if !one.Equals(op2.ByItems[i]) { + return false + } + } + if (op.PartitionBy == nil && op2.PartitionBy != nil) || (op.PartitionBy != nil && op2.PartitionBy == nil) || len(op.PartitionBy) != len(op2.PartitionBy) { + return false + } + for i, one := range op.PartitionBy { + if !one.Equals(&op2.PartitionBy[i]) { + return false + } + } + if op.Offset != op2.Offset { + return false + } + if op.Count != op2.Count { + return false + } + if op.PreferLimitToCop != op2.PreferLimitToCop { + return false + } + return true +} + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalUnionScan) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeUnionScan) + if op.Conditions == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.Conditions)) + for _, one := range op.Conditions { + one.Hash64(h) + } + } + op.HandleCols.Hash64(h) +} + +// Equals implements the Hash64Equals interface, only receive *LogicalUnionScan pointer. +func (op *LogicalUnionScan) Equals(other any) bool { + op2, ok := other.(*LogicalUnionScan) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if (op.Conditions == nil && op2.Conditions != nil) || (op.Conditions != nil && op2.Conditions == nil) || len(op.Conditions) != len(op2.Conditions) { + return false + } + for i, one := range op.Conditions { + if !one.Equals(op2.Conditions[i]) { + return false + } + } + if !op.HandleCols.Equals(op2.HandleCols) { + return false + } + return true +} + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalWindow) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeWindow) + op.LogicalSchemaProducer.Hash64(h) + if op.WindowFuncDescs == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.WindowFuncDescs)) + for _, one := range op.WindowFuncDescs { + one.Hash64(h) + } + } + if op.PartitionBy == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.PartitionBy)) + for _, one := range op.PartitionBy { + one.Hash64(h) + } + } + if op.OrderBy == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.OrderBy)) + for _, one := range op.OrderBy { + one.Hash64(h) + } + } + if op.Frame == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + op.Frame.Hash64(h) + } +} + +// Equals implements the Hash64Equals interface, only receive *LogicalWindow pointer. +func (op *LogicalWindow) Equals(other any) bool { + op2, ok := other.(*LogicalWindow) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) { + return false + } + if (op.WindowFuncDescs == nil && op2.WindowFuncDescs != nil) || (op.WindowFuncDescs != nil && op2.WindowFuncDescs == nil) || len(op.WindowFuncDescs) != len(op2.WindowFuncDescs) { + return false + } + for i, one := range op.WindowFuncDescs { + if !one.Equals(op2.WindowFuncDescs[i]) { + return false + } + } + if (op.PartitionBy == nil && op2.PartitionBy != nil) || (op.PartitionBy != nil && op2.PartitionBy == nil) || len(op.PartitionBy) != len(op2.PartitionBy) { + return false + } + for i, one := range op.PartitionBy { + if !one.Equals(&op2.PartitionBy[i]) { + return false + } + } + if (op.OrderBy == nil && op2.OrderBy != nil) || (op.OrderBy != nil && op2.OrderBy == nil) || len(op.OrderBy) != len(op2.OrderBy) { + return false + } + for i, one := range op.OrderBy { + if !one.Equals(&op2.OrderBy[i]) { + return false + } + } + if !op.Frame.Equals(op2.Frame) { + return false + } + return true +} diff --git a/pkg/planner/core/operator/logicalop/logical_table_dual.go b/pkg/planner/core/operator/logicalop/logical_table_dual.go index fa05796f91b45..2fa5bb3da5285 100644 --- a/pkg/planner/core/operator/logicalop/logical_table_dual.go +++ b/pkg/planner/core/operator/logicalop/logical_table_dual.go @@ -33,10 +33,10 @@ import ( // outputting 0/1 row with zero column. This semantic may be different from your expectation sometimes but should not // cause any actual problems now. type LogicalTableDual struct { - LogicalSchemaProducer + LogicalSchemaProducer `hash64-equals:"true"` // RowCount could only be 0 or 1. - RowCount int + RowCount int `hash64-equals:"true"` } // Init initializes LogicalTableDual. diff --git a/pkg/planner/core/operator/logicalop/logical_top_n.go b/pkg/planner/core/operator/logicalop/logical_top_n.go index b29eb13d7fc62..2535205a1321a 100644 --- a/pkg/planner/core/operator/logicalop/logical_top_n.go +++ b/pkg/planner/core/operator/logicalop/logical_top_n.go @@ -32,12 +32,12 @@ import ( type LogicalTopN struct { BaseLogicalPlan - ByItems []*util.ByItems + ByItems []*util.ByItems `hash64-equals:"true"` // PartitionBy is used for extended TopN to consider K heaps. Used by rule_derive_topn_from_window - PartitionBy []property.SortItem // This is used for enhanced topN optimization - Offset uint64 - Count uint64 - PreferLimitToCop bool + PartitionBy []property.SortItem `hash64-equals:"true"` // This is used for enhanced topN optimization + Offset uint64 `hash64-equals:"true"` + Count uint64 `hash64-equals:"true"` + PreferLimitToCop bool `hash64-equals:"true"` } // Init initializes LogicalTopN. diff --git a/pkg/planner/core/operator/logicalop/logical_union_scan.go b/pkg/planner/core/operator/logicalop/logical_union_scan.go index 3eb493e7e6867..6bca79636d276 100644 --- a/pkg/planner/core/operator/logicalop/logical_union_scan.go +++ b/pkg/planner/core/operator/logicalop/logical_union_scan.go @@ -32,9 +32,9 @@ import ( type LogicalUnionScan struct { BaseLogicalPlan - Conditions []expression.Expression + Conditions []expression.Expression `hash64-equals:"true"` - HandleCols util.HandleCols + HandleCols util.HandleCols `hash64-equals:"true"` } // Init initializes LogicalUnionScan. diff --git a/pkg/planner/core/operator/logicalop/logical_window.go b/pkg/planner/core/operator/logicalop/logical_window.go index e4636f4721ad1..e218d12bcfc08 100644 --- a/pkg/planner/core/operator/logicalop/logical_window.go +++ b/pkg/planner/core/operator/logicalop/logical_window.go @@ -15,9 +15,12 @@ package logicalop import ( + "fmt" + "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/expression/aggregation" "github.com/pingcap/tidb/pkg/parser/ast" + base2 "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/base" ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util" "github.com/pingcap/tidb/pkg/planner/property" @@ -31,12 +34,12 @@ import ( // LogicalWindow represents a logical window function plan. type LogicalWindow struct { - LogicalSchemaProducer + LogicalSchemaProducer `hash64-equals:"true"` - WindowFuncDescs []*aggregation.WindowFuncDesc - PartitionBy []property.SortItem - OrderBy []property.SortItem - Frame *WindowFrame + WindowFuncDescs []*aggregation.WindowFuncDesc `hash64-equals:"true"` + PartitionBy []property.SortItem `hash64-equals:"true"` + OrderBy []property.SortItem `hash64-equals:"true"` + Frame *WindowFrame `hash64-equals:"true"` } // WindowFrame represents a window function frame. @@ -46,6 +49,36 @@ type WindowFrame struct { End *FrameBound } +// Hash64 implements HashEquals interface. +func (wf *WindowFrame) Hash64(h base2.Hasher) { + h.HashInt(int(wf.Type)) + if wf.Start != nil { + h.HashByte(base2.NotNilFlag) + wf.Start.Hash64(h) + } else { + h.HashByte(base2.NilFlag) + wf.End.Hash64(h) + } +} + +// Equals implements HashEquals interface. +func (wf *WindowFrame) Equals(other any) bool { + wf2, ok := other.(*WindowFrame) + if !ok { + return false + } + if wf == nil { + return wf2 == nil + } + if wf2 == nil { + return false + } + if wf.Type != wf2.Type || !wf.Start.Equals(wf2.Start) || !wf.End.Equals(wf2.End) { + return false + } + return true +} + // Clone copies a window frame totally. func (wf *WindowFrame) Clone() *WindowFrame { cloned := new(WindowFrame) @@ -76,6 +109,85 @@ type FrameBound struct { IsExplicitRange bool } +// Hash64 implement HashEquals interface. +func (fb *FrameBound) Hash64(h base2.Hasher) { + h.HashInt(int(fb.Type)) + h.HashBool(fb.UnBounded) + h.HashUint64(fb.Num) + if fb.CalcFuncs == nil { + h.HashByte(base2.NilFlag) + } else { + h.HashByte(base2.NotNilFlag) + h.HashInt(len(fb.CalcFuncs)) + for _, one := range fb.CalcFuncs { + one.Hash64(h) + } + } + if fb.CompareCols == nil { + h.HashByte(base2.NilFlag) + } else { + h.HashByte(base2.NotNilFlag) + h.HashInt(len(fb.CompareCols)) + for _, one := range fb.CompareCols { + one.Hash64(h) + } + } + if fb.CmpFuncs == nil { + h.HashByte(base2.NilFlag) + } else { + h.HashByte(base2.NotNilFlag) + h.HashInt(len(fb.CmpFuncs)) + for _, f := range fb.CmpFuncs { + h.HashString(fmt.Sprintf("%p", f)) + } + } + h.HashInt64(int64(fb.CmpDataType)) + h.HashBool(fb.IsExplicitRange) +} + +// Equals implement HashEquals interface. +func (fb *FrameBound) Equals(other any) bool { + fb2, ok := other.(*FrameBound) + if !ok { + return false + } + if fb == nil { + return fb2 == nil + } + if fb2 == nil { + return false + } + if fb.Type != fb2.Type || fb.UnBounded != fb2.UnBounded || fb.Num != fb2.Num { + return false + } + if fb.CalcFuncs == nil && fb2.CalcFuncs != nil || fb.CalcFuncs != nil && fb2.CalcFuncs == nil || len(fb.CalcFuncs) != len(fb2.CmpFuncs) { + return false + } + for i, one := range fb.CalcFuncs { + if !one.Equals(fb2.CalcFuncs[i]) { + return false + } + } + if fb.CompareCols == nil && fb2.CompareCols != nil || fb.CompareCols != nil && fb2.CompareCols == nil || len(fb.CompareCols) != len(fb2.CompareCols) { + return false + } + for i, one := range fb.CompareCols { + if !one.Equals(fb2.CompareCols[i]) { + return false + } + } + if fb.CmpFuncs == nil && fb2.CmpFuncs != nil || fb.CmpFuncs != nil && fb2.CmpFuncs == nil || len(fb.CmpFuncs) != len(fb2.CmpFuncs) { + return false + } + for i, one := range fb.CmpFuncs { + // com function addr + if fmt.Sprintf("%p", one) != fmt.Sprintf("%p", fb2.CmpFuncs[i]) { + return false + } + } + return fb.CmpDataType == fb2.CmpDataType && fb.IsExplicitRange == fb2.IsExplicitRange +} + // Clone copies a frame bound totally. func (fb *FrameBound) Clone() *FrameBound { cloned := new(FrameBound) diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel index 6efece856dfb4..2f46bd6c57799 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel +++ b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "logical_mem_table_predicate_extractor_test.go", ], flaky = True, - shard_count = 22, + shard_count = 25, deps = [ "//pkg/domain", "//pkg/expression", @@ -30,6 +30,7 @@ go_test( "//pkg/session/types", "//pkg/testkit", "//pkg/types", + "//pkg/util/chunk", "//pkg/util/hint", "//pkg/util/mock", "//pkg/util/set", diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go index 701495c79c9ae..337a97ffd9e23 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go +++ b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go @@ -26,7 +26,9 @@ import ( "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" "github.com/stretchr/testify/require" ) @@ -498,3 +500,214 @@ func TestLogicalAggregationHash64Equals(t *testing.T) { require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) require.False(t, la1.Equals(la2)) } + +func MockFunc(sctx expression.EvalContext, lhsArg, rhsArg expression.Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + return 0, false, nil +} +func MockFunc2(sctx expression.EvalContext, lhsArg, rhsArg expression.Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + return 0, false, nil +} + +func TestFrameBoundHash64Equals(t *testing.T) { + col := &expression.Column{ + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + Index: 1, + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + fb1 := &logicalop.FrameBound{ + Type: ast.Preceding, + UnBounded: true, + Num: 1, + CalcFuncs: []expression.Expression{col}, + CompareCols: []expression.Expression{col}, + CmpFuncs: []expression.CompareFunc{MockFunc}, + CmpDataType: 1, + IsExplicitRange: false, + } + fb2 := &logicalop.FrameBound{ + Type: ast.Preceding, + UnBounded: true, + Num: 1, + CalcFuncs: []expression.Expression{col}, + CompareCols: []expression.Expression{col}, + CmpFuncs: []expression.CompareFunc{MockFunc}, + CmpDataType: 1, + IsExplicitRange: false, + } + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + fb1.Hash64(hasher1) + fb2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, fb1.Equals(fb2)) + + fb2.Type = ast.CurrentRow + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.Type = ast.Preceding + fb2.UnBounded = false + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.UnBounded = true + fb2.Num = 2 + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.Num = 1 + fb2.CalcFuncs = []expression.Expression{col2} + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.CalcFuncs = []expression.Expression{col} + fb2.CompareCols = []expression.Expression{col2} + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.CompareCols = []expression.Expression{col} + fb2.CmpFuncs = []expression.CompareFunc{MockFunc2} + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.CmpFuncs = []expression.CompareFunc{MockFunc} + hasher2.Reset() + fb2.Hash64(hasher2) + require.True(t, fb1.Equals(fb2)) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + + fb2.CmpDataType = 2 + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) + + fb2.CmpDataType = 1 + fb2.IsExplicitRange = true + hasher2.Reset() + fb2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, fb1.Equals(fb2)) +} + +func TestWindowFrameHash64Equals(t *testing.T) { + col := &expression.Column{ + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + start := &logicalop.FrameBound{ + Type: ast.Preceding, + UnBounded: true, + Num: 1, + CalcFuncs: []expression.Expression{col}, + CompareCols: []expression.Expression{col}, + CmpFuncs: []expression.CompareFunc{MockFunc}, + CmpDataType: 1, + IsExplicitRange: false, + } + end := start + wf1 := &logicalop.WindowFrame{ + Type: 1, + Start: start, + End: end, + } + wf2 := &logicalop.WindowFrame{ + Type: 1, + Start: start, + End: end, + } + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + wf1.Hash64(hasher1) + wf2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, wf1.Equals(wf2)) + + wf2.Type = 2 + hasher2.Reset() + wf2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, wf1.Equals(wf2)) +} + +func TestHandleColsHash64Equals(t *testing.T) { + col1 := &expression.Column{ + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + UniqueID: 2, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + ctx := mock.NewContext() + handles1 := util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 1}, &model.IndexInfo{ID: 1}, []*expression.Column{col1, col2}) + handles2 := util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 1}, &model.IndexInfo{ID: 1}, []*expression.Column{col1, col2}) + + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + handles1.Hash64(hasher1) + handles2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, handles1.Equals(handles2)) + + handles2 = util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 2}, &model.IndexInfo{ID: 1}, []*expression.Column{col1, col2}) + hasher2.Reset() + handles2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, handles1.Equals(handles2)) + + handles2 = util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 1}, &model.IndexInfo{ID: 2}, []*expression.Column{col1, col2}) + hasher2.Reset() + handles2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, handles1.Equals(handles2)) + + handles2 = util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 1}, &model.IndexInfo{ID: 1}, []*expression.Column{col2, col2}) + hasher2.Reset() + handles2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, handles1.Equals(handles2)) + + handles2 = util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 1}, &model.IndexInfo{ID: 1}, []*expression.Column{col2, col1}) + hasher2.Reset() + handles2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, handles1.Equals(handles2)) + + handles2 = util.NewCommonHandlesColsWithoutColsAlign(ctx.GetSessionVars().StmtCtx, &model.TableInfo{ID: 1}, &model.IndexInfo{ID: 1}, []*expression.Column{col1, col2}) + hasher2.Reset() + handles2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, handles1.Equals(handles2)) + + intH1 := util.NewIntHandleCols(col1) + intH2 := util.NewIntHandleCols(col1) + hasher1.Reset() + hasher2.Reset() + intH1.Hash64(hasher1) + intH2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, handles1.Equals(handles2)) + + intH2 = util.NewIntHandleCols(col2) + hasher2.Reset() + intH2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, intH1.Equals(intH2)) +} diff --git a/pkg/planner/util/handle_cols.go b/pkg/planner/util/handle_cols.go index 06d6ad97008a5..21145b512a93d 100644 --- a/pkg/planner/util/handle_cols.go +++ b/pkg/planner/util/handle_cols.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" @@ -34,6 +35,7 @@ import ( // HandleCols is the interface that holds handle columns. type HandleCols interface { expression.StringerWithCtx + base.HashEquals // BuildHandle builds a Handle from a row. BuildHandle(row chunk.Row) (kv.Handle, error) @@ -86,6 +88,59 @@ func (cb *CommonHandleCols) Clone(newCtx *stmtctx.StatementContext) HandleCols { } } +// Hash64 implements HashEquals interface. +func (cb *CommonHandleCols) Hash64(h base.Hasher) { + if cb.tblInfo != nil { + h.HashByte(base.NotNilFlag) + cb.tblInfo.Hash64(h) + } else { + h.HashByte(base.NilFlag) + } + if cb.idxInfo != nil { + h.HashByte(base.NotNilFlag) + cb.idxInfo.Hash64(h) + } else { + h.HashByte(base.NilFlag) + } + if cb.columns != nil { + h.HashByte(base.NotNilFlag) + h.HashInt(len(cb.columns)) + for _, one := range cb.columns { + one.Hash64(h) + } + } else { + h.HashByte(base.NilFlag) + } +} + +// Equals implements HashEquals interface. +func (cb *CommonHandleCols) Equals(other any) bool { + cb2, ok := other.(*CommonHandleCols) + if !ok { + return false + } + if cb == nil { + return cb2 == nil + } + if cb2 == nil { + return false + } + if !cb.tblInfo.Equals(cb2.tblInfo) || !cb.idxInfo.Equals(cb2.idxInfo) { + return false + } + if cb.columns == nil && cb2.columns != nil || + cb.columns != nil && cb2.columns == nil || + len(cb.columns) != len(cb2.columns) { + return false + } + for i, one := range cb.columns { + if !one.Equals(cb2.columns[i]) { + return false + } + } + return true +} + // GetColumns returns all the internal columns out. func (cb *CommonHandleCols) GetColumns() []*expression.Column { return cb.columns @@ -262,6 +317,31 @@ type IntHandleCols struct { col *expression.Column } +// Hash64 implements HashEquals interface. +func (ib *IntHandleCols) Hash64(h base.Hasher) { + if ib.col != nil { + h.HashByte(base.NotNilFlag) + ib.col.Hash64(h) + } else { + h.HashByte(base.NilFlag) + } +} + +// Equals implements HashEquals interface. +func (ib *IntHandleCols) Equals(other any) bool { + ib2, ok := other.(*IntHandleCols) + if !ok { + return false + } + if ib == nil { + return ib2 == nil + } + if ib2 == nil { + return false + } + return ib.col.Equals(ib2.col) +} + // Clone implements the kv.HandleCols interface. func (ib *IntHandleCols) Clone(*stmtctx.StatementContext) HandleCols { return &IntHandleCols{col: ib.col.Clone().(*expression.Column)}