TVM  0.9.4
AffineExpr.h
Go to the documentation of this file.
1 
3 #pragma once
4 
5 #include <tvm/api.h>
6 #include <tvm/defs.h>
7 
8 #include <tvm/Variable.h>
9 #include <tvm/internal/meta.h>
11 
12 #include <Eigen/Core>
13 
14 #include <tuple>
15 
16 namespace tvm
17 {
18 namespace utils
19 {
24 template<typename Derived>
26 {
27 public:
29  LinearExpr(const Eigen::MatrixBase<Derived> & matrix, const VariablePtr & v) : matrix_(matrix.derived()), var_(v)
30  { assert(matrix.cols() == v->size()); }
31 
33  template<class T = Derived, typename std::enable_if_t<std::is_same_v<T, internal::IdentityType>, int> = 0>
34  LinearExpr(const tvm::VariablePtr & v) : matrix_(Eigen::MatrixXd::Identity(v->size(), v->size())), var_(v)
35  {}
36 
38  template<class T = Derived, typename std::enable_if_t<std::is_same_v<T, internal::MultIdentityType>, int> = 0>
39  LinearExpr(double a, const tvm::VariablePtr & v)
40  : matrix_(a * Eigen::MatrixXd::Identity(v->size(), v->size())), var_(v)
41  {}
42 
44  template<class T = Derived, typename std::enable_if_t<std::is_same_v<T, internal::MinusIdentityType>, int> = 0>
45  LinearExpr(const tvm::VariablePtr & v) : matrix_(-Eigen::MatrixXd::Identity(v->size(), v->size())), var_(v)
46  {}
47 
48  const Derived & matrix() const { return matrix_; };
49  const VariablePtr & variable() const { return var_; }
50 
51 private:
53  typename internal::RefSelector_t<Derived> matrix_;
55  const VariablePtr & var_;
56 };
57 
69 template<typename CstDerived, typename... Derived>
71 {
72 public:
74  template<class T = CstDerived, typename std::enable_if_t<std::is_same_v<T, internal::NoConstant>, int> = 0>
76  : linear_(std::forward_as_tuple(linear...)), constant_(internal::NoConstant())
77  {}
78 
80  template<class T = CstDerived, typename std::enable_if_t<!std::is_same_v<T, internal::NoConstant>, int> = 0>
81  AffineExpr(const Eigen::MatrixBase<CstDerived> & constant, const LinearExpr<Derived> &... linear)
82  : linear_(std::forward_as_tuple(linear...)), constant_(constant.derived())
83  { EIGEN_STATIC_ASSERT_VECTOR_ONLY(CstDerived) }
84 
86  template<class T = CstDerived, typename std::enable_if_t<std::is_same_v<T, internal::NoConstant>, int> = 0>
88  : linear_(linear), constant_(constant)
89  {}
90 
92  template<class T = CstDerived, typename std::enable_if_t<!std::is_same_v<T, internal::NoConstant>, int> = 0>
93  AffineExpr(const Eigen::MatrixBase<CstDerived> & constant, const std::tuple<LinearExpr<Derived>...> & linear)
94  : linear_(linear), constant_(constant.derived())
95  { EIGEN_STATIC_ASSERT_VECTOR_ONLY(CstDerived) }
96 
97  const std::tuple<LinearExpr<Derived>...> & linear() const { return linear_; }
98  const CstDerived & constant() const { return constant_; }
99 
100 private:
101  // ConstantType::Type is NoConstant in case CstType==NoConstant and ref_selector<CstDerived> otherwise.
102  // Use of ref_selector<CstDerived>: for a vector, we need to keep a const ref, while for a vector
103  // expression we need to take a copy of the expression (same use as in e.g. CWiseBinaryOp).
104  template<typename T>
105  struct ConstantType
106  {
107  using Type = typename internal::RefSelector_t<CstDerived>;
108  };
109 
111  std::tuple<LinearExpr<Derived>...> linear_;
113  typename ConstantType<CstDerived>::Type constant_;
114 };
115 
117 template<typename CstDerived, typename... Derived>
118 AffineExpr<CstDerived, Derived...> make_AffineExpr(const CstDerived & constant,
119  const std::tuple<LinearExpr<Derived>...> & linear)
120 { return {constant, linear}; }
121 } // namespace utils
122 } // namespace tvm
123 
Definition: AffineExpr.h:71
AffineExpr(const Eigen::MatrixBase< CstDerived > &constant, const std::tuple< LinearExpr< Derived >... > &linear)
Definition: AffineExpr.h:93
AffineExpr(const Eigen::MatrixBase< CstDerived > &constant, const LinearExpr< Derived > &... linear)
Definition: AffineExpr.h:81
const CstDerived & constant() const
Definition: AffineExpr.h:98
AffineExpr(const internal::NoConstant &constant, const std::tuple< LinearExpr< Derived >... > &linear)
Definition: AffineExpr.h:87
const std::tuple< LinearExpr< Derived >... > & linear() const
Definition: AffineExpr.h:97
AffineExpr(const LinearExpr< Derived > &... linear)
Definition: AffineExpr.h:75
Definition: AffineExpr.h:26
LinearExpr(double a, const tvm::VariablePtr &v)
Definition: AffineExpr.h:39
LinearExpr(const tvm::VariablePtr &v)
Definition: AffineExpr.h:34
const VariablePtr & variable() const
Definition: AffineExpr.h:49
const Derived & matrix() const
Definition: AffineExpr.h:48
LinearExpr(const Eigen::MatrixBase< Derived > &matrix, const VariablePtr &v)
Definition: AffineExpr.h:29
Definition: AffineExprDetail.h:26
Definition: AffineExprDetail.h:95
Definition: probe.h:44
Type
Definition: enums.h:15
typename RefSelector< Derived >::type RefSelector_t
Definition: AffineExprDetail.h:70
AffineExpr< CstDerived, Derived... > make_AffineExpr(const CstDerived &constant, const std::tuple< LinearExpr< Derived >... > &linear)
Definition: AffineExpr.h:118
Definition: Clock.h:12
std::shared_ptr< Variable > VariablePtr
Definition: defs.h:65