Skip to content

Commit 18f164f

Browse files
authored
Refact next-plaid (#3354)
### What problem does this PR solve? - Larger centroid for better retrieval performance with higher indexing overhead - Batched IVF Probing - Centroid Score Threshold - Sparse Centroid Scoring - IVF index flatting storage - Outlier-based Centroid Expansion - Memory-bounded Compress - Bucket Weight Lookup Table ### Type of change - [x] Refactoring - [x] Performance Improvement
1 parent e96ccb8 commit 18f164f

8 files changed

Lines changed: 969 additions & 144 deletions

src/storage/buffer/file_worker/plaid_index_file_worker.cppm

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ export struct PlaidIndexFileWorker : IndexFileWorker {
4949
start_segment_offset_(start_segment_offset), rel_file_path_(std::make_shared<std::string>(fmt::format("{}/{}", *file_dir_, *file_name_))) {}
5050

5151
~PlaidIndexFileWorker() override {
52-
if (data_ != nullptr || mmap_data_ != nullptr) {
52+
if (data_ != nullptr) {
5353
FreeInMemory();
54-
data_ = nullptr;
55-
mmap_data_ = nullptr;
54+
}
55+
if (mmap_data_ != nullptr) {
56+
FreeFromMmapImpl();
5657
}
5758
}
5859

@@ -80,16 +81,6 @@ protected:
8081
bool ReadFromMmapImpl(const void *ptr, size_t size) override;
8182

8283
void FreeFromMmapImpl() override;
83-
84-
public:
85-
// Template wrappers for FileWorkerMap compatibility
86-
template <typename T = void>
87-
Status CleanupFile() const {
88-
return Status::OK();
89-
}
90-
91-
template <typename T = void>
92-
void MoveFile() {}
9384
};
9485

9586
} // namespace infinity

src/storage/buffer/file_worker/plaid_index_file_worker_impl.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,15 @@ void PlaidIndexFileWorker::AllocateInMemory() {
6363
}
6464

6565
void PlaidIndexFileWorker::FreeInMemory() {
66-
// For mmap-loaded index, object is in mmap_data_
67-
// For regular-loaded index, object is in data_
68-
if (mmap_data_) {
69-
auto *index = reinterpret_cast<PlaidIndex *>(mmap_data_);
70-
delete index;
71-
mmap_data_ = nullptr;
72-
} else if (data_) {
73-
auto *index = static_cast<PlaidIndex *>(data_);
74-
delete index;
75-
data_ = nullptr;
76-
} else {
66+
// FreeInMemory() must only handle data_, never mmap_data_.
67+
// mmap_data_ is exclusively managed by FreeFromMmapImpl() / Munmap().
68+
// This matches the contract used by HNSW, BMP, and all other FileWorkers.
69+
if (!data_) {
7770
UnrecoverableError("PlaidIndexFileWorker::FreeInMemory: Data is not allocated.");
7871
}
72+
auto *index = static_cast<PlaidIndex *>(data_);
73+
delete index;
74+
data_ = nullptr;
7975
}
8076

8177
bool PlaidIndexFileWorker::WriteToFileImpl(bool to_spill, bool &prepare_success, const FileWorkerSaveCtx &ctx) {

src/storage/knn_index/plaid/plaid_global_centroids_impl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ void PlaidGlobalCentroids::Train(const u32 n_centroids, const f32 *embedding_dat
6363
UnrecoverableError(fmt::format("PlaidGlobalCentroids::Train: n_centroids must be a multiple of 8, got {}", n_centroids));
6464
}
6565

66-
// Minimum training data requirement
66+
// Minimum training data requirement - warn but don't fail (new centroid formula may produce larger k)
6767
const u64 min_data = std::max<u64>(32ul * n_centroids, 256ul);
6868
if (embedding_num < min_data) {
69-
LOG_WARN(fmt::format("PlaidGlobalCentroids::Train: Not enough training data. Have {}, need at least {}", embedding_num, min_data));
69+
LOG_WARN(fmt::format("PlaidGlobalCentroids::Train: Not enough training data. Have {}, need at least {}. Proceeding with reduced quality.",
70+
embedding_num,
71+
min_data));
7072
}
7173

7274
n_centroids_ = n_centroids;

src/storage/knn_index/plaid/plaid_index.cppm

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,37 @@ public:
149149
u32 total_docs,
150150
u64 total_embeddings);
151151

152+
// Flatten IVF lists from nested vector to contiguous arrays for cache-friendly search
153+
// Call after all data is added (before dump or after merge)
154+
void FinalizeIVF();
155+
156+
// Check if IVF is flattened (for search path selection)
157+
bool IsIVFFlattened() const { return ivf_flattened_; }
158+
159+
// Access flattened IVF data for a centroid
160+
const u32 *GetIVFListData(u32 centroid_id) const;
161+
u32 GetIVFListLength(u32 centroid_id) const;
162+
163+
// Outlier-based Centroid Expansion for incremental updates
164+
// Finds embeddings far from existing centroids, clusters them into new centroids
165+
// Returns number of new centroids added
166+
u32 ExpandCentroids(const f32 *new_embeddings, u64 n_new_embeddings, f32 cluster_threshold);
167+
168+
// Find outlier embeddings (distance > threshold^2 from nearest centroid)
169+
// Returns indices of outlier embeddings
170+
std::vector<u64> FindOutliers(const f32 *embeddings, u64 n_embeddings, f32 threshold_sq) const;
171+
152172
public:
153173
// Fixed parameters (set at construction)
154174
const u32 start_segment_offset_ = 0;
155175
const u32 embedding_dimension_ = 0;
156176
const u32 nbits_ = 4;
157177
const u32 requested_n_centroids_ = 0; // 0 = auto
158178

179+
// Batch size for batched IVF probing (0 = auto, based on n_centroids)
180+
static constexpr u32 DEFAULT_CENTROID_BATCH_SIZE = 4096;
181+
u32 centroid_batch_size_ = DEFAULT_CENTROID_BATCH_SIZE;
182+
159183
// Trained parameters
160184
u32 n_centroids_ = 0;
161185
std::vector<f32> centroids_data_; // [n_centroids_, embedding_dimension_]
@@ -174,9 +198,16 @@ public:
174198
std::vector<u32> centroid_ids_; // [n_total_embeddings_] centroid assignment for each embedding
175199
std::unique_ptr<u8[]> packed_residuals_; // Quantized residuals
176200
size_t packed_residuals_size_ = 0;
201+
size_t packed_residuals_capacity_ = 0; // Tracked capacity for amortized growth in MergeOneChunk
177202

178203
// Inverted index: centroid -> doc ids
204+
// Legacy: nested vector (used during build/incremental update)
179205
std::vector<std::vector<u32>> ivf_lists_; // [n_centroids_] posting lists
206+
// Flattened: contiguous storage (used after FinalizeIVF() for cache-friendly search)
207+
std::vector<u32> ivf_data_; // contiguous posting list entries
208+
std::vector<u32> ivf_offsets_; // [n_centroids_] start offset of each list in ivf_data_
209+
std::vector<u32> ivf_lengths_; // [n_centroids_] length of each list
210+
bool ivf_flattened_ = false; // true when using flattened IVF
180211

181212
// Quantizer
182213
std::unique_ptr<PlaidQuantizer> quantizer_;
@@ -206,6 +237,17 @@ public:
206237
const BlockIndex *block_index,
207238
TxnTimeStamp begin_ts) const;
208239

240+
// Batched search path: memory-efficient for large centroid counts
241+
PlaidQueryResultType GetQueryResultBatched(const f32 *query_ptr,
242+
u32 query_embedding_num,
243+
u32 n_ivf_probe,
244+
f32 centroid_score_threshold,
245+
u32 n_doc_to_score,
246+
u32 n_full_scores,
247+
u32 top_k,
248+
Bitmask &bitmask,
249+
u32 start_segment_offset) const;
250+
209251
// Compute approximate score using centroid lookups
210252
f32 ApproximateScore(const u32 *doc_centroid_ids, u32 doc_len, const f32 *query_centroid_scores, u32 n_query_tokens) const;
211253

@@ -214,6 +256,36 @@ public:
214256

215257
// Helper for batch centroid scoring
216258
std::unique_ptr<f32[]> ComputeQueryCentroidScores(const f32 *query_ptr, u32 n_query_tokens) const;
259+
260+
// Batched IVF Probing: processes centroids in chunks to bound memory usage
261+
// Returns: set of centroid IDs to probe (union of top-k across all query tokens)
262+
// Each query token maintains its own top-k heap; final result is the union
263+
void BatchedIVFProbe(const f32 *query_ptr,
264+
u32 n_query_tokens,
265+
u32 n_ivf_probe,
266+
f32 centroid_score_threshold,
267+
u32 centroid_batch_size,
268+
std::vector<u32> &probed_centroids,
269+
std::unique_ptr<f32[]> &sparse_centroid_scores,
270+
u32 &n_sparse_centroids) const;
271+
272+
// Compute approximate score using sparse centroid score lookup (for batched path)
273+
f32 ApproximateScoreSparse(const u32 *doc_centroid_ids,
274+
u32 doc_len,
275+
const f32 *sparse_scores,
276+
const u32 *sparse_centroid_id_map,
277+
u32 n_sparse_centroids,
278+
u32 n_query_tokens) const;
279+
280+
// Static helper: compute auto n_centroids using next-plaid formula
281+
static u32 ComputeAutoNCentroids(u64 embedding_count);
282+
283+
// Ensure IVF is in mutable (nested vector) form.
284+
// If currently flattened, reconstructs ivf_lists_ from ivf_data_ and clears
285+
// the flattened arrays. Must be called at the start of any IVF-mutating
286+
// operation (Add*, Merge*, ExpandCentroids, etc.).
287+
// Caller must hold an exclusive lock (rw_mutex_) before calling.
288+
void EnsureMutableIVF();
217289
};
218290

219291
} // namespace infinity

0 commit comments

Comments
 (0)