TVM  0.9.4
VariableVectorPartition.h
Go to the documentation of this file.
1 
3 #pragma once
4 
5 #include <tvm/Variable.h>
7 
8 namespace tvm::internal
9 {
53 template<class VarVector>
55 {
56 public:
64  VariableVectorPartition(const VarVector & v, const VariableCountingVector & partition)
65  : var_(v), partition_(partition.variables())
66  { assert(partition.split()); }
67 
68  class iterator
69  {
70  public:
71  iterator(const VarVector & var, const VariableVector & p, int iv) : ip_(0), iv_(iv), var_(var), partition_(p)
72  {
73  if(iv == p.numberOfVariables()) // end
74  {
75  ip_ = iv;
76  }
77  else
78  {
79  const auto & v = var[iv];
80  for(; ip_ < p.numberOfVariables(); ++ip_)
81  {
82  if(v->contains(*p[ip_]))
83  return;
84  }
85  throw std::runtime_error("Invalid partition");
86  }
87  }
88 
90  {
91  ++ip_;
92  if(ip_ < partition_.numberOfVariables() && var_[iv_]->contains(*partition_[ip_]))
93  {
94  return *this;
95  }
96  else
97  {
98  assert(partition_[ip_ - 1]->subvariableRange().end() == var_[iv_]->subvariableRange().end()
99  && "End of variable in partition is not the same as the end of variable in the reference vector.");
100  ++iv_;
101  if(iv_ == static_cast<int>(var_.end() - var_.begin()))
102  {
103  // return end
104  ip_ = partition_.numberOfVariables();
105  return *this;
106  }
107  const auto & v = var_[iv_];
108  for(ip_ = 0; ip_ < partition_.numberOfVariables(); ++ip_)
109  {
110  if(v->contains(*partition_[ip_]))
111  return *this;
112  }
113  throw std::runtime_error("Invalid partition");
114  }
115  }
116 
117  bool operator==(iterator other) const
118  {
119  assert(&var_ == &other.var_ && &partition_ == &other.partition_);
120  return ip_ == other.ip_;
121  }
122  bool operator!=(iterator other) const { return !(*this == other); }
123  VariablePtr operator*() { return partition_[ip_]; }
124 
125  private:
126  int ip_;
127  int iv_;
128  const VarVector & var_;
129  const VariableVector & partition_;
130  };
131 
132  iterator begin() { return {var_, partition_, 0}; }
133  iterator end() { return {var_, partition_, partition_.numberOfVariables()}; }
134 
135 private:
136  const VarVector & var_;
137  const VariableVector & partition_;
138 };
139 } // namespace tvm::internal
Definition: VariableVector.h:41
int numberOfVariables() const
Definition: VariableCountingVector.h:30
bool split() const
Definition: VariableCountingVector.h:68
Definition: VariableVectorPartition.h:69
bool operator!=(iterator other) const
Definition: VariableVectorPartition.h:122
iterator & operator++()
Definition: VariableVectorPartition.h:89
VariablePtr operator*()
Definition: VariableVectorPartition.h:123
iterator(const VarVector &var, const VariableVector &p, int iv)
Definition: VariableVectorPartition.h:71
bool operator==(iterator other) const
Definition: VariableVectorPartition.h:117
Definition: VariableVectorPartition.h:55
VariableVectorPartition(const VarVector &v, const VariableCountingVector &partition)
Definition: VariableVectorPartition.h:64
iterator end()
Definition: VariableVectorPartition.h:133
iterator begin()
Definition: VariableVectorPartition.h:132
Definition: CallbackManager.h:12
std::shared_ptr< Variable > VariablePtr
Definition: defs.h:65