While running this part of the code(python==3.10.12 & faiss-gpu==1.7.2):
from retro_pytorch.retrieval import chunks_to_index_and_embed
index, embeddings = chunks_to_index_and_embed(
num_chunks = 1000,
chunk_size = 64,
chunk_memmap_path = './train.chunks.dat'
)
query_vector = embeddings[:1] # use first embedding as query
_, indices = index.search(query_vector, k = 2) # fetch 2 neighbors, first indices should be self
neighbor_embeddings = embeddings[indices] # (1, 2, 768)

While running this part of the code(python==3.10.12 & faiss-gpu==1.7.2):
from retro_pytorch.retrieval import chunks_to_index_and_embed
index, embeddings = chunks_to_index_and_embed(
num_chunks = 1000,
chunk_size = 64,
chunk_memmap_path = './train.chunks.dat'
)
query_vector = embeddings[:1] # use first embedding as query
_, indices = index.search(query_vector, k = 2) # fetch 2 neighbors, first indices should be self
neighbor_embeddings = embeddings[indices] # (1, 2, 768)