#pragma once

#include "visited_list_pool.h"
#include "hnswlib.h"
#include <atomic>
#include <random>
#include <stdlib.h>
#include <streambuf>
#include <istream>
#include <assert.h>
#include <memory>
#include <string>
#include <unordered_set>
#include <set>
#include <cstring>

namespace hnswlib
{
    #define HEADER_FIELDS(ACTION) \
    ACTION(PERSISTENCE_VERSION) \
    ACTION(offsetLevel0_) \
    ACTION(max_elements_) \
    ACTION(cur_element_count) \
    ACTION(size_data_per_element_) \
    ACTION(label_offset_) \
    ACTION(offsetData_) \
    ACTION(maxlevel_) \
    ACTION(enterpoint_node_) \
    ACTION(maxM_) \
    ACTION(maxM0_) \
    ACTION(M_) \
    ACTION(mult_) \
    ACTION(ef_construction_)
    
    // --- Input Streambuf (read from memory) ---
    class in_membuf : public std::streambuf {
    public:
        in_membuf(const char* data, std::size_t size) {
            char* p = const_cast<char*>(data);
            setg(p, p, p + size);  // set get pointers
        }
    };
    
    // --- Input Stream ---
    class memistream : public std::istream {
     public:
        memistream(const char* data, std::size_t size)
            : std::istream(&buf), buf(data, size) {}
    
     private:
        in_membuf buf;
    };

    struct InputPersistenceStreams {
        std::shared_ptr<std::istream> header_stream;
        std::shared_ptr<std::istream> data_level0_stream;
        std::shared_ptr<std::istream> length_stream;
        std::shared_ptr<std::istream> link_list_stream;
    };

    // --- Memory Buffer Container for Persistence ---
    template <typename BufferType>
    struct HnswData {
        static_assert(
            std::is_same<BufferType, char>::value || std::is_same<BufferType, const char>::value,
            "HnswData BufferType must be char or const char"
        );

        BufferType* header_buffer;
        size_t header_size;
        BufferType* data_level0_buffer;
        size_t data_level0_size;
        BufferType* length_buffer;
        size_t length_size;
        BufferType* link_list_buffer;
        size_t link_list_size;
        
        ~HnswData() {}
        
        // Constructor for external buffers (from Rust)
        HnswData(BufferType* header_buf, size_t header_sz,
                 BufferType* data_level0_buf, size_t data_level0_sz,
                 BufferType* length_buf, size_t length_sz,
                 BufferType* link_list_buf, size_t link_list_sz) 
            : header_buffer(header_buf), header_size(header_sz),
              data_level0_buffer(data_level0_buf), data_level0_size(data_level0_sz),
              length_buffer(length_buf), length_size(length_sz),
              link_list_buffer(link_list_buf), link_list_size(link_list_sz) {}

        const HnswData<const BufferType> *getView() const {
            return reinterpret_cast<const HnswData<const BufferType> *>(this);
        }

        // USED ONLY IN TESTS
        bool matchesWithDirectory(const std::string& directory) const {
            if (header_buffer == nullptr || data_level0_buffer == nullptr || length_buffer == nullptr || link_list_buffer == nullptr) {
                printf("HnswData is not initialized\n");
                return false;
            }

            struct file_test_info {
                const char *filename;
                const BufferType *buffer_ptr;
                size_t buffer_size;
            };

            file_test_info files[] = {
                {"header", header_buffer, header_size},
                {"data_level0", data_level0_buffer, data_level0_size},
                {"length", length_buffer, length_size},
                {"link_lists", link_list_buffer, link_list_size}
            };
            
            for (const auto& file : files) {
                printf("testing %s\n", file.filename);
                std::ifstream file_stream(directory + "/" + file.filename + ".bin");
                if (!file_stream.is_open()) {
                    printf("File %s not found\n", file.filename);
                    return false;
                }

                file_stream.seekg(0, file_stream.end);
                size_t file_size = file_stream.tellg();
                file_stream.seekg(0, file_stream.beg);

                if (file_size != file.buffer_size) {
                    printf("File %s size mismatch %ld != %ld\n", file.filename, file_size, file.buffer_size);
                    return false;
                }

                std::vector<char> file_content(file_size);
                file_stream.read(file_content.data(), file_size);
                file_stream.close();
            
                if (file_content != std::vector<char>(file.buffer_ptr, file.buffer_ptr + file.buffer_size)) {
                    printf("File %s content mismatch\n", file.filename);
                    printf("File content:\n");
                    for (size_t i = 0; i < file_size; i++) {
                        printf("%02x ", file_content[i]);
                    }
                    printf("\n");
                    printf("Buffer content:\n");
                    for (size_t i = 0; i < file.buffer_size; i++) {
                        printf("%02x ", file.buffer_ptr[i]);
                    }
                    printf("\n");
                    return false;
                }
            }

            return true;
        }
    };

    using HnswDataMut = HnswData<char>;
    using HnswDataView = HnswData<const char>;

    typedef unsigned int tableint;
    typedef unsigned int linklistsizeint;
    const int PERSISTENCE_VERSION = 1; // Used by persistent indices to check if the index on disk is compatible with the code

    template <typename dist_t>
    class HierarchicalNSW : public AlgorithmInterface<dist_t>
    {
    public:
        static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
        static const unsigned char DELETE_MARK = 0x01;

        size_t max_elements_{0};
        mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
        size_t size_data_per_element_{0};
        unsigned int size_links_per_element_{0};
        mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
        size_t M_{0};
        size_t maxM_{0};
        size_t maxM0_{0};
        size_t ef_construction_{0};
        size_t ef_{0};

        double mult_{0.0}, revSize_{0.0};
        int maxlevel_{0};

        VisitedListPool *visited_list_pool_{nullptr};

        // Locks operations with element by label value
        mutable std::vector<std::mutex> label_op_locks_;

        std::mutex global;
        std::vector<std::mutex> link_list_locks_;

        tableint enterpoint_node_{0};

        size_t size_links_level0_{0};
        size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{0};

        char *data_level0_memory_{nullptr};
        char *length_memory_{nullptr};
        char **linkLists_{nullptr};
        std::vector<int> element_levels_; // keeps level of each element

        size_t data_size_{0};

        DISTFUNC<dist_t> fstdistfunc_;
        void *dist_func_param_{nullptr};

        mutable std::mutex label_lookup_lock; // lock for label_lookup_
        std::unordered_map<labeltype, tableint> label_lookup_;

        std::default_random_engine level_generator_;
        std::default_random_engine update_probability_generator_;

        mutable std::atomic<long> metric_distance_computations{0};
        mutable std::atomic<long> metric_hops{0};

        bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
        bool normalize_ = false;             // flag to normalize vectors before insertion

        std::mutex deleted_elements_lock;              // lock for deleted_elements
        std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements

        bool persist_on_write_ = false;
        std::string persist_location_;
        std::mutex elements_to_persist_lock_;    // lock for elements_to_persist_
        std::set<tableint> elements_to_persist_; // dirty elements to persist
        // File handles for persistence
        std::ofstream output_header_;              // output stream for header
        std::ofstream output_data_level0_;         // output stream for data level 0
        std::ofstream output_length_;              // output stream for length
        std::ofstream output_link_lists_;          // output stream for link lists
        bool _persist_file_handles_opened = false; // flag to check if file handles are opened

        HierarchicalNSW(SpaceInterface<dist_t> *s)
        {
        }

        HierarchicalNSW(
            SpaceInterface<dist_t> *s,
            const std::string &location,
            bool nmslib = false,
            size_t max_elements = 0,
            bool allow_replace_deleted = false,
            bool normalize = false,
            bool persist_on_write = false)
            : allow_replace_deleted_(allow_replace_deleted),
              normalize_(normalize),
              persist_on_write_(persist_on_write),
              persist_location_(location)
        {
            // Persisted indices are stored differently
            if (persist_on_write_)
            {
                loadPersistedIndex(s, max_elements);
            }
            else
            {
                loadIndex(location, s, max_elements);
            }
        }

