diff --git a/tc/core/polyhedral/schedule_isl_conversion.cc b/tc/core/polyhedral/schedule_isl_conversion.cc index 2b01979b9..e6dae4452 100644 --- a/tc/core/polyhedral/schedule_isl_conversion.cc +++ b/tc/core/polyhedral/schedule_isl_conversion.cc @@ -311,7 +311,7 @@ std::unique_ptr fromIslSchedule(isl::schedule schedule) { // Note that the children of set and sequence nodes are always filters, so // they cannot be replaced by empty trees. bool validateSchedule(const ScheduleTree* st) { - return *st == *fromIslSchedule(toIslSchedule(st)); + return st->treeEquals(fromIslSchedule(toIslSchedule(st)).get()); } bool validateSchedule(isl::schedule sc) { diff --git a/tc/core/polyhedral/schedule_tree.cc b/tc/core/polyhedral/schedule_tree.cc index c17c2542a..c5ba49a45 100644 --- a/tc/core/polyhedral/schedule_tree.cc +++ b/tc/core/polyhedral/schedule_tree.cc @@ -336,21 +336,17 @@ vector ScheduleTree::collectDFSPreorder( return functional::Filter(filterType, collectDFSPreorder(tree)); } -bool ScheduleTree::operator==(const ScheduleTree& other) const { - // ctx_ cmp ? - if (type_ != other.type_) { +bool ScheduleTree::treeEquals(const ScheduleTree* other) const { + if (!nodeEquals(other)) { return false; } - if (children_.size() != other.children_.size()) { + if (numChildren() != other->numChildren()) { return false; } - if (!elemEquals(this, &other, type_)) { - return false; - } - TC_CHECK(!other.as()) + TC_CHECK(!other->as()) << "NYI: ScheduleTreeType::Set comparison"; - for (size_t i = 0; i < children_.size(); ++i) { - if (*children_[i] != *other.children_[i]) { + for (size_t i = 0, e = numChildren(); i < e; ++i) { + if (!child({i})->treeEquals(other->child({i}))) { return false; } } diff --git a/tc/core/polyhedral/schedule_tree.h b/tc/core/polyhedral/schedule_tree.h index 403dd0210..92144087f 100644 --- a/tc/core/polyhedral/schedule_tree.h +++ b/tc/core/polyhedral/schedule_tree.h @@ -156,11 +156,6 @@ struct ScheduleTree { public: virtual ~ScheduleTree(); - bool operator==(const ScheduleTree& other) const; - bool operator!=(const ScheduleTree& other) const { - return !(*this == other); - } - // Swap a tree with with the given tree. void swapChild(size_t pos, ScheduleTreeUPtr& swappee) { TC_CHECK_GE(pos, 0u) << "position out of children bounds"; @@ -469,6 +464,15 @@ struct ScheduleTree { // Note that this function does _not_ clone the child trees. virtual ScheduleTreeUPtr clone() const = 0; + // Compare the current node to the "other" node. + // Note that this function does _not_ compare the child trees, + // use treeEquals() instead to compare entire trees. + virtual bool nodeEquals(const ScheduleTree* other) const = 0; + + // Comapre the subtree rooted at the current node to the subtree + // rooted at "other". + bool treeEquals(const ScheduleTree* other) const; + // // Data members // diff --git a/tc/core/polyhedral/schedule_tree_elem.cc b/tc/core/polyhedral/schedule_tree_elem.cc index cc7ce86ff..77973894a 100644 --- a/tc/core/polyhedral/schedule_tree_elem.cc +++ b/tc/core/polyhedral/schedule_tree_elem.cc @@ -281,21 +281,26 @@ ScheduleTreeThreadSpecificMarker::make( return res; } -bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { - if (permutable_ != other.permutable_) { +bool ScheduleTreeBand::nodeEquals(const ScheduleTreeBand* otherBand) const { + if (!otherBand) { return false; } - if (coincident_.size() != other.coincident_.size()) { + if (permutable_ != otherBand->permutable_) { return false; } - if (unroll_.size() != other.unroll_.size()) { + if (coincident_.size() != otherBand->coincident_.size()) { + return false; + } + if (unroll_.size() != otherBand->unroll_.size()) { return false; } if (!std::equal( - coincident_.begin(), coincident_.end(), other.coincident_.begin())) { + coincident_.begin(), + coincident_.end(), + otherBand->coincident_.begin())) { return false; } - if (!std::equal(unroll_.begin(), unroll_.end(), other.unroll_.begin())) { + if (!std::equal(unroll_.begin(), unroll_.end(), otherBand->unroll_.begin())) { return false; } @@ -305,13 +310,13 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { // .domain() returns a zero-dimensional union set (in purely parameter space) // if there is no explicit domain. bool mupaIs0D = nMember() == 0; - bool otherMupaIs0D = other.nMember() == 0; + bool otherMupaIs0D = otherBand->nMember() == 0; if (mupaIs0D ^ otherMupaIs0D) { return false; } if (mupaIs0D && otherMupaIs0D) { auto d1 = mupa_.domain(); - auto d2 = other.mupa_.domain(); + auto d2 = otherBand->mupa_.domain(); auto res = d1.is_equal(d2); if (!res) { LOG_IF(INFO, FLAGS_debug_tc_mapper) @@ -322,7 +327,7 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { } } else { auto m1 = isl::union_map::from(mupa_); - auto m2 = isl::union_map::from(other.mupa_); + auto m2 = isl::union_map::from(otherBand->mupa_); { auto res = m1.is_equal(m2); if (!res) { @@ -337,74 +342,60 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const { return true; } -bool ScheduleTreeContext::operator==(const ScheduleTreeContext& other) const { - auto res = context_.is_equal(other.context_); - return res; +bool ScheduleTreeContext::nodeEquals(const ScheduleTreeContext* other) const { + return other && context_.is_equal(other->context_); } -bool ScheduleTreeDomain::operator==(const ScheduleTreeDomain& other) const { - auto res = domain_.is_equal(other.domain_); +bool ScheduleTreeDomain::nodeEquals(const ScheduleTreeDomain* other) const { + if (!other) { + return false; + } + auto res = domain_.is_equal(other->domain_); if (!res) { LOG_IF(INFO, FLAGS_debug_tc_mapper) << "ScheduleTreeDomain difference: " << domain_ << " VS " - << other.domain_ << "\n"; + << other->domain_ << "\n"; } return res; } -bool ScheduleTreeExtension::operator==( - const ScheduleTreeExtension& other) const { - auto res = extension_.is_equal(other.extension_); - return res; +bool ScheduleTreeExtension::nodeEquals( + const ScheduleTreeExtension* other) const { + return other && extension_.is_equal(other->extension_); } -bool ScheduleTreeFilter::operator==(const ScheduleTreeFilter& other) const { - auto res = filter_.is_equal(other.filter_); - return res; +bool ScheduleTreeFilter::nodeEquals(const ScheduleTreeFilter* other) const { + return other && filter_.is_equal(other->filter_); } -bool ScheduleTreeMapping::operator==(const ScheduleTreeMapping& other) const { - auto res = filter_.is_equal(other.filter_); - return res; +bool ScheduleTreeMapping::nodeEquals(const ScheduleTreeMapping* other) const { + if (mapping.size() != other->mapping.size()) { + return false; + } + for (const auto& kvp : mapping) { + if (other->mapping.count(kvp.first) == 0) { + return false; + } + if (!other->mapping.at(kvp.first).plain_is_equal(kvp.second)) { + return false; + } + } + return filter_.is_equal(other->filter_); } -bool ScheduleTreeSequence::operator==(const ScheduleTreeSequence& other) const { +bool ScheduleTreeSequence::nodeEquals(const ScheduleTreeSequence* other) const { return true; } -bool ScheduleTreeSet::operator==(const ScheduleTreeSet& other) const { +bool ScheduleTreeSet::nodeEquals(const ScheduleTreeSet* other) const { return true; } -bool elemEquals( - const ScheduleTree* e1, - const ScheduleTree* e2, - detail::ScheduleTreeType type) { -#define ELEM_EQUALS_CASE(CLASS) \ - else if (type == CLASS::NodeType) { \ - return *static_cast(e1) == *static_cast(e2); \ - } - - if (type == detail::ScheduleTreeType::None) { - LOG(FATAL) << "Hit Error node!"; - } - ELEM_EQUALS_CASE(ScheduleTreeBand) - ELEM_EQUALS_CASE(ScheduleTreeContext) - ELEM_EQUALS_CASE(ScheduleTreeDomain) - ELEM_EQUALS_CASE(ScheduleTreeExtension) - ELEM_EQUALS_CASE(ScheduleTreeFilter) - ELEM_EQUALS_CASE(ScheduleTreeMapping) - ELEM_EQUALS_CASE(ScheduleTreeSequence) - ELEM_EQUALS_CASE(ScheduleTreeSet) - else { - LOG(FATAL) << "NYI: ScheduleTree::operator== for type: " - << static_cast(type); - } - -#undef ELEM_EQUALS_CASE - - return false; +bool ScheduleTreeThreadSpecificMarker::nodeEquals( + const ScheduleTreeThreadSpecificMarker* other) const { + return true; } + } // namespace detail } // namespace polyhedral } // namespace tc diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 5d782049d..bc6e47ddf 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -52,15 +52,15 @@ struct ScheduleTreeContext : public ScheduleTree { const ScheduleTreeContext* tree, std::vector&& children = {}); - bool operator==(const ScheduleTreeContext& other) const; - bool operator!=(const ScheduleTreeContext& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherContext = other->as(); + return otherContext && nodeEquals(otherContext); + } + bool nodeEquals(const ScheduleTreeContext* otherContext) const; public: isl::set context_; @@ -88,15 +88,15 @@ struct ScheduleTreeDomain : public ScheduleTree { const ScheduleTreeDomain* tree, std::vector&& children = {}); - bool operator==(const ScheduleTreeDomain& other) const; - bool operator!=(const ScheduleTreeDomain& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherDomain = other->as(); + return otherDomain && nodeEquals(otherDomain); + } + bool nodeEquals(const ScheduleTreeDomain* otherDomain) const; public: isl::union_set domain_; @@ -124,15 +124,15 @@ struct ScheduleTreeExtension : public ScheduleTree { const ScheduleTreeExtension* tree, std::vector&& children = {}); - bool operator==(const ScheduleTreeExtension& other) const; - bool operator!=(const ScheduleTreeExtension& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherExtension = other->as(); + return otherExtension && nodeEquals(otherExtension); + } + bool nodeEquals(const ScheduleTreeExtension* otherExtension) const; public: isl::union_map extension_; @@ -153,11 +153,6 @@ struct ScheduleTreeFilter : public ScheduleTree { public: virtual ~ScheduleTreeFilter() override {} - bool operator==(const ScheduleTreeFilter& other) const; - bool operator!=(const ScheduleTreeFilter& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::union_set filter, std::vector&& children = {}); @@ -169,6 +164,11 @@ struct ScheduleTreeFilter : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherFilter = other->as(); + return otherFilter && nodeEquals(otherFilter); + } + bool nodeEquals(const ScheduleTreeFilter* otherFilter) const; public: isl::union_set filter_; @@ -193,11 +193,6 @@ struct ScheduleTreeMapping : public ScheduleTree { public: virtual ~ScheduleTreeMapping() override {} - bool operator==(const ScheduleTreeMapping& other) const; - bool operator!=(const ScheduleTreeMapping& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, const Mapping& mapping, @@ -210,6 +205,11 @@ struct ScheduleTreeMapping : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherMapping = other->as(); + return otherMapping && nodeEquals(otherMapping); + } + bool nodeEquals(const ScheduleTreeMapping* otherMapping) const; public: // Mapping from identifiers to affine functions on domain elements. @@ -231,11 +231,6 @@ struct ScheduleTreeSequence : public ScheduleTree { public: virtual ~ScheduleTreeSequence() override {} - bool operator==(const ScheduleTreeSequence& other) const; - bool operator!=(const ScheduleTreeSequence& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, std::vector&& children = {}); @@ -247,6 +242,11 @@ struct ScheduleTreeSequence : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherSequence = other->as(); + return otherSequence && nodeEquals(otherSequence); + } + bool nodeEquals(const ScheduleTreeSequence* otherSequence) const; }; struct ScheduleTreeSet : public ScheduleTree { @@ -261,11 +261,6 @@ struct ScheduleTreeSet : public ScheduleTree { public: virtual ~ScheduleTreeSet() override {} - bool operator==(const ScheduleTreeSet& other) const; - bool operator!=(const ScheduleTreeSet& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, std::vector&& children = {}); @@ -277,6 +272,11 @@ struct ScheduleTreeSet : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherSet = other->as(); + return otherSet && nodeEquals(otherSet); + } + bool nodeEquals(const ScheduleTreeSet* otherSet) const; }; struct ScheduleTreeBand : public ScheduleTree { @@ -295,15 +295,15 @@ struct ScheduleTreeBand : public ScheduleTree { virtual ~ScheduleTreeBand() override {} - bool operator==(const ScheduleTreeBand& other) const; - bool operator!=(const ScheduleTreeBand& other) const { - return !(*this == other); - } - virtual std::ostream& write(std::ostream& os) const override; virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherBand = other->as(); + return otherBand && nodeEquals(otherBand); + } + bool nodeEquals(const ScheduleTreeBand* other) const; // Make a schedule node band from partial schedule. // Replace "mupa" by its greatest integer part to ensure that the @@ -362,13 +362,6 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree { public: virtual ~ScheduleTreeThreadSpecificMarker() override {} - bool operator==(const ScheduleTreeThreadSpecificMarker& other) const { - return true; - } - bool operator!=(const ScheduleTreeThreadSpecificMarker& other) const { - return !(*this == other); - } - static std::unique_ptr make( isl::ctx ctx, std::vector&& children = {}); @@ -380,13 +373,13 @@ struct ScheduleTreeThreadSpecificMarker : public ScheduleTree { virtual ScheduleTreeUPtr clone() const override { return make(this); } + virtual bool nodeEquals(const ScheduleTree* other) const override { + auto otherMarker = other->as(); + return otherMarker && nodeEquals(otherMarker); + } + bool nodeEquals(const ScheduleTreeThreadSpecificMarker* other) const; }; -bool elemEquals( - const ScheduleTree* e1, - const ScheduleTree* e2, - detail::ScheduleTreeType type); - std::ostream& operator<<(std::ostream& os, detail::ScheduleTreeType nt); std::ostream& operator<<( std::ostream& os, diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 545610a6e..48cf236c9 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -144,8 +144,9 @@ struct PolyhedralMapperTest : public ::testing::Test { islNode = islNode.as().tile(mv); auto scheduleISL = fromIslSchedule(islNode.get_schedule().reset_user()); - ASSERT_TRUE(*scheduleISL == *scheduleISLPP) << *scheduleISL << "\nVS\n" - << *scheduleISLPP; + ASSERT_TRUE(scheduleISL->treeEquals(scheduleISLPP.get())) + << *scheduleISL << "\nVS\n" + << *scheduleISLPP; } }