diff --git a/include/xtensor-sparse/xsparse_scheme.hpp b/include/xtensor-sparse/xsparse_scheme.hpp new file mode 100644 index 0000000..26062d2 --- /dev/null +++ b/include/xtensor-sparse/xsparse_scheme.hpp @@ -0,0 +1,864 @@ +#ifndef XSPARSE_SCHEME_HPP +#define XSPARSE_SCHEME_HPP + +#include +#include + +namespace xt +{ + + /*********************************************************************** + * xsparse_polymorphic_scheme_nz_iterator as a bridge for type erasure * + ***********************************************************************/ + + template + class xsparse_abstract_scheme_nz_iterator; + + template + class xsparse_polymorphic_scheme_nz_iterator + { + public: + + using self_type = xsparse_polymorphic_scheme_nz_iterator; + using abstract_iterator = xsparse_abstract_scheme_nz_iterator; + using index_type = xtl::any; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::ptrdiff_t; + + xsparse_polymorphic_scheme_nz_iterator(abstract_iterator *it); + ~xsparse_polymorphic_scheme_nz_iterator(); + + self_type& operator++(); + self_type& operator--(); + + self_type& operator+=(difference_type n); + self_type& operator-=(difference_type n); + + difference_type operator-(const self_type& rhs) const; + + reference operator*() const; + pointer operator->() const; + const index_type& index() const; + + bool equal(const self_type& rhs) const; + bool less_than(const self_type& rhs) const; + + private: + abstract_iterator *m_it = nullptr; + }; + + template + bool operator == (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs); + + template + bool operator < (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs); + + /*************************************************************************** + * xsparse_abstract_scheme_nz_iterator as top-level class for type erasure * + ***************************************************************************/ + + template + class xsparse_abstract_scheme_nz_iterator + { + public: + + using self_type = xsparse_abstract_scheme_nz_iterator; + using index_type = xtl::any; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::ptrdiff_t; + + virtual ~xsparse_abstract_scheme_nz_iterator() = default; + + virtual reference operator*() const = 0; + virtual pointer operator->() const = 0; + virtual const index_type& index() const = 0; + + virtual bool equal(const self_type& rhs) const = 0; + virtual bool less_than(const self_type& rhs) const = 0; + + virtual difference_type distance(const self_type& rhs) const = 0; + + virtual void advance(void) = 0; + virtual void rewind(void) = 0; + virtual void advance(difference_type n) = 0; + virtual void rewind(difference_type n) = 0; + }; + + /****************************************************************** + * xsparse_crtp_scheme_nz_iterator as base class for type erasure * + ******************************************************************/ + + template + class xsparse_crtp_scheme_nz_iterator : public xsparse_abstract_scheme_nz_iterator + { + public: + + using derived_type = D; + + using self_type = xsparse_crtp_scheme_nz_iterator; + using index_type = xtl::any; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::ptrdiff_t; + + const index_type& index() const final; + + bool equal(const self_type& rhs) const final; + bool less_than(const self_type& rhs) const final; + + difference_type distance(const self_type& rhs) const final; + + void advance(void) final; + void rewind(void) final; + void advance(difference_type n) final; + void rewind(difference_type n) final; + + private: + + derived_type& derived_cast() & noexcept; + const derived_type& derived_cast() const & noexcept; + derived_type derived_cast() && noexcept; + + index_type m_index; + }; + + /************************************************ + * xsparse_coo_scheme_nz_iterator as an example * + ************************************************/ + + namespace detail + { + template + struct xsparse_coo_scheme_storage_type + { + using storage_type = typename scheme::storage_type; + using value_iterator = typename storage_type::iterator; + }; + + template + struct xsparse_coo_scheme_storage_type + { + using storage_type = typename scheme::storage_type; + using value_iterator = typename storage_type::const_iterator; + }; + + template + struct xsparse_coo_scheme_nz_iterator_types : xsparse_coo_scheme_storage_type + { + using base_type = xsparse_coo_scheme_storage_type; + using index_type = typename scheme::index_type; + using coordinate_type = typename scheme::coordinate_type; + using coordinate_iterator = typename coordinate_type::const_iterator; + using value_iterator = typename base_type::value_iterator; + using value_type = typename value_iterator::value_type; + using reference = typename value_iterator::reference; + using pointer = typename value_iterator::pointer; + using difference_type = typename value_iterator::difference_type; + }; + } + + template + class xsparse_coo_scheme_nz_iterator : public xsparse_crtp_scheme_nz_iterator>, + xtl::xrandom_access_iterator_base3, + detail::xsparse_coo_scheme_nz_iterator_types> + { + public: + + using self_type = xsparse_coo_scheme_nz_iterator; + using scheme_type = scheme; + using iterator_types = detail::xsparse_coo_scheme_nz_iterator_types; + using index_type = typename iterator_types::index_type; + using coordinate_type = typename iterator_types::coordinate_type; + using coordinate_iterator = typename iterator_types::coordinate_iterator; + using value_iterator = typename iterator_types::value_iterator; + using value_type = typename iterator_types::value_type; + using reference = typename iterator_types::reference; + using pointer = typename iterator_types::pointer; + using difference_type = typename iterator_types::difference_type; + using iterator_category = std::random_access_iterator_tag; + + xsparse_coo_scheme_nz_iterator() = default; + xsparse_coo_scheme_nz_iterator(scheme& s, coordinate_iterator cit, value_iterator vit); + + self_type& operator++(); + self_type& operator--(); + + self_type& operator+=(difference_type n); + self_type& operator-=(difference_type n); + + difference_type operator-(const self_type& rhs) const; + + reference operator*() const; + pointer operator->() const; + const index_type& index() const; + + bool equal(const self_type& rhs) const; + bool less_than(const self_type& rhs) const; + + private: + + scheme_type* p_scheme = nullptr; + coordinate_iterator m_cit; + value_iterator m_vit; + }; + + template + bool operator==(const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs); + + template + bool operator<(const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs); + + /********************************************************* + * xsparse_polymorphic_scheme_nz_iterator implementation * + *********************************************************/ + + template + inline xsparse_polymorphic_scheme_nz_iterator::xsparse_polymorphic_scheme_nz_iterator(abstract_iterator *it) : m_it(it) + { + } + + template + inline xsparse_polymorphic_scheme_nz_iterator::~xsparse_polymorphic_scheme_nz_iterator() + { + if (m_it) + delete m_it; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator++() -> self_type& + { + m_it->advance(); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator--() -> self_type& + { + m_it->rewind(); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator+=(difference_type n) -> self_type& + { + m_it->advance(n); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator-=(difference_type n) -> self_type& + { + m_it->rewind(n); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator-(const self_type& rhs) const -> difference_type + { + return m_it->distance(rhs); + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator*() const -> reference + { + return m_it->reference(); + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator->() const -> pointer + { + return m_it->pointer(); + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::index() const -> const index_type& + { + return m_it->index(); + } + + template + inline bool xsparse_polymorphic_scheme_nz_iterator::equal(const self_type& rhs) const + { + return m_it->equal(*(rhs->m_it)); + } + + template + inline bool xsparse_polymorphic_scheme_nz_iterator::less_than(const self_type& rhs) const + { + return m_it->less_than(*(rhs->m_it)); + } + + template + inline bool operator == (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs) + { + return lhs->equal(rhs); + } + + template + inline bool operator < (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs) + { + return lhs->less_than(rhs); + } + + /************************************************** + * xsparse_crtp_scheme_nz_iterator implementation * + **************************************************/ + + template + inline auto xsparse_crtp_scheme_nz_iterator::index() const -> const index_type& + { + m_index = this->derived_cast().index(); + return m_index; + } + + template + inline bool xsparse_crtp_scheme_nz_iterator::equal(const self_type& rhs) const + { + return this->derived_cast() == static_cast(rhs); + } + + template + inline bool xsparse_crtp_scheme_nz_iterator::less_than(const self_type& rhs) const + { + return this->derived_cast() < static_cast(rhs); + } + + template + inline auto xsparse_crtp_scheme_nz_iterator::distance(const self_type& rhs) const -> difference_type + { + auto self = this->derived_cast(); + auto other = static_cast(rhs); + + auto diff = self - other; + return (difference_type)(diff); + } + + template + inline void xsparse_crtp_scheme_nz_iterator::advance(void) + { + ++(this->derived_cast()); + } + + template + inline void xsparse_crtp_scheme_nz_iterator::rewind(void) + { + --(this->derived_cast()); + } + + template + inline void xsparse_crtp_scheme_nz_iterator::advance(difference_type n) + { + (this->derived_cast()) += n; + } + + template + inline void xsparse_crtp_scheme_nz_iterator::rewind(difference_type n) + { + (this->derived_cast()) -= n; + } + + template + inline auto xsparse_crtp_scheme_nz_iterator::derived_cast() & noexcept -> derived_type& + { + return static_cast(*this); + } + + template + inline auto xsparse_crtp_scheme_nz_iterator::derived_cast() const & noexcept -> const derived_type& + { + return static_cast(*this); + } + + /************************************************* + * xsparse_coo_scheme_nz_iterator implementation * + *************************************************/ + + template + inline xsparse_coo_scheme_nz_iterator::xsparse_coo_scheme_nz_iterator(S& s, coordinate_iterator cit, value_iterator vit) + : p_scheme(&s) + , m_cit(cit) + , m_vit(vit) + { + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator++() -> self_type& + { + ++m_cit; + ++m_vit; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator--() -> self_type& + { + --m_cit; + --m_vit; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator+=(difference_type n) -> self_type& + { + m_cit += n; + m_vit += n; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator-=(difference_type n) -> self_type& + { + m_cit -= n; + m_vit -= n; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator-(const self_type& rhs) const -> difference_type + { + return m_cit - rhs.m_cit; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator*() const -> reference + { + return *m_vit; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator->() const -> pointer + { + return &(*m_vit); + } + + template + inline auto xsparse_coo_scheme_nz_iterator::index() const -> const index_type& + { + return *m_cit; + } + + template + inline bool xsparse_coo_scheme_nz_iterator::equal(const self_type& rhs) const + { + return p_scheme == rhs.p_scheme && m_cit == rhs.m_cit && m_vit == rhs.m_vit; + } + + template + inline bool xsparse_coo_scheme_nz_iterator::less_than(const self_type& rhs) const + { + return p_scheme == rhs.p_scheme && m_cit < rhs.m_cit && m_vit < rhs.m_vit; + } + + template + inline bool operator == (const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs) + { + return lhs.equal(rhs); + } + + template + inline bool operator < (const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs) + { + return lhs.less_than(rhs); + } + + /*********************************************************** + * xsparse_polymorphic_scheme as a bridge for type erasure * + ***********************************************************/ + + template + class xsparse_abstract_scheme; + + template + class xsparse_polymorphic_scheme + { + public: + + using self_type = xsparse_polymorphic_scheme; + using index_type = xtl::any; + + using value_type = T; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + using size_type = std::size_t; + using shape_type = svector; + using strides_type = svector; + using inner_shape_type = shape_type; + + using nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + using const_nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + + xsparse_polymorphic_scheme(); + xsparse_polymorphic_scheme(xsparse_abstract_scheme *scheme); + ~xsparse_polymorphic_scheme(); + + pointer find_element(const index_type& index); + const_pointer find_element(const index_type& index) const; + void insert_element(const index_type& index, const_reference value); + void remove_element(const index_type& index); + + void update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape); + + nz_iterator nz_begin(); + nz_iterator nz_end(); + const_nz_iterator nz_begin() const; + const_nz_iterator nz_end() const; + const_nz_iterator nz_cbegin() const; + const_nz_iterator nz_cend() const; + + private: + class xsparse_abstract_scheme *m_scheme = nullptr; + }; + + /*********************************************************** + * xsparse_abstract_scheme as base class for type erasure * + ***********************************************************/ + + template + class xsparse_abstract_scheme + { + public: + + using self_type = xsparse_abstract_scheme; + using index_type = xtl::any; + + using value_type = T; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + using size_type = std::size_t; + using shape_type = svector; + using strides_type = svector; + using inner_shape_type = shape_type; + + using nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + using const_nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + + + virtual ~xsparse_abstract_scheme = default; + + virtual pointer find_element(const index_type& index) = 0; + virtual const_pointer find_element(const index_type& index) const = 0; + virtual void insert_element(const index_type& index, const_reference value) = 0; + virtual void remove_element(const index_type& index) = 0; + + virtual void update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape) = 0; + + virtual nz_iterator nz_begin() = 0; + virtual nz_iterator nz_end() = 0; + virtual const_nz_iterator nz_begin() const = 0; + virtual const_nz_iterator nz_end() const = 0; + virtual const_nz_iterator nz_cbegin() const = 0; + virtual const_nz_iterator nz_cend() const = 0; + }; + + /********************** + * xsparse_coo_scheme * + **********************/ + + template > + class xsparse_coo_scheme + { + public: + + using self_type = xsparse_coo_scheme; + using position_type = P; + using coordinate_type = C; + using storage_type = ST; + using index_type = IT; + + using value_type = typename storage_type::value_type; + using reference = typename storage_type::reference; + using const_reference = typename storage_type::const_reference; + using pointer = typename storage_type::pointer; + using const_pointer = typename storage_type::const_pointer; + + using nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + using const_nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + + using coo_nz_iterator = xsparse_coo_scheme_nz_iterator; + using coo_const_nz_iterator = xsparse_coo_scheme_nz_iterator; + + xsparse_coo_scheme(); + + const position_type& position() const; + const coordinate_type& coordinate() const; + const storage_type& storage() const; + + + pointer find_element(const index_type& index); + const_pointer find_element(const index_type& index) const; + void insert_element(const index_type& index, const_reference value); + void remove_element(const index_type& index); + + template + void update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape); + + nz_iterator nz_begin(); + nz_iterator nz_end(); + const_nz_iterator nz_begin() const; + const_nz_iterator nz_end() const; + const_nz_iterator nz_cbegin() const; + const_nz_iterator nz_cend() const; + + private: + + const_pointer find_element_impl(const index_type& index) const; + + position_type m_pos; + coordinate_type m_coords; + storage_type m_storage; + + friend class xsparse_coo_scheme_nz_iterator; + friend class xsparse_coo_scheme_nz_iterator; + }; + + + /********************************************* + * xsparse_polymorphic_scheme implementation * + *********************************************/ + + template + inline xsparse_polymorphic_scheme::xsparse_polymorphic_scheme() + { + // m_scheme = xt::scheme_policy().scheme(); + } + + template + inline xsparse_polymorphic_scheme::xsparse_polymorphic_scheme(xsparse_abstract_scheme *scheme) : m_scheme(scheme) + { + + } + + template + inline xsparse_polymorphic_scheme::~xsparse_polymorphic_scheme() + { + if (m_scheme) + delete m_scheme; + } + + template + inline auto xsparse_polymorphic_scheme::find_element(const index_type& index) -> pointer + { + return m_scheme->find_element(index); + } + + template + inline auto xsparse_polymorphic_scheme::find_element(const index_type& index) const -> const_pointer + { + return m_scheme->find_element(index); + } + + template + inline void xsparse_polymorphic_scheme::insert_element(const index_type& index, const_reference value) + { + m_scheme->insert_element(index, value); + } + + template + inline void xsparse_polymorphic_scheme::remove_element(const index_type& index) + { + m_scheme->remove_element(index); + } + + template + inline void xsparse_polymorphic_scheme::update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape) + { + m_scheme->update_entries(old_strides, new_strides, new_shape); + } + + template + inline auto xsparse_polymorphic_scheme::nz_begin() -> nz_iterator + { + return m_scheme->nz_begin(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_end() -> nz_iterator + { + return m_scheme->nz_end(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_begin() const -> const_nz_iterator + { + return m_scheme->nz_begin(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_end() const -> const_nz_iterator + { + return m_scheme->nz_end(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_cbegin() const -> const_nz_iterator + { + return m_scheme->nz_cbegin(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_cend() const -> const_nz_iterator + { + return m_scheme->nz_cend(); + } + + /****************************************** + * xsparse_abstract_scheme implementation * + ******************************************/ + + template + inline xsparse_coo_scheme::xsparse_coo_scheme() + : m_pos(P{{0u, 0u}}) + + + template + inline auto xsparse_coo_scheme::position() const -> const position_type& + { + return m_pos; + } + + template + inline auto xsparse_coo_scheme::coordinate() const -> const coordinate_type& + { + return m_coords; + } + + template + inline auto xsparse_coo_scheme::storage() const -> const storage_type& + { + return m_storage; + } + + template + inline auto xsparse_coo_scheme::find_element(const index_type& index) -> pointer + { + return const_cast(find_element_impl(index)); + } + + template + inline auto xsparse_coo_scheme::find_element(const index_type& index) const -> const_pointer + { + return find_element_impl(index); + } + + template + inline void xsparse_coo_scheme::insert_element(const index_type& index, const_reference value) + { + auto it = std::upper_bound(m_coords.cbegin(), m_coords.cend(), index); + if (it != m_coords.cend()) + { + auto diff = std::distance(m_coords.cbegin(), it); + m_coords.insert(it, index); + m_storage.insert(m_storage.cbegin() + diff, value); + } + else + { + m_coords.push_back(index); + m_storage.push_back(value); + } + ++m_pos.back(); + } + + template + inline void xsparse_coo_scheme::remove_element(const index_type& index) + { + auto it = std::find(m_coords.begin(), m_coords.end(), index); + if (it != m_coords.end()) + { + auto diff = it - m_coords.begin(); + m_coords.erase(it); + m_pos.back()--; + m_storage.erase(m_storage.begin() + diff); + } + } + + template + template + inline void xsparse_coo_scheme::update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type&) + { + coordinate_type new_coords; + + for(auto& old_index: m_coords) + { + std::size_t offset = element_offset(old_strides, old_index.cbegin(), old_index.cend()); + index_type new_index = unravel_from_strides(offset, new_strides); + new_coords.push_back(new_index); + } + using std::swap; + swap(m_coords, new_coords); + } + + template + inline auto xsparse_coo_scheme::find_element_impl(const index_type& index) const -> const_pointer + { + auto it = std::find(m_coords.begin(), m_coords.end(), index); + return it == m_coords.end() ? nullptr : &*(m_storage.begin() + (it - m_coords.begin())); + } + + template + inline auto xsparse_coo_scheme::nz_begin() -> nz_iterator + { + return nz_iterator(new coo_nz_iterator(*this, m_coords.cbegin(), m_storage.begin())); + } + + template + inline auto xsparse_coo_scheme::nz_end() -> nz_iterator + { + return nz_iterator(new coo_nz_iterator(*this, m_coords.cend(), m_storage.end())); + } + + template + inline auto xsparse_coo_scheme::nz_begin() const -> const_nz_iterator + { + return nz_cbegin(); + } + + template + inline auto xsparse_coo_scheme::nz_end() const -> const_nz_iterator + { + return nz_cend(); + } + + template + inline auto xsparse_coo_scheme::nz_cbegin() const -> const_nz_iterator + { + return const_nz_iterator(new coo_const_nz_iterator(*this, m_coords.cbegin(), m_storage.cbegin())); + } + + template + inline auto xsparse_coo_scheme::nz_cend() const -> const_nz_iterator + { + return const_nz_iterator(new coo_const_nz_iterator(*this, m_coords.cend(), m_storage.cend())); + } +} + +#endif