diff --git a/pkg/cmd/vecbench/main.go b/pkg/cmd/vecbench/main.go index 3d1edea8e3ab..ed012cc32925 100644 --- a/pkg/cmd/vecbench/main.go +++ b/pkg/cmd/vecbench/main.go @@ -247,12 +247,12 @@ func downloadDataset(ctx context.Context, datasetName string) { // Use progressWriter to track download progress var buf bytes.Buffer - progressWriter := &progressWriter{ + writer := &progressWriter{ Writer: &buf, Total: attrs.Size, } - if _, err = io.Copy(progressWriter, reader); err != nil { + if _, err = io.Copy(writer, reader); err != nil { log.Fatalf("Failed to copy object data: %v", err) } @@ -348,15 +348,6 @@ func buildIndex(ctx context.Context, datasetName string) { panic(err) } - // Insert empty root partition. - func() { - txn := beginTransaction(ctx, store) - defer commitTransaction(ctx, store, txn) - if err := index.CreateRoot(ctx, txn); err != nil { - panic(err) - } - }() - // Create unique primary key for each vector in a single large byte buffer. primaryKeys := make([]byte, data.Train.Count*4) for i := 0; i < data.Train.Count; i++ { @@ -439,13 +430,12 @@ func loadStore(fileName string) *vecstore.InMemoryStore { panic(err) } - var inMemStore vecstore.InMemoryStore - err = inMemStore.UnmarshalBinary(data) + inMemStore, err := vecstore.LoadInMemoryStore(data) if err != nil { panic(err) } - return &inMemStore + return inMemStore } // loadDataset deserializes a dataset saved as a gob file. diff --git a/pkg/sql/vecindex/fixup_processor.go b/pkg/sql/vecindex/fixup_processor.go index 89c73db4e4ed..c2b922b6594b 100644 --- a/pkg/sql/vecindex/fixup_processor.go +++ b/pkg/sql/vecindex/fixup_processor.go @@ -89,6 +89,9 @@ type fixupProcessor struct { // pendingVectors tracks pending fixups for deleting vectors. pendingVectors map[string]bool + + // waitForFixups broadcasts to any waiters when all fixups are processed. + waitForFixups sync.Cond } // -------------------------------------------------- @@ -101,10 +104,6 @@ type fixupProcessor struct { // maxFixups limit has been reached. fixupsLimitHit log.EveryN - // pendingCount tracks the number of pending fixups that still need to be - // processed. - pendingCount sync.WaitGroup - // -------------------------------------------------- // The following fields should only be accessed on a single background // goroutine (or a single foreground goroutine in deterministic tests). @@ -135,6 +134,7 @@ func (fp *fixupProcessor) Init(index *VectorIndex, seed int64) { } fp.mu.pendingPartitions = make(map[partitionFixupKey]bool, maxFixups) fp.mu.pendingVectors = make(map[string]bool, maxFixups) + fp.mu.waitForFixups.L = &fp.mu fp.fixups = make(chan fixup, maxFixups) fp.fixupsLimitHit = log.Every(time.Second) } @@ -197,7 +197,11 @@ func (fp *fixupProcessor) Start(ctx context.Context) { // Wait blocks until all pending fixups have been processed by the background // goroutine. This is useful in testing. func (fp *fixupProcessor) Wait() { - fp.pendingCount.Wait() + fp.mu.Lock() + defer fp.mu.Unlock() + for len(fp.mu.pendingVectors) > 0 || len(fp.mu.pendingPartitions) > 0 { + fp.mu.waitForFixups.Wait() + } } // runAll processes all fixups in the queue. This should only be called by tests @@ -270,9 +274,6 @@ func (fp *fixupProcessor) run(ctx context.Context, wait bool) (ok bool, err erro fp.mu.Lock() defer fp.mu.Unlock() - // Decrement the number of pending fixups. - fp.pendingCount.Done() - switch next.Type { case splitFixup, mergeFixup: key := partitionFixupKey{Type: next.Type, PartitionKey: next.PartitionKey} @@ -282,6 +283,11 @@ func (fp *fixupProcessor) run(ctx context.Context, wait bool) (ok bool, err erro delete(fp.mu.pendingVectors, string(next.VectorKey)) } + // If there are no more pending fixups, notify any waiters. + if len(fp.mu.pendingPartitions) == 0 && len(fp.mu.pendingVectors) == 0 { + fp.mu.waitForFixups.Broadcast() + } + return true, err } @@ -319,9 +325,6 @@ func (fp *fixupProcessor) addFixup(ctx context.Context, fixup fixup) { panic(errors.AssertionFailedf("unknown fixup %d", fixup.Type)) } - // Increment the number of pending fixups. - fp.pendingCount.Add(1) - // Note that the channel send operation should never block, since it has // maxFixups capacity. fp.fixups <- fixup diff --git a/pkg/sql/vecindex/vecstore/in_memory_store.go b/pkg/sql/vecindex/vecstore/in_memory_store.go index 20db29b5ec74..33e0061fb067 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store.go @@ -58,10 +58,10 @@ type InMemoryStore struct { txnLock syncutil.RWMutex mu struct { syncutil.Mutex - index map[PartitionKey]*Partition - nextKey PartitionKey - vectors map[string]vector.T - stats IndexStats + partitions map[PartitionKey]*Partition + nextKey PartitionKey + vectors map[string]vector.T + stats IndexStats } } @@ -76,9 +76,21 @@ func NewInMemoryStore(dims int, seed int64) *InMemoryStore { dims: dims, seed: seed, } - st.mu.index = make(map[PartitionKey]*Partition) + st.mu.partitions = make(map[PartitionKey]*Partition) + + // Create empty root partition. + var empty vector.Set + quantizer := quantize.NewUnQuantizer(dims) + quantizedSet := quantizer.Quantize(context.Background(), &empty) + st.mu.partitions[RootKey] = &Partition{ + quantizer: quantizer, + quantizedSet: quantizedSet, + level: LeafLevel, + } + st.mu.nextKey = RootKey + 1 st.mu.vectors = make(map[string]vector.T) + st.mu.stats.NumPartitions = 1 return st } @@ -98,7 +110,7 @@ func (s *InMemoryStore) CommitTransaction(ctx context.Context, txn Txn) error { s.mu.Lock() defer s.mu.Unlock() - partition, ok := s.mu.index[inMemTxn.unbalancedKey] + partition, ok := s.mu.partitions[inMemTxn.unbalancedKey] if ok && partition.Count() == 0 && partition.Level() > LeafLevel { panic(errors.AssertionFailedf( "K-means tree is unbalanced, with empty non-leaf partition %d", inMemTxn.unbalancedKey)) @@ -135,7 +147,7 @@ func (s *InMemoryStore) GetPartition( s.mu.Lock() defer s.mu.Unlock() - partition, ok := s.mu.index[partitionKey] + partition, ok := s.mu.partitions[partitionKey] if !ok { return nil, ErrPartitionNotFound } @@ -152,9 +164,9 @@ func (s *InMemoryStore) SetRootPartition(ctx context.Context, txn Txn, partition s.mu.Lock() defer s.mu.Unlock() - _, ok := s.mu.index[RootKey] + _, ok := s.mu.partitions[RootKey] if !ok { - s.mu.stats.NumPartitions++ + panic(errors.AssertionFailedf("the root partition cannot be found")) } // Grow or shrink CVStats slice if a new level is being added or removed. @@ -164,7 +176,7 @@ func (s *InMemoryStore) SetRootPartition(ctx context.Context, txn Txn, partition } s.mu.stats.CVStats = s.mu.stats.CVStats[:expectedLevels] - s.mu.index[RootKey] = partition + s.mu.partitions[RootKey] = partition return nil } @@ -179,7 +191,7 @@ func (s *InMemoryStore) InsertPartition( partitionKey := s.mu.nextKey s.mu.nextKey++ - s.mu.index[partitionKey] = partition + s.mu.partitions[partitionKey] = partition s.mu.stats.NumPartitions++ return partitionKey, nil } @@ -193,11 +205,15 @@ func (s *InMemoryStore) DeletePartition( s.mu.Lock() defer s.mu.Unlock() - _, ok := s.mu.index[partitionKey] + if partitionKey == RootKey { + panic(errors.AssertionFailedf("cannot delete the root partition")) + } + + _, ok := s.mu.partitions[partitionKey] if !ok { return ErrPartitionNotFound } - delete(s.mu.index, partitionKey) + delete(s.mu.partitions, partitionKey) s.mu.stats.NumPartitions-- return nil } @@ -211,7 +227,7 @@ func (s *InMemoryStore) AddToPartition( s.mu.Lock() defer s.mu.Unlock() - partition, ok := s.mu.index[partitionKey] + partition, ok := s.mu.partitions[partitionKey] if !ok { return 0, ErrPartitionNotFound } @@ -230,7 +246,7 @@ func (s *InMemoryStore) RemoveFromPartition( s.mu.Lock() defer s.mu.Unlock() - partition, ok := s.mu.index[partitionKey] + partition, ok := s.mu.partitions[partitionKey] if !ok { return 0, ErrPartitionNotFound } @@ -265,7 +281,7 @@ func (s *InMemoryStore) SearchPartitions( defer s.mu.Unlock() for i := 0; i < len(partitionKeys); i++ { - partition, ok := s.mu.index[partitionKeys[i]] + partition, ok := s.mu.partitions[partitionKeys[i]] if !ok { return 0, ErrPartitionNotFound } @@ -295,7 +311,7 @@ func (s *InMemoryStore) GetFullVectors(ctx context.Context, txn Txn, refs []Vect ref := &refs[i] if ref.Key.PartitionKey != InvalidKey { // Return the partition's centroid. - partition, ok := s.mu.index[ref.Key.PartitionKey] + partition, ok := s.mu.partitions[ref.Key.PartitionKey] if !ok { return ErrPartitionNotFound } @@ -417,14 +433,14 @@ func (s *InMemoryStore) MarshalBinary() (data []byte, err error) { storeProto := StoreProto{ Dims: s.dims, Seed: s.seed, - Partitions: make([]PartitionProto, 0, len(s.mu.index)), + Partitions: make([]PartitionProto, 0, len(s.mu.partitions)), NextKey: s.mu.nextKey, Vectors: make([]VectorProto, 0, len(s.mu.vectors)), Stats: s.mu.stats, } // Remap partitions to protobufs. - for partitionKey, partition := range s.mu.index { + for partitionKey, partition := range s.mu.partitions { partitionProto := PartitionProto{ PartitionKey: partitionKey, ChildKeys: partition.ChildKeys(), @@ -451,24 +467,24 @@ func (s *InMemoryStore) MarshalBinary() (data []byte, err error) { return protoutil.Marshal(&storeProto) } -// UnmarshalBinary loads the in-memory store from bytes that were previously +// LoadInMemoryStore loads the in-memory store from bytes that were previously // saved by MarshalBinary. -func (s *InMemoryStore) UnmarshalBinary(data []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - +func LoadInMemoryStore(data []byte) (*InMemoryStore, error) { // Unmarshal bytes into a protobuf. var storeProto StoreProto if err := protoutil.Unmarshal(data, &storeProto); err != nil { - return err + return nil, err } // Construct the InMemoryStore object. - s.seed = storeProto.Seed - s.mu.index = make(map[PartitionKey]*Partition, len(storeProto.Partitions)) - s.mu.nextKey = storeProto.NextKey - s.mu.vectors = make(map[string]vector.T, len(storeProto.Vectors)) - s.mu.stats = storeProto.Stats + inMemStore := &InMemoryStore{ + dims: storeProto.Dims, + seed: storeProto.Seed, + } + inMemStore.mu.partitions = make(map[PartitionKey]*Partition, len(storeProto.Partitions)) + inMemStore.mu.nextKey = storeProto.NextKey + inMemStore.mu.vectors = make(map[string]vector.T, len(storeProto.Vectors)) + inMemStore.mu.stats = storeProto.Stats raBitQuantizer := quantize.NewRaBitQuantizer(storeProto.Dims, storeProto.Seed) unquantizer := quantize.NewUnQuantizer(storeProto.Dims) @@ -487,16 +503,16 @@ func (s *InMemoryStore) UnmarshalBinary(data []byte) error { partition.quantizer = unquantizer partition.quantizedSet = partitionProto.UnQuantized } - s.mu.index[partitionProto.PartitionKey] = &partition + inMemStore.mu.partitions[partitionProto.PartitionKey] = &partition } // Insert vectors into the in-memory store. for i := range storeProto.Vectors { vectorProto := storeProto.Vectors[i] - s.mu.vectors[string(vectorProto.PrimaryKey)] = vectorProto.Vector + inMemStore.mu.vectors[string(vectorProto.PrimaryKey)] = vectorProto.Vector } - return nil + return inMemStore, nil } // acquireTxnLock acquires a data or partition lock within the scope of the diff --git a/pkg/sql/vecindex/vecstore/in_memory_store_test.go b/pkg/sql/vecindex/vecstore/in_memory_store_test.go index 1edbb20da3cf..e0625656fc41 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store_test.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store_test.go @@ -65,15 +65,10 @@ func TestInMemoryStore(t *testing.T) { }, vectors) }) - t.Run("insert empty root partition into the store", func(t *testing.T) { + t.Run("search empty root partition", func(t *testing.T) { txn := beginTransaction(ctx, t, store) defer commitTransaction(ctx, t, store, txn) - vectors := vector.MakeSet(2) - quantizedSet := quantizer.Quantize(ctx, &vectors) - root := NewPartition(quantizer, quantizedSet, []ChildKey{}, LeafLevel) - require.NoError(t, store.SetRootPartition(ctx, txn, root)) - searchSet := SearchSet{MaxResults: 2} partitionCounts := []int{0} level, err := store.SearchPartitions( @@ -373,7 +368,7 @@ func TestInMemoryStoreMarshalling(t *testing.T) { dims: 2, seed: 42, } - store.mu.index = map[PartitionKey]*Partition{ + store.mu.partitions = map[PartitionKey]*Partition{ 10: { quantizer: unquantizer, quantizedSet: &quantize.UnQuantizedVectorSet{ @@ -418,15 +413,14 @@ func TestInMemoryStoreMarshalling(t *testing.T) { data, err := store.MarshalBinary() require.NoError(t, err) - var store2 InMemoryStore - err = store2.UnmarshalBinary(data) + store2, err := LoadInMemoryStore(data) require.NoError(t, err) - require.Len(t, store2.mu.index, 2) - require.Equal(t, Level(1), store2.mu.index[10].level) - require.Equal(t, 3, store2.mu.index[10].quantizedSet.GetCount()) - require.Equal(t, 2, store2.mu.index[20].quantizer.GetOriginalDims()) - require.Len(t, store2.mu.index[20].childKeys, 3) + require.Len(t, store2.mu.partitions, 2) + require.Equal(t, Level(1), store2.mu.partitions[10].level) + require.Equal(t, 3, store2.mu.partitions[10].quantizedSet.GetCount()) + require.Equal(t, 2, store2.mu.partitions[20].quantizer.GetOriginalDims()) + require.Len(t, store2.mu.partitions[20].childKeys, 3) require.Equal(t, PartitionKey(100), store2.mu.nextKey) require.Len(t, store2.mu.vectors, 2) require.Equal(t, vector.T{12, 13}, store2.mu.vectors[string([]byte{3, 4})]) diff --git a/pkg/sql/vecindex/vector_index.go b/pkg/sql/vecindex/vector_index.go index b50a74a5f884..3c71714cbf2e 100644 --- a/pkg/sql/vecindex/vector_index.go +++ b/pkg/sql/vecindex/vector_index.go @@ -199,18 +199,6 @@ func (vi *VectorIndex) ProcessFixups() { vi.fixups.Wait() } -// CreateRoot creates an empty root partition in the store. This should only be -// called once when the index is first created. -func (vi *VectorIndex) CreateRoot(ctx context.Context, txn vecstore.Txn) error { - // Use the UnQuantizer because vectors in the root are not quantized. - dims := vi.rootQuantizer.GetRandomDims() - vectors := vector.MakeSet(dims) - rootQuantizedSet := vi.rootQuantizer.Quantize(ctx, &vectors) - rootPartition := vecstore.NewPartition( - vi.rootQuantizer, rootQuantizedSet, []vecstore.ChildKey{}, vecstore.LeafLevel) - return vi.store.SetRootPartition(ctx, txn, rootPartition) -} - // Insert adds a new vector with the given primary key to the index. This is // called within the scope of a transaction so that the index does not appear to // change during the insert. diff --git a/pkg/sql/vecindex/vector_index_test.go b/pkg/sql/vecindex/vector_index_test.go index c4fb6796301c..3d61e93967e3 100644 --- a/pkg/sql/vecindex/vector_index_test.go +++ b/pkg/sql/vecindex/vector_index_test.go @@ -129,11 +129,6 @@ func (s *testState) NewIndex(d *datadriven.TestData) string { s.Index, err = NewVectorIndex(s.Ctx, s.InMemStore, s.Quantizer, &s.Options, stopper) require.NoError(s.T, err) - // Insert empty root partition. - txn := beginTransaction(s.Ctx, s.T, s.InMemStore) - require.NoError(s.T, s.Index.CreateRoot(s.Ctx, txn)) - commitTransaction(s.Ctx, s.T, s.InMemStore, txn) - // Insert initial vectors. return s.Insert(d) }