        HierarchicalNSW(
            SpaceInterface<dist_t> *s,
            const HnswDataView &buffers,
            bool nmslib = false,
            size_t max_elements = 0,
            bool allow_replace_deleted = false,
            bool normalize = false)
            : allow_replace_deleted_(allow_replace_deleted),
              normalize_(normalize),
              persist_on_write_(false),
              persist_location_("")
        {
            loadPersistedIndexFromMemory(s, buffers, max_elements);
        }

        HierarchicalNSW(
            SpaceInterface<dist_t> *s,
            size_t max_elements,
            size_t M = 16,
            size_t ef_construction = 200,
            size_t random_seed = 100,
            bool allow_replace_deleted = false,
            bool normalize = false,
            bool persist_on_write = false,
            const std::string &persist_location = "")
            : link_list_locks_(max_elements),
              label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
              element_levels_(max_elements),
              allow_replace_deleted_(allow_replace_deleted),
              normalize_(normalize),
              persist_on_write_(persist_on_write),
              persist_location_(persist_location)
        {
            max_elements_ = max_elements;
            num_deleted_ = 0;
            data_size_ = s->get_data_size();
            fstdistfunc_ = s->get_dist_func();
            dist_func_param_ = s->get_dist_func_param();
            M_ = M;
            maxM_ = M_;
            maxM0_ = M_ * 2;
            ef_construction_ = std::max(ef_construction, M_);
            ef_ = 10;

            level_generator_.seed(random_seed);
            update_probability_generator_.seed(random_seed + 1);

            size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
            size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
            offsetData_ = size_links_level0_;
            label_offset_ = size_links_level0_ + data_size_;
            offsetLevel0_ = 0;

            data_level0_memory_ = (char *)malloc(max_elements_ * size_data_per_element_);
            if (data_level0_memory_ == nullptr)
                throw std::runtime_error("Not enough memory");

            length_memory_ = (char *)malloc(max_elements_ * sizeof(float));
            if (length_memory_ == nullptr)
                throw std::runtime_error("Not enough memory");

            cur_element_count = 0;

            visited_list_pool_ = new VisitedListPool(1, max_elements);

            // initializations for special treatment of the first node
            enterpoint_node_ = -1;
            maxlevel_ = -1;

            linkLists_ = (char **)malloc(sizeof(void *) * max_elements_);
            if (linkLists_ == nullptr)
                throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
            size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
            mult_ = 1 / log(1.0 * M_);
            revSize_ = 1.0 / mult_;

            if (persist_on_write_)
            {
                if (persist_location_.empty())
                {
                    throw std::runtime_error("persist_location_ is empty");
                }
                initPersistentIndex();
            }
        }

        ~HierarchicalNSW()
        {
            free(data_level0_memory_);
            for (tableint i = 0; i < cur_element_count; i++)
            {
                if (element_levels_[i] > 0)
                    free(linkLists_[i]);
            }
            free(linkLists_);
            free(length_memory_);
            delete visited_list_pool_;
            closePersistentIndex();
        }

        struct CompareByFirst
        {
            constexpr bool operator()(std::pair<dist_t, tableint> const &a,
                                      std::pair<dist_t, tableint> const &b) const noexcept
            {
                return a.first < b.first;
            }
        };

        void setEf(size_t ef)
        {
            ef_ = ef;
        }

        inline std::mutex &getLabelOpMutex(labeltype label) const
        {
            // calculate hash
            size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
            return label_op_locks_[lock_id];
        }

        inline labeltype getExternalLabel(tableint internal_id) const
        {
            labeltype return_label;
            memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
            return return_label;
        }

        inline void setExternalLabel(tableint internal_id, labeltype label) const
        {
            memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
        }

        inline labeltype *getExternalLabeLp(tableint internal_id) const
        {
            return (labeltype *)(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
        }

        inline char *getDataByInternalId(tableint internal_id) const
        {
            return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
        }

        int getRandomLevel(double reverse_size)
        {
            std::uniform_real_distribution<double> distribution(0.0, 1.0);
            double r = -log(distribution(level_generator_)) * reverse_size;
            return (int)r;
        }

        size_t getMaxElements()
        {
            return max_elements_;
        }

        size_t getCurrentElementCount()
        {
            return cur_element_count;
        }

        size_t getDeletedCount()
        {
            return num_deleted_;
        }

        float normalize_vector(float *data, float *norm_array, size_t dim)
        {
            float norm = 0.0f;
            for (int i = 0; i < dim; i++)
                norm += data[i] * data[i];
            float length = sqrtf(norm);
            norm = 1.0f / (length + 1e-30f);
            for (int i = 0; i < dim; i++)
            {
                norm_array[i] = data[i] * norm;
            }
            return length;
        }

        std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
        searchBaseLayer(tableint ep_id, const void *data_point, int layer)
        {
            VisitedList *vl = visited_list_pool_->getFreeVisitedList();
            vl_type *visited_array = vl->mass;
            vl_type visited_array_tag = vl->curV;

            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;

            dist_t lowerBound;
            if (!isMarkedDeleted(ep_id))
            {
                dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
                top_candidates.emplace(dist, ep_id);
                lowerBound = dist;
                candidateSet.emplace(-dist, ep_id);
            }
            else
            {
                lowerBound = std::numeric_limits<dist_t>::max();
                candidateSet.emplace(-lowerBound, ep_id);
            }
            visited_array[ep_id] = visited_array_tag;

            while (!candidateSet.empty())
            {
                std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
                if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_)
                {
                    break;
                }
                candidateSet.pop();

                tableint curNodeNum = curr_el_pair.second;

                std::unique_lock<std::mutex> lock(link_list_locks_[curNodeNum]);

                int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
                if (layer == 0)
                {
                    data = (int *)get_linklist0(curNodeNum);
                }
                else
                {
                    data = (int *)get_linklist(curNodeNum, layer);
                    //                    data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
                }
                size_t size = getListCount((linklistsizeint *)data);
                tableint *datal = (tableint *)(data + 1);
#ifdef USE_SSE
                _mm_prefetch((char *)(visited_array + *(data + 1)), _MM_HINT_T0);
                _mm_prefetch((char *)(visited_array + *(data + 1) + 64), _MM_HINT_T0);
                _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
                _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
#endif

                for (size_t j = 0; j < size; j++)
                {
                    tableint candidate_id = *(datal + j);
//                    if (candidate_id == 0) continue;
#ifdef USE_SSE
                    _mm_prefetch((char *)(visited_array + *(datal + j + 1)), _MM_HINT_T0);
                    _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
#endif
                    if (visited_array[candidate_id] == visited_array_tag)
                        continue;
                    visited_array[candidate_id] = visited_array_tag;
                    char *currObj1 = (getDataByInternalId(candidate_id));

                    dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
                    if (top_candidates.size() < ef_construction_ || lowerBound > dist1)
                    {
                        candidateSet.emplace(-dist1, candidate_id);
#ifdef USE_SSE
                        _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif

                        if (!isMarkedDeleted(candidate_id))
                            top_candidates.emplace(dist1, candidate_id);

                        if (top_candidates.size() > ef_construction_)
                            top_candidates.pop();

                        if (!top_candidates.empty())
                            lowerBound = top_candidates.top().first;
                    }
                }
            }
            visited_list_pool_->releaseVisitedList(vl);

            return top_candidates;
        }

        template <bool has_deletions, bool collect_metrics = false>
        std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
        searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor *isIdAllowed = nullptr) const
        {
            VisitedList *vl = visited_list_pool_->getFreeVisitedList();
            vl_type *visited_array = vl->mass;
            vl_type visited_array_tag = vl->curV;

            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

            dist_t lowerBound;
            if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))
            {
                dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
                lowerBound = dist;
                top_candidates.emplace(dist, ep_id);
                candidate_set.emplace(-dist, ep_id);
            }
            else
            {
                lowerBound = std::numeric_limits<dist_t>::max();
                candidate_set.emplace(-lowerBound, ep_id);
            }

            visited_array[ep_id] = visited_array_tag;

            while (!candidate_set.empty())
            {
                std::pair<dist_t, tableint> current_node_pair = candidate_set.top();

                if ((-current_node_pair.first) > lowerBound &&
                    (top_candidates.size() == ef || (!isIdAllowed && !has_deletions)))
                {
                    break;
                }
                candidate_set.pop();

                tableint current_node_id = current_node_pair.second;
                int *data = (int *)get_linklist0(current_node_id);
                size_t size = getListCount((linklistsizeint *)data);
                //                bool cur_node_deleted = isMarkedDeleted(current_node_id);
                if (collect_metrics)
                {
                    metric_hops++;
                    metric_distance_computations += size;
                }

#ifdef USE_SSE
                _mm_prefetch((char *)(visited_array + *(data + 1)), _MM_HINT_T0);
                _mm_prefetch((char *)(visited_array + *(data + 1) + 64), _MM_HINT_T0);
                _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
                _mm_prefetch((char *)(data + 2), _MM_HINT_T0);
#endif

                for (size_t j = 1; j <= size; j++)
                {
                    int candidate_id = *(data + j);
//                    if (candidate_id == 0) continue;
#ifdef USE_SSE
                    _mm_prefetch((char *)(visited_array + *(data + j + 1)), _MM_HINT_T0);
                    _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
                                 _MM_HINT_T0); ////////////
#endif
                    if (!(visited_array[candidate_id] == visited_array_tag))
                    {
                        visited_array[candidate_id] = visited_array_tag;

                        char *currObj1 = (getDataByInternalId(candidate_id));
                        dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

                        if (top_candidates.size() < ef || lowerBound > dist)
                        {
                            candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
                            _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
                                             offsetLevel0_, ///////////
                                         _MM_HINT_T0);      ////////////////////////
#endif

                            if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
                                top_candidates.emplace(dist, candidate_id);

                            if (top_candidates.size() > ef)
                                top_candidates.pop();

                            if (!top_candidates.empty())
                                lowerBound = top_candidates.top().first;
                        }
                    }
                }
            }

            visited_list_pool_->releaseVisitedList(vl);
            return top_candidates;
        }

        void getNeighborsByHeuristic2(
            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
            const size_t M)
        {
            if (top_candidates.size() < M)
            {
                return;
            }

            std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
            std::vector<std::pair<dist_t, tableint>> return_list;
            while (top_candidates.size() > 0)
            {
                queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
                top_candidates.pop();
            }

            while (queue_closest.size())
            {
                if (return_list.size() >= M)
                    break;
                std::pair<dist_t, tableint> curent_pair = queue_closest.top();
                dist_t dist_to_query = -curent_pair.first;
                queue_closest.pop();
                bool good = true;

                for (std::pair<dist_t, tableint> second_pair : return_list)
                {
                    dist_t curdist =
                        fstdistfunc_(getDataByInternalId(second_pair.second),
                                     getDataByInternalId(curent_pair.second),
                                     dist_func_param_);
                    if (curdist < dist_to_query)
                    {
                        good = false;
                        break;
                    }
                }
                if (good)
                {
                    return_list.push_back(curent_pair);
                }
            }

            for (std::pair<dist_t, tableint> curent_pair : return_list)
            {
                top_candidates.emplace(-curent_pair.first, curent_pair.second);
            }
        }

        linklistsizeint *get_linklist0(tableint internal_id) const
        {
            return (linklistsizeint *)(data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
        }

        linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const
        {
            return (linklistsizeint *)(data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
        }

        linklistsizeint *get_linklist(tableint internal_id, int level) const
        {
            return (linklistsizeint *)(linkLists_[internal_id] + (level - 1) * size_links_per_element_);
        }

        linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const
        {
            return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
        }

        void markElementToPersist(tableint internal_id)
        {
            std::unique_lock<std::mutex> lock_elements_to_persist(elements_to_persist_lock_);
            elements_to_persist_.insert(internal_id);
        }

        tableint mutuallyConnectNewElement(
            const void *data_point,
            tableint cur_c,
            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
            int level,
            bool isUpdate)
        {

            // mark cur_c as dirty
            markElementToPersist(cur_c);

            size_t Mcurmax = level ? maxM_ : maxM0_;
            getNeighborsByHeuristic2(top_candidates, M_);
            if (top_candidates.size() > M_)
                throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");

            std::vector<tableint> selectedNeighbors;
            selectedNeighbors.reserve(M_);
            while (top_candidates.size() > 0)
            {
                selectedNeighbors.push_back(top_candidates.top().second);
                top_candidates.pop();
            }

            tableint next_closest_entry_point = selectedNeighbors.back();

            {
                // lock only during the update
                // because during the addition the lock for cur_c is already acquired
                std::unique_lock<std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
                if (isUpdate)
                {
                    lock.lock();
                }
                linklistsizeint *ll_cur;
                if (level == 0)
                    ll_cur = get_linklist0(cur_c);
                else
                    ll_cur = get_linklist(cur_c, level);

                if (*ll_cur && !isUpdate)
                {
                    throw std::runtime_error("The newly inserted element should have blank link list");
                }
                setListCount(ll_cur, selectedNeighbors.size());
                tableint *data = (tableint *)(ll_cur + 1);
                for (size_t idx = 0; idx < selectedNeighbors.size(); idx++)
                {
                    if (data[idx] && !isUpdate)
                        throw std::runtime_error("Possible memory corruption");
                    if (level > element_levels_[selectedNeighbors[idx]])
                        throw std::runtime_error("Trying to make a link on a non-existent level");

                    data[idx] = selectedNeighbors[idx];
                }
            }

            for (size_t idx = 0; idx < selectedNeighbors.size(); idx++)
            {
                // Note: We may want to lock _elements_to_persist outside the loop. Should profile this to see if it matters.
                markElementToPersist(selectedNeighbors[idx]);

                std::unique_lock<std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);

                linklistsizeint *ll_other;
                if (level == 0)
                    ll_other = get_linklist0(selectedNeighbors[idx]);
                else
                    ll_other = get_linklist(selectedNeighbors[idx], level);

                size_t sz_link_list_other = getListCount(ll_other);

                if (sz_link_list_other > Mcurmax)
                    throw std::runtime_error("Bad value of sz_link_list_other");
                if (selectedNeighbors[idx] == cur_c)
                    throw std::runtime_error("Trying to connect an element to itself");
                if (level > element_levels_[selectedNeighbors[idx]])
                    throw std::runtime_error("Trying to make a link on a non-existent level");

                tableint *data = (tableint *)(ll_other + 1);

                bool is_cur_c_present = false;
                if (isUpdate)
                {
                    for (size_t j = 0; j < sz_link_list_other; j++)
                    {
                        if (data[j] == cur_c)
                        {
                            is_cur_c_present = true;
                            break;
                        }
                    }
                }

                // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
                if (!is_cur_c_present)
                {
                    if (sz_link_list_other < Mcurmax)
                    {
                        data[sz_link_list_other] = cur_c;
                        setListCount(ll_other, sz_link_list_other + 1);
                    }
                    else
                    {
                        // finding the "weakest" element to replace it with the new one
                        dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
                                                    dist_func_param_);
                        // Heuristic:
                        std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
                        candidates.emplace(d_max, cur_c);

                        for (size_t j = 0; j < sz_link_list_other; j++)
                        {
                            candidates.emplace(
                                fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
                                             dist_func_param_),
                                data[j]);
                        }

                        getNeighborsByHeuristic2(candidates, Mcurmax);

                        int indx = 0;
                        while (candidates.size() > 0)
                        {
                            data[indx] = candidates.top().second;
                            candidates.pop();
                            indx++;
                        }

                        setListCount(ll_other, indx);
                        // Nearest K:
                        /*int indx = -1;
                        for (int j = 0; j < sz_link_list_other; j++) {
                            dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
                            if (d > d_max) {
                                indx = j;
                                d_max = d;
                            }
                        }
                        if (indx >= 0) {
                            data[indx] = cur_c;
                        } */
                    }
                }
            }

            return next_closest_entry_point;
        }

        void resizeIndex(size_t new_max_elements)
        {
            if (new_max_elements < cur_element_count)
                throw std::runtime_error("Cannot resize, max element is less than the current number of elements");

            delete visited_list_pool_;
            visited_list_pool_ = new VisitedListPool(1, new_max_elements);

            element_levels_.resize(new_max_elements);

            std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);

            // Reallocate base layer
            char *data_level0_memory_new = (char *)realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
            if (data_level0_memory_new == nullptr)
                throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
            data_level0_memory_ = data_level0_memory_new;

            // Reallocate length memory
            char *length_memory_new = (char *)realloc(length_memory_, new_max_elements * sizeof(float));
            if (length_memory_new == nullptr)
                throw std::runtime_error("Not enough memory: resizeIndex failed to allocate length memory");
            length_memory_ = length_memory_new;

            // Reallocate all other layers
            char **linkLists_new = (char **)realloc(linkLists_, sizeof(void *) * new_max_elements);
            if (linkLists_new == nullptr)
                throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
            linkLists_ = linkLists_new;

            max_elements_ = new_max_elements;
        }

        void saveIndex(const std::string &location)
        {
            std::ofstream output(location, std::ios::binary);
            std::streampos position;

            // IF THIS IS CHANGED: PLEASE MAKE CORRESPONDING MODIFICATIONS TO
            // THE HEADER_FIELDS MACRO.
            writeBinaryPOD(output, offsetLevel0_);
            writeBinaryPOD(output, max_elements_);
            writeBinaryPOD(output, cur_element_count);
            writeBinaryPOD(output, size_data_per_element_);
            writeBinaryPOD(output, label_offset_);
            writeBinaryPOD(output, offsetData_);
            writeBinaryPOD(output, maxlevel_);
            writeBinaryPOD(output, enterpoint_node_);
            writeBinaryPOD(output, maxM_);

            writeBinaryPOD(output, maxM0_);
            writeBinaryPOD(output, M_);
            writeBinaryPOD(output, mult_);
            writeBinaryPOD(output, ef_construction_);

            output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
            output.write(length_memory_, cur_element_count * sizeof(float));

            for (size_t i = 0; i < cur_element_count; i++)
            {
                unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
                writeBinaryPOD(output, linkListSize);
                if (linkListSize)
                    output.write(linkLists_[i], linkListSize);
            }
            output.close();
        }

        std::string getHeaderLocation()
        {
            return persist_location_ + "/header.bin";
        }

        std::string getDataLevel0Location()
        {
            return persist_location_ + "/data_level0.bin";
        }

        std::string getLengthLocation()
        {
            return persist_location_ + "/length.bin";
        }

        std::string getLinkListLocation()
        {
            return persist_location_ + "/link_lists.bin";
        }

        // #pragma region PersistentIndex
        void setupPersistentIndexFileHandles()
        {
            this->output_header_ = std::ofstream(this->getHeaderLocation(), std::ios::in | std::ios::out | std::ios::binary);
            if (!this->output_header_.is_open())
            {
                std::runtime_error("Cannot open file: " + this->getHeaderLocation());
            }

            this->output_data_level0_ = std::ofstream(this->getDataLevel0Location(), std::ios::in | std::ios::out | std::ios::binary);
            if (!this->output_data_level0_.is_open())
            {
                std::runtime_error("Cannot open file: " + this->getDataLevel0Location());
            }

            this->output_length_ = std::ofstream(this->getLengthLocation(), std::ios::in | std::ios::out | std::ios::binary);
            if (!this->output_length_.is_open())
            {
                std::runtime_error("Cannot open file: " + this->getLengthLocation());
            }

            this->output_link_lists_ = std::ofstream(this->getLinkListLocation(), std::ios::in | std::ios::out | std::ios::binary);
            if (!this->output_link_lists_.is_open())
            {
                std::runtime_error("Cannot open file: " + this->getLinkListLocation());
            }
        }

        void closePersistentIndexFileHandles()
        {
            this->output_header_.close();
            this->output_data_level0_.close();
            this->output_length_.close();
            this->output_link_lists_.close();
        }

        void openPersistentIndex()
        {
            // A persisted index is stored as four files, this function opens them for reading, so that calling processes can manage the file handle
            // utilization themselves. This function is safe to repeatedly call, it will only open the files if they are not already open.
            if (!_persist_file_handles_opened && persist_on_write_)
            {
                setupPersistentIndexFileHandles();
                _persist_file_handles_opened = true;
            }
        }

        void closePersistentIndex()
        {
            if (_persist_file_handles_opened && persist_on_write_)
            {
                closePersistentIndexFileHandles();
                _persist_file_handles_opened = false;
            }
        }

        void initPersistentIndex()
        {
            // A persisted index is stored as four files
            // The latter 3 files are stored seperately so that they can each grow as the index grows
            // 1. The header
            // 2. The data_level_0
            // 3. length_memory_
            // 4. linkLists

            if (!persist_on_write_)
            {
                throw std::runtime_error("initPersistentIndex called for an index that is not set to persist on write");
            }

            // Create the file handles for initial write
            std::ofstream output_header(this->getHeaderLocation(), std::ios::binary);
            std::ofstream output_data_level0(this->getDataLevel0Location(), std::ios::binary);
            std::ofstream output_length(this->getLengthLocation(), std::ios::binary);
            std::ofstream output_link_lists(this->getLinkListLocation(), std::ios::binary);

            // Write header
            persistHeader(output_header);

            // Write data_level0
            output_data_level0.seekp(0, std::ios::beg);
            output_data_level0.write(data_level0_memory_, max_elements_ * size_data_per_element_);
            output_data_level0.flush();

            // Write lengths
            output_length.seekp(0, std::ios::beg);
            output_length.write(length_memory_, max_elements_ * sizeof(float));
            output_length.flush();

            // Close file handles
            output_header.close();
            output_data_level0.close();
            output_length.close();
            output_link_lists.close();

            // Create file handles for further writing
            openPersistentIndex();
        }

        void persistHeader(std::ostream &output_header)
        {
            output_header.seekp(0, std::ios::beg);

            #define WRITE_ACTION(field) writeBinaryPOD(output_header, field);
            HEADER_FIELDS(WRITE_ACTION)
            #undef WRITE_ACTION

            output_header.flush();
        }

        // Persistence functions
        void persistDirty()
        {
            if (elements_to_persist_.size() == 0)
            {
                return;
            }

            if (!persist_on_write_)
            {
                throw std::runtime_error("persistDirty called for an index that is not set to persist on write");
            }

            if (!_persist_file_handles_opened)
            {
                throw std::runtime_error("persistDirty called for an index that has not opened its file handles");
            }

            persistHeader(this->output_header_);

            // Note: We could benefit a lot from async IO here. Either via classic POSIX AIO or via libaio
            // Generally, this storage scheme is a bit naive, and we could do a lot better in terms of disk access patterns
            this->output_data_level0_.seekp(0, std::ios::beg);
            for (const auto &id : elements_to_persist_)
            {
                // Write the _data_level0_memory
                // Each element is stored in a contiguous block of size_data_per_element_
                // Where each element is the the size of llist, the llist, the data, and the label
                this->output_data_level0_.seekp(id * size_data_per_element_, this->output_data_level0_.beg);
                this->output_data_level0_.write(data_level0_memory_ + id * size_data_per_element_, size_data_per_element_);
            }
            this->output_data_level0_.flush();

            // Write the dirty lengths
            this->output_length_.seekp(0, std::ios::beg);
            for (const auto &id : elements_to_persist_)
            {
                // Write the _length_memory
                this->output_length_.seekp(id * sizeof(float), this->output_length_.beg);
                writeBinaryPOD(this->output_length_, ((float *)length_memory_)[id]);
            }
            this->output_length_.flush();

            // Write the dirty link lists
            this->output_link_lists_.seekp(0, std::ios::beg);
            auto dirty_elements_iter = elements_to_persist_.begin();
            // TODO: don't need to iterate over potentially all elements, could store it or memoize
            for (size_t i = 0; i < cur_element_count && dirty_elements_iter != elements_to_persist_.end(); i++)
            {
                unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
                if (i == *dirty_elements_iter)
                {
                    writeBinaryPOD(this->output_link_lists_, linkListSize);
                    if (linkListSize)
                        this->output_link_lists_.write(linkLists_[i], linkListSize);
                    dirty_elements_iter = std::next(dirty_elements_iter);
                }
                else
                {
                    this->output_link_lists_.seekp(linkListSize + sizeof(unsigned int), this->output_link_lists_.cur);
                }
            }
            this->output_link_lists_.flush();

            // Note: It would make sense to do a fsync here
            elements_to_persist_.clear();
        }

        constexpr size_t calculateHeaderSize() const {
            #define SIZE_ACTION(field) + sizeof(field)
            return 0 HEADER_FIELDS(SIZE_ACTION);
            #undef SIZE_ACTION
        }

        size_t calculateLinkListSize() {
            size_t total_size = 0;
            for (size_t i = 0; i < cur_element_count; i++) {
                unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
                total_size += sizeof(linkListSize) + linkListSize;
            }
            return total_size;
        }

        size_t calculateDataLevel0Size() {
            return max_elements_ * size_data_per_element_;
        }

        size_t calculateLengthSize() {
            return max_elements_ * sizeof(float);
        }

        void serializeHeaderToBuffer(char* buffer, size_t buffer_size) {
            char *init_buffer = buffer;
            #define WRITE_ACTION(field) \
                do { memcpy(buffer, &(field), sizeof(field)); buffer += sizeof(field); } while (0);
            HEADER_FIELDS(WRITE_ACTION)
            #undef WRITE_ACTION
        }

        void serializeDataLevel0ToBuffer(char* buffer, size_t buffer_size) {
            memcpy(buffer, data_level0_memory_, max_elements_ * size_data_per_element_);
        }

        void serializeLengthToBuffer(char* buffer, size_t buffer_size) {
            memcpy(buffer, length_memory_, max_elements_ * sizeof(float));
        }

        void serializeLinkListsToBuffer(char* buffer, size_t buffer_size) {
            for (size_t i = 0; i < cur_element_count; i++) {
                unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
                memcpy(buffer, &linkListSize, sizeof(linkListSize));
                buffer += sizeof(linkListSize);
                if (linkListSize) {
                    memcpy(buffer, linkLists_[i], linkListSize);
                    buffer += linkListSize;
                }
            }
        }

        // Functions for Rust to query required buffer sizes
        size_t getRequiredHeaderSize() const {
            return calculateHeaderSize();
        }
        
        size_t getRequiredDataLevel0Size() {
            return calculateDataLevel0Size();
        }
        
        size_t getRequiredLengthSize() {
            return calculateLengthSize();
        }
        
        size_t getRequiredLinkListSize() {
            return calculateLinkListSize();
        }
        
        // Serialize to externally provided buffers (from Rust)
        HnswDataMut* serializeToHnswData(HnswDataMut* hnsw_data) {
            // Verify buffer sizes match requirements
            if (hnsw_data->header_size < calculateHeaderSize()) {
                throw std::runtime_error("Header buffer too small");
            }
            if (hnsw_data->data_level0_size < calculateDataLevel0Size()) {
                throw std::runtime_error("Data level0 buffer too small");
            }
            if (hnsw_data->length_size < calculateLengthSize()) {
                throw std::runtime_error("Length buffer too small");
            }
            if (hnsw_data->link_list_size < calculateLinkListSize()) {
                throw std::runtime_error("Link list buffer too small");
            }
            
            // Serialize to the provided buffers
            serializeHeaderToBuffer(hnsw_data->header_buffer, hnsw_data->header_size);
            serializeDataLevel0ToBuffer(hnsw_data->data_level0_buffer, hnsw_data->data_level0_size);
            serializeLengthToBuffer(hnsw_data->length_buffer, hnsw_data->length_size);
            serializeLinkListsToBuffer(hnsw_data->link_list_buffer, hnsw_data->link_list_size);
            
            return hnsw_data;
        }

        void readPersistedIndexFromStreams(SpaceInterface<dist_t> *s, 
                                           InputPersistenceStreams& input_streams,
                                           size_t max_elements_i = 0)
        {
            auto& input_header = *input_streams.header_stream;
            auto& input_data_level0 = *input_streams.data_level0_stream;
            auto& input_length = *input_streams.length_stream;
            auto& input_link_list = *input_streams.link_list_stream;
            
            if (!input_header.good())
                throw std::runtime_error("Header stream is not in good state");

            // Read the header
            int persisted_version;
            readBinaryPOD(input_header, persisted_version);
            // For now, version is a simple equality check, we may add backwards compatibility later
            if (persisted_version != PERSISTENCE_VERSION)
                throw std::runtime_error("Cannot read persisted index: wrong persistence version");

            readBinaryPOD(input_header, offsetLevel0_);
            readBinaryPOD(input_header, max_elements_);
            readBinaryPOD(input_header, cur_element_count);

            size_t max_elements = max_elements_i;
            if (max_elements < cur_element_count)
                max_elements = max_elements_;
            max_elements_ = max_elements;
            readBinaryPOD(input_header, size_data_per_element_);
            readBinaryPOD(input_header, label_offset_);
            readBinaryPOD(input_header, offsetData_);
            readBinaryPOD(input_header, maxlevel_);
            readBinaryPOD(input_header, enterpoint_node_);

            readBinaryPOD(input_header, maxM_);
            readBinaryPOD(input_header, maxM0_);
            readBinaryPOD(input_header, M_);
            readBinaryPOD(input_header, mult_);
            readBinaryPOD(input_header, ef_construction_);

            data_size_ = s->get_data_size();
            fstdistfunc_ = s->get_dist_func();
            dist_func_param_ = s->get_dist_func_param();

            // Read data_level0_memory_
            if (!input_data_level0.good())
                throw std::runtime_error("Data level0 stream is not in good state");

            data_level0_memory_ = (char *)malloc(max_elements * size_data_per_element_);
            if (data_level0_memory_ == nullptr)
                throw std::runtime_error("Not enough memory: loadPersistedIndex failed to allocate level0");
            input_data_level0.read(data_level0_memory_, max_elements * size_data_per_element_);

            // Read length_memory_
            if (!input_length.good())
                throw std::runtime_error("Length stream is not in good state");

            length_memory_ = (char *)malloc(max_elements * sizeof(float));
            if (length_memory_ == nullptr)
                throw std::runtime_error("Not enough memory: loadPersistedIndex failed to allocate length_memory_");
            input_length.read(length_memory_, max_elements * sizeof(float));

            // Read the linkLists
            if (!input_link_list.good())
                throw std::runtime_error("Link list stream is not in good state");
            loadLinkLists(input_link_list);

            loadDeleted();
            return;
        }

        void loadPersistedIndex(SpaceInterface<dist_t> *s, size_t max_elements_i = 0)
        {
            std::ifstream input_header(this->getHeaderLocation(), std::ios::binary);
            if (!input_header.is_open())
                throw std::runtime_error("Cannot open header file");

            std::ifstream input_data_level0(this->getDataLevel0Location(), std::ios::binary);
            if (!input_data_level0.is_open())
                throw std::runtime_error("Cannot open data_level0 file");

            std::ifstream input_length(this->getLengthLocation(), std::ios::binary);
            if (!input_length.is_open())
                throw std::runtime_error("Cannot open length file");

            std::ifstream input_link_list(this->getLinkListLocation(), std::ios::binary);
            if (!input_link_list.is_open())
                throw std::runtime_error("Cannot open link list file");

            {
                InputPersistenceStreams input_streams = {
                    std::make_shared<std::ifstream>(std::move(input_header)),
                    std::make_shared<std::ifstream>(std::move(input_data_level0)),
                    std::make_shared<std::ifstream>(std::move(input_length)),
                    std::make_shared<std::ifstream>(std::move(input_link_list))
                };

                readPersistedIndexFromStreams(s, input_streams, max_elements_i);
            }

            openPersistentIndex();
        }

        void loadPersistedIndexFromMemory(SpaceInterface<dist_t> *s,
                                          const HnswDataView &buffers,
                                          size_t max_elements_i = 0)
        {   
            // This function expects null-terminated C-style strings as per the memory
            // Create streams from the extracted data

            InputPersistenceStreams input_streams = {
                std::make_shared<memistream>(buffers.header_buffer, buffers.header_size),
                std::make_shared<memistream>(buffers.data_level0_buffer, buffers.data_level0_size),
                std::make_shared<memistream>(buffers.length_buffer, buffers.length_size),
                std::make_shared<memistream>(buffers.link_list_buffer, buffers.link_list_size)
            };
            
            readPersistedIndexFromStreams(s, input_streams, max_elements_i);
        }

        // #pragma endregion

        void loadLinkLists(std::istream &input_link_list)
        {
            // Init link lists / visited lists pool
            size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
            size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);

            std::vector<std::mutex>(max_elements_).swap(link_list_locks_);
            std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);

            visited_list_pool_ = new VisitedListPool(1, max_elements_);

            linkLists_ = (char **)malloc(sizeof(void *) * max_elements_);
            if (linkLists_ == nullptr)
                throw std::runtime_error("Not enough memory: loadPersistedIndex failed to allocate linklists");
            element_levels_ = std::vector<int>(max_elements_);
            revSize_ = 1.0 / mult_;
            ef_ = 10;
            for (size_t i = 0; i < cur_element_count; i++)
            {
                label_lookup_[getExternalLabel(i)] = i;
                unsigned int linkListSize;
                readBinaryPOD(input_link_list, linkListSize);
                if (linkListSize == 0)
                {
                    element_levels_[i] = 0;
                    linkLists_[i] = nullptr;
                }
                else
                {
                    element_levels_[i] = linkListSize / size_links_per_element_;
                    linkLists_[i] = (char *)malloc(linkListSize);
                    if (linkLists_[i] == nullptr)
                        throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
                    input_link_list.read(linkLists_[i], linkListSize);
                }
            }
        }

        void loadDeleted()
        {
            for (size_t i = 0; i < cur_element_count; i++)
            {
                if (isMarkedDeleted(i))
                {
                    num_deleted_ += 1;
                    if (allow_replace_deleted_)
                        deleted_elements.insert(i);
                }
            }
        }

        void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0)
        {
            std::ifstream input(location, std::ios::binary);

            if (!input.is_open())
                throw std::runtime_error("Cannot open file");

            // get file size:
            input.seekg(0, input.end);
            std::streampos total_filesize = input.tellg();
            input.seekg(0, input.beg);

            readBinaryPOD(input, offsetLevel0_);
            readBinaryPOD(input, max_elements_);
            readBinaryPOD(input, cur_element_count);

            size_t max_elements = max_elements_i;
            if (max_elements < cur_element_count)
                max_elements = max_elements_;
            max_elements_ = max_elements;
            readBinaryPOD(input, size_data_per_element_);
            readBinaryPOD(input, label_offset_);
            readBinaryPOD(input, offsetData_);
            readBinaryPOD(input, maxlevel_);
            readBinaryPOD(input, enterpoint_node_);

            readBinaryPOD(input, maxM_);
            readBinaryPOD(input, maxM0_);
            readBinaryPOD(input, M_);
            readBinaryPOD(input, mult_);
            readBinaryPOD(input, ef_construction_);

            data_size_ = s->get_data_size();
            fstdistfunc_ = s->get_dist_func();
            dist_func_param_ = s->get_dist_func_param();

            auto pos = input.tellg();

            /// Optional - check if index is ok:
            input.seekg(cur_element_count * size_data_per_element_ + cur_element_count * sizeof(float), input.cur);
            for (size_t i = 0; i < cur_element_count; i++)
            {
                if (input.tellg() < 0 || input.tellg() >= total_filesize)
                {
                    throw std::runtime_error("Index seems to be corrupted or unsupported");
                }

                unsigned int linkListSize;
                readBinaryPOD(input, linkListSize);
                if (linkListSize != 0)
                {
                    input.seekg(linkListSize, input.cur);
                }
            }

            // throw exception if it either corrupted or old index
            if (input.tellg() != total_filesize)
                throw std::runtime_error("Index seems to be corrupted or unsupported");

            input.clear();
            /// Optional check end

            input.seekg(pos, input.beg);

            data_level0_memory_ = (char *)malloc(max_elements * size_data_per_element_);
            if (data_level0_memory_ == nullptr)
                throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
            input.read(data_level0_memory_, cur_element_count * size_data_per_element_);

            length_memory_ = (char *)malloc(max_elements * sizeof(float));
            if (length_memory_ == nullptr)
                throw std::runtime_error("Not enough memory: loadIndex failed to allocate length_memory_");
            input.read(length_memory_, cur_element_count * sizeof(float));

            loadLinkLists(input);
            loadDeleted();

            input.close();

            return;
        }

        std::pair<size_t, size_t> getLabelCounts() const
        {
            size_t non_deleted_labels = 0;
            size_t deleted_labels = 0;
            std::unique_lock<std::mutex> label_lock(label_lookup_lock);
            for (auto it = label_lookup_.begin(); it != label_lookup_.end(); ++it)
            {
                if (!isMarkedDeleted(it->second))
                {
                    non_deleted_labels += 1;
                }
                else
                {
                    deleted_labels += 1;
                }
            }
            return std::make_pair(non_deleted_labels, deleted_labels);
        }

        // Get all labels, segregated by deleted and non deleted.
        std::pair<std::vector<labeltype>, std::vector<labeltype>> getAllLabels() const
        {
            std::vector<labeltype> labels;
            std::vector<labeltype> deleted_labels;
            std::unique_lock<std::mutex> label_lock(label_lookup_lock);
            for (auto it = label_lookup_.begin(); it != label_lookup_.end(); ++it)
            {
                if (!isMarkedDeleted(it->second))
                {
                    labels.push_back(it->first);
                }
                else
                {
                    deleted_labels.push_back(it->first);
                }
            }
            return std::make_pair(labels, deleted_labels);
        }

        template <typename data_t>
        std::vector<data_t> getDataByLabel(labeltype label) const
        {
            // lock all operations with element by label
            std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));

            std::unique_lock<std::mutex> lock_table(label_lookup_lock);
            auto search = label_lookup_.find(label);
            if (search == label_lookup_.end() || isMarkedDeleted(search->second))
            {
                throw std::runtime_error("Label not found");
            }
            tableint internalId = search->second;
            lock_table.unlock();

            char *data_ptrv = getDataByInternalId(internalId);
            float length = 1.0;
            if (normalize_)
            {
                length = ((float *)length_memory_)[internalId];
            }
            size_t dim = *((size_t *)dist_func_param_);
            std::vector<data_t> data;
            data_t *data_ptr = (data_t *)data_ptrv;
            for (int i = 0; i < dim; i++)
            {
                if (normalize_)
                {
                    data.push_back(*data_ptr * length);
                }
                else
                {
                    data.push_back(*data_ptr);
                }
                data_ptr += 1;
            }
            return data;
        }

        /*
         * Marks an element with the given label deleted, does NOT really change the current graph.
         */
        void markDelete(labeltype label)
        {
            // lock all operations with element by label
            std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));

            std::unique_lock<std::mutex> lock_table(label_lookup_lock);
            auto search = label_lookup_.find(label);
            if (search == label_lookup_.end())
            {
                throw std::runtime_error("Label not found");
            }
            tableint internalId = search->second;
            lock_table.unlock();

            markDeletedInternal(internalId);
            markElementToPersist(internalId);
        }

        /*
         * Uses the last 16 bits of the memory for the linked list size to store the mark,
         * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases.
         */
        void markDeletedInternal(tableint internalId)
        {
            assert(internalId < cur_element_count);
            if (!isMarkedDeleted(internalId))
            {
                unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
                *ll_cur |= DELETE_MARK;
                num_deleted_ += 1;
                if (allow_replace_deleted_)
                {
                    std::unique_lock<std::mutex> lock_deleted_elements(deleted_elements_lock);
                    deleted_elements.insert(internalId);
                }
            }
            else
            {
                throw std::runtime_error("The requested to delete element is already deleted");
            }
        }

        /*
         * Removes the deleted mark of the node, does NOT really change the current graph.
         *
         * Note: the method is not safe to use when replacement of deleted elements is enabled,
         *  because elements marked as deleted can be completely removed by addPoint
         */
        void unmarkDelete(labeltype label)
        {
            // lock all operations with element by label
            std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));

            std::unique_lock<std::mutex> lock_table(label_lookup_lock);
            auto search = label_lookup_.find(label);
            if (search == label_lookup_.end())
            {
                throw std::runtime_error("Label not found");
            }
            tableint internalId = search->second;
            lock_table.unlock();

            unmarkDeletedInternal(internalId);
            markElementToPersist(internalId);
        }

        /*
         * Remove the deleted mark of the node.
         */
        void unmarkDeletedInternal(tableint internalId)
        {
            assert(internalId < cur_element_count);
            if (isMarkedDeleted(internalId))
            {
                unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
                *ll_cur &= ~DELETE_MARK;
                num_deleted_ -= 1;
                if (allow_replace_deleted_)
                {
                    std::unique_lock<std::mutex> lock_deleted_elements(deleted_elements_lock);
                    deleted_elements.erase(internalId);
                }
            }
            else
            {
                throw std::runtime_error("The requested to undelete element is not deleted");
            }
        }

        /*
         * Checks the first 16 bits of the memory to see if the element is marked deleted.
         */
        bool isMarkedDeleted(tableint internalId) const
        {
            unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
            return *ll_cur & DELETE_MARK;
        }

        unsigned short int getListCount(linklistsizeint *ptr) const
        {
            return *((unsigned short int *)ptr);
        }

        void setListCount(linklistsizeint *ptr, unsigned short int size) const
        {
            *((unsigned short int *)(ptr)) = *((unsigned short int *)&size);
        }

        /*
         * Adds point. Updates the point if it is already in the index.
         * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
         */
        void addPoint(const void *data_point, labeltype label, bool replace_deleted = false)
        {
            if ((allow_replace_deleted_ == false) && (replace_deleted == true))
            {
                throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
            }

            // lock all operations with element by label
            std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));
            if (!replace_deleted)
            {
                addPoint(data_point, label, -1);
                return;
            }
            // check if there is vacant place
            tableint internal_id_replaced;
            std::unique_lock<std::mutex> lock_deleted_elements(deleted_elements_lock);
            bool is_vacant_place = !deleted_elements.empty();
            if (is_vacant_place)
            {
                internal_id_replaced = *deleted_elements.begin();
                deleted_elements.erase(internal_id_replaced);
            }
            lock_deleted_elements.unlock();

            // if there is no vacant place then add or update point
            // else add point to vacant place
            if (!is_vacant_place)
            {
                addPoint(data_point, label, -1);
            }
            else
            {
                // we assume that there are no concurrent operations on deleted element
                labeltype label_replaced = getExternalLabel(internal_id_replaced);
                setExternalLabel(internal_id_replaced, label);

                std::unique_lock<std::mutex> lock_table(label_lookup_lock);
                label_lookup_.erase(label_replaced);
                label_lookup_[label] = internal_id_replaced;
                lock_table.unlock();

                unmarkDeletedInternal(internal_id_replaced);
                updatePoint(data_point, internal_id_replaced, 1.0);
            }
        }

        void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability)
        {
            const void *newPoint = dataPoint;

            size_t dim = *((size_t *)dist_func_param_);
            std::vector<float> norm_array(dim);
            if (normalize_)
            {
                float length = normalize_vector((float *)dataPoint, norm_array.data(), dim);
                void *lengthPtr = length_memory_ + internalId * sizeof(float);
                memcpy(length_memory_ + internalId * sizeof(float), &length, sizeof(float));
                newPoint = norm_array.data();
            }
            // update the feature vector associated with existing point with new vector
            memcpy(getDataByInternalId(internalId), newPoint, data_size_);

            int maxLevelCopy = maxlevel_;
            tableint entryPointCopy = enterpoint_node_;
            // If point to be updated is entry point and graph just contains single element then just return.
            if (entryPointCopy == internalId && cur_element_count == 1)
                return;

            int elemLevel = element_levels_[internalId];
            std::uniform_real_distribution<float> distribution(0.0, 1.0);
            for (int layer = 0; layer <= elemLevel; layer++)
            {
                std::unordered_set<tableint> sCand;
                std::unordered_set<tableint> sNeigh;
                std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
                if (listOneHop.size() == 0)
                    continue;

                sCand.insert(internalId);

                for (auto &&elOneHop : listOneHop)
                {
                    sCand.insert(elOneHop);

                    if (distribution(update_probability_generator_) > updateNeighborProbability)
                        continue;

                    sNeigh.insert(elOneHop);

                    std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
                    for (auto &&elTwoHop : listTwoHop)
                    {
                        sCand.insert(elTwoHop);
                    }
                }

                for (auto &&neigh : sNeigh)
                {
                    // if (neigh == internalId)
                    //     continue;

                    std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
                    size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
                    size_t elementsToKeep = std::min(ef_construction_, size);
                    for (auto &&cand : sCand)
                    {
                        if (cand == neigh)
                            continue;

                        dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
                        if (candidates.size() < elementsToKeep)
                        {
                            candidates.emplace(distance, cand);
                        }
                        else
                        {
                            if (distance < candidates.top().first)
                            {
                                candidates.pop();
                                candidates.emplace(distance, cand);
                            }
                        }
                    }

                    // Retrieve neighbours using heuristic and set connections.
                    getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);

                    {
                        std::unique_lock<std::mutex> lock(link_list_locks_[neigh]);
                        linklistsizeint *ll_cur;
                        ll_cur = get_linklist_at_level(neigh, layer);
                        size_t candSize = candidates.size();
                        setListCount(ll_cur, candSize);
                        tableint *data = (tableint *)(ll_cur + 1);
                        for (size_t idx = 0; idx < candSize; idx++)
                        {
                            data[idx] = candidates.top().second;
                            candidates.pop();
                        }
                    }
                }
            }

            repairConnectionsForUpdate(newPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
        }

        void repairConnectionsForUpdate(
            const void *dataPoint,
            tableint entryPointInternalId,
            tableint dataPointInternalId,
            int dataPointLevel,
            int maxLevel)
        {
            tableint currObj = entryPointInternalId;
            if (dataPointLevel < maxLevel)
            {
                dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
                for (int level = maxLevel; level > dataPointLevel; level--)
                {
                    bool changed = true;
                    while (changed)
                    {
                        changed = false;
                        unsigned int *data;
                        std::unique_lock<std::mutex> lock(link_list_locks_[currObj]);
                        data = get_linklist_at_level(currObj, level);
                        int size = getListCount(data);
                        tableint *datal = (tableint *)(data + 1);
#ifdef USE_SSE
                        _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
#endif
                        for (int i = 0; i < size; i++)
                        {
#ifdef USE_SSE
                            _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
#endif
                            tableint cand = datal[i];
                            dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
                            if (d < curdist)
                            {
                                curdist = d;
                                currObj = cand;
                                changed = true;
                            }
                        }
                    }
                }
            }

            if (dataPointLevel > maxLevel)
                throw std::runtime_error("Level of item to be updated cannot be bigger than max level");

            for (int level = dataPointLevel; level >= 0; level--)
            {
                std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
                    currObj, dataPoint, level);

                std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
                while (topCandidates.size() > 0)
                {
                    if (topCandidates.top().second != dataPointInternalId)
                        filteredTopCandidates.push(topCandidates.top());

                    topCandidates.pop();
                }

                // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
                // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
                if (filteredTopCandidates.size() > 0)
                {
                    bool epDeleted = isMarkedDeleted(entryPointInternalId);
                    if (epDeleted)
                    {
                        filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
                        if (filteredTopCandidates.size() > ef_construction_)
                            filteredTopCandidates.pop();
                    }

                    currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
                }
            }
        }

        std::vector<tableint> getConnectionsWithLock(tableint internalId, int level)
        {
            std::unique_lock<std::mutex> lock(link_list_locks_[internalId]);
            unsigned int *data = get_linklist_at_level(internalId, level);
            int size = getListCount(data);
            std::vector<tableint> result(size);
            tableint *ll = (tableint *)(data + 1);
            memcpy(result.data(), ll, size * sizeof(tableint));
            return result;
        }

        tableint addPoint(const void *data_point, labeltype label, int level)
        {
            tableint cur_c = 0;
            {
                // Checking if the element with the same label already exists
                // if so, updating it *instead* of creating a new element.
                std::unique_lock<std::mutex> lock_table(label_lookup_lock);
                auto search = label_lookup_.find(label);
                if (search != label_lookup_.end())
                {
                    tableint existingInternalId = search->second;
                    if (allow_replace_deleted_)
                    {
                        if (isMarkedDeleted(existingInternalId))
                        {
                            throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
                        }
                    }
                    lock_table.unlock();

                    if (isMarkedDeleted(existingInternalId))
                    {
                        unmarkDeletedInternal(existingInternalId);
                    }
                    updatePoint(data_point, existingInternalId, 1.0);

                    return existingInternalId;
                }

                if (cur_element_count >= max_elements_)
                {
                    throw std::runtime_error("The number of elements exceeds the specified limit");
                }

                cur_c = cur_element_count;
                cur_element_count++;
                label_lookup_[label] = cur_c;
            }

            std::unique_lock<std::mutex> lock_el(link_list_locks_[cur_c]);
            int curlevel = getRandomLevel(mult_);
            if (level > 0)
                curlevel = level;

            element_levels_[cur_c] = curlevel;

            std::unique_lock<std::mutex> templock(global);
            int maxlevelcopy = maxlevel_;
            if (curlevel <= maxlevelcopy)
                templock.unlock();
            tableint currObj = enterpoint_node_;
            tableint enterpoint_copy = enterpoint_node_;

            memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);

            // Initialisation of the data and label and if appropriate the length
            const void *normalized_vector = data_point;
            size_t dim = *((size_t *)dist_func_param_);
            std::vector<float> norm_array(dim);
            if (normalize_)
            {
                float length = normalize_vector((float *)data_point, norm_array.data(), dim);
                void *lengthPtr = length_memory_ + cur_c * sizeof(float);
                memcpy(length_memory_ + cur_c * sizeof(float), &length, sizeof(float));
                normalized_vector = norm_array.data();
            }
            memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
            memcpy(getDataByInternalId(cur_c), normalized_vector, data_size_);

            if (curlevel)
            {
                linkLists_[cur_c] = (char *)malloc(size_links_per_element_ * curlevel + 1);
                if (linkLists_[cur_c] == nullptr)
                    throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
                memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
            }

            if ((signed)currObj != -1)
            {
                if (curlevel < maxlevelcopy)
                {
                    dist_t curdist = fstdistfunc_(normalized_vector, getDataByInternalId(currObj), dist_func_param_);
                    for (int level = maxlevelcopy; level > curlevel; level--)
                    {
                        bool changed = true;
                        while (changed)
                        {
                            changed = false;
                            unsigned int *data;
                            std::unique_lock<std::mutex> lock(link_list_locks_[currObj]);
                            data = get_linklist(currObj, level);
                            int size = getListCount(data);

                            tableint *datal = (tableint *)(data + 1);
                            for (int i = 0; i < size; i++)
                            {
                                tableint cand = datal[i];
                                if (cand < 0 || cand > max_elements_)
                                    throw std::runtime_error("cand error");
                                dist_t d = fstdistfunc_(normalized_vector, getDataByInternalId(cand), dist_func_param_);
                                if (d < curdist)
                                {
                                    curdist = d;
                                    currObj = cand;
                                    changed = true;
                                }
                            }
                        }
                    }
                }

                bool epDeleted = isMarkedDeleted(enterpoint_copy);
                for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--)
                {
                    if (level > maxlevelcopy || level < 0) // possible?
                        throw std::runtime_error("Level error");

                    std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
                        currObj, normalized_vector, level);
                    if (epDeleted)
                    {
                        top_candidates.emplace(fstdistfunc_(normalized_vector, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
                        if (top_candidates.size() > ef_construction_)
                            top_candidates.pop();
                    }
                    currObj = mutuallyConnectNewElement(normalized_vector, cur_c, top_candidates, level, false);
                }
            }
            else
            {
                // Do nothing for the first element
                enterpoint_node_ = 0;
                maxlevel_ = curlevel;

                // mark cur_c as dirty
                markElementToPersist(cur_c);
            }

            // Releasing lock for the maximum level
            if (curlevel > maxlevelcopy)
            {
                enterpoint_node_ = cur_c;
                maxlevel_ = curlevel;
            }
            return cur_c;
        }

        std::priority_queue<std::pair<dist_t, labeltype>>
        searchKnn(const void *query_data, size_t k, BaseFilterFunctor *isIdAllowed = nullptr) const
        {
            std::priority_queue<std::pair<dist_t, labeltype>> result;
            if (cur_element_count == 0)
                return result;

            tableint currObj = enterpoint_node_;
            dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

            for (int level = maxlevel_; level > 0; level--)
            {
                bool changed = true;
                while (changed)
                {
                    changed = false;
                    unsigned int *data;
                    data = (unsigned int *)get_linklist(currObj, level);
                    int size = getListCount(data);
                    metric_hops++;
                    metric_distance_computations += size;

                    tableint *datal = (tableint *)(data + 1);
                    for (int i = 0; i < size; i++)
                    {
                        tableint cand = datal[i];
                        if (cand < 0 || cand > max_elements_)
                            throw std::runtime_error("cand error");
                        dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

                        if (d < curdist)
                        {
                            curdist = d;
                            currObj = cand;
                            changed = true;
                        }
                    }
                }
            }

            std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
            if (num_deleted_)
            {
                top_candidates = searchBaseLayerST<true, true>(
                    currObj, query_data, std::max(ef_, k), isIdAllowed);
            }
            else
            {
                top_candidates = searchBaseLayerST<false, true>(
                    currObj, query_data, std::max(ef_, k), isIdAllowed);
            }

            while (top_candidates.size() > k)
            {
                top_candidates.pop();
            }
            while (top_candidates.size() > 0)
            {
                std::pair<dist_t, tableint> rez = top_candidates.top();
                result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
                top_candidates.pop();
            }
            return result;
        }

        void checkIntegrity()
        {
            int connections_checked = 0;
            std::vector<int> inbound_connections_num(cur_element_count, 0);
            for (int i = 0; i < cur_element_count; i++)
            {
                for (int l = 0; l <= element_levels_[i]; l++)
                {
                    linklistsizeint *ll_cur = get_linklist_at_level(i, l);
                    int size = getListCount(ll_cur);
                    tableint *data = (tableint *)(ll_cur + 1);
                    std::unordered_set<tableint> s;
                    for (int j = 0; j < size; j++)
                    {
                        if (data[j] < 0 || data[j] >= cur_element_count || data[j] == i)
                            throw std::runtime_error("HNSW Integrity failure: invalid neighbor index");
                        inbound_connections_num[data[j]]++;
                        s.insert(data[j]);
                        connections_checked++;
                    }
                    if (s.size() != size)
                        throw std::runtime_error("HNSW Integrity failure: duplicate neighbor index");
                }
            }
            if (cur_element_count > 1)
            {
                int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
                for (int i = 0; i < cur_element_count; i++)
                {
                    // This should always be true regardless the data is corrupted or not
                    assert(inbound_connections_num[i] > 0);
                    min1 = std::min(inbound_connections_num[i], min1);
                    max1 = std::max(inbound_connections_num[i], max1);
                }
                std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
            }
            std::cout << "integrity ok, checked " << connections_checked << " connections\n";
        }
    };
} // namespace hnswlib
