TVM  0.9.4
Variable.h
Go to the documentation of this file.
1 
3 #pragma once
4 
5 #include <tvm/Range.h>
6 #include <tvm/Space.h>
7 
9 
10 #include <tvm/deprecated.hh>
11 
12 #include <Eigen/Core>
13 
14 #include <memory>
15 #include <string>
16 #include <unordered_map>
17 #include <vector>
18 
19 namespace tvm
20 {
21 class Variable;
22 class VariableVector;
23 
31 VariablePtr TVM_DLLAPI dot(VariablePtr var, int ndiff = 1, bool autoName = false);
32 
48 class TVM_DLLAPI Variable : public tvm::internal::ObjWithId, public std::enable_shared_from_this<Variable>
49 {
50 private:
51  struct make_shared_token
52  {};
53 
54 public:
58  Variable(const Variable &) = delete;
62  Variable & operator=(const Variable &) = delete;
72  VariablePtr duplicate(std::string_view name = "") const;
74  const std::string & name() const;
78  int size() const;
84  const Space & space() const;
89  const Space & spaceShift() const;
95  bool isEuclidean() const;
99  TVM_DEPRECATED inline void value(const VectorConstRef & x) { set(x); }
101  void set(const VectorConstRef & x);
103  void set(Eigen::DenseIndex idx, double value);
105  void set(Eigen::DenseIndex idx, Eigen::DenseIndex length, const VectorConstRef & value);
107  void setZero();
112  int derivativeNumber() const;
116  bool isDerivativeOf(const Variable & v) const;
120  bool isPrimitiveOf(const Variable & v) const;
124  bool isBasePrimitive() const;
126  template<int n = 1>
127  VariablePtr primitive() const;
132 
134  bool isSubvariable() const;
139  bool isSuperVariable() const;
146  bool isSuperVariableOf(const Variable & v) const;
161  bool contains(const Variable & v) const;
163  bool intersects(const Variable & v) const;
164 
179  [[nodiscard]] VariablePtr subvariable(Space space, std::string_view baseName, Space shift = {0}) const;
180 
196  [[nodiscard]] VariablePtr subvariable(Space space, Space shift = {0}) const;
197 
202  Range getMappingIn(const VariableVector & variables) const;
203 
210  [[nodiscard]] VariablePtr shared_from_this();
211 
213  Eigen::CommaInitializer<VectorRef> operator<<(double d);
214 
216  template<typename Derived>
217  Eigen::CommaInitializer<VectorRef> operator<<(const Eigen::DenseBase<Derived> & other);
218 
219  friend bool operator==(const Variable & u, const Variable & v);
220  friend bool operator!=(const Variable & u, const Variable & v);
221 
223  Variable(make_shared_token, VariablePtr var, const Space & space, std::string_view name, const Space & shift);
224 
225 private:
226  struct MappingHelper
227  {
228  int start;
229  int stamp;
230  };
231 
233  Variable(const Space & s, std::string_view name);
234 
239  Variable(Variable * var, bool autoName);
240 
244  [[nodiscard]] VariablePtr subvariable(Space space, std::string_view baseName, Space shift, bool autoName) const;
245 
247  template<int n>
248  VariablePtr primitiveNoCheck() const;
249 
251  std::string name_;
252 
254  Space space_;
255 
259  Space shift_;
260 
262  double * memory_;
263 
265  VectorRef value_;
266 
270  int derivativeNumber_;
271 
275  Variable * primitive_;
276 
280  VariablePtr superVariable_;
281 
283  std::unique_ptr<Variable> derivative_;
284 
288  mutable std::unordered_map<int, MappingHelper> startIn_;
289 
291  friend class Space;
292  friend VariablePtr TVM_DLLAPI dot(VariablePtr, int, bool);
293  friend class VariableVector;
294 };
295 
296 template<int n>
298 {
299  if(n <= derivativeNumber_)
300  return primitiveNoCheck<n>();
301  else
302  throw std::runtime_error("This variable is not the n-th derivative of an other variable.");
303 }
304 
305 template<int n>
306 inline VariablePtr Variable::primitiveNoCheck() const
307 {
308  static_assert(n > 0, "Works only for non-negative numbers.");
309  return primitive_->primitive<n - 1>();
310 }
311 
312 template<>
313 inline VariablePtr Variable::primitiveNoCheck<1>() const
314 {
315  if(derivativeNumber_ > 1)
316  return {basePrimitive(), primitive_};
317  else
318  return primitive_->shared_from_this();
319 }
320 
321 inline Eigen::CommaInitializer<VectorRef> Variable::operator<<(double d) { return {value_, d}; }
322 
323 template<typename Derived>
324 inline Eigen::CommaInitializer<VectorRef> Variable::operator<<(const Eigen::DenseBase<Derived> & other)
325 { return {value_, other}; }
326 
327 inline bool operator==(const Variable & u, const Variable & v)
328 { return u.value_.data() == v.value_.data() && u.size() == v.size(); }
329 
330 inline bool operator!=(const Variable & u, const Variable & v) { return !(u == v); }
331 
332 } // namespace tvm
333 
335 inline Eigen::CommaInitializer<tvm::VectorRef> operator<<(tvm::VariablePtr & v, double d) { return *v.get() << d; }
336 
338 inline Eigen::CommaInitializer<tvm::VectorRef> operator<<(tvm::VariablePtr && v, double d) { return *v.get() << d; }
339 
341 template<typename Derived>
342 inline Eigen::CommaInitializer<tvm::VectorRef> operator<<(tvm::VariablePtr & v, const Eigen::DenseBase<Derived> & other)
343 { return *v.get() << other; }
344 
346 template<typename Derived>
347 inline Eigen::CommaInitializer<Eigen::VectorXd> operator<<(tvm::VariablePtr && v,
348  const Eigen::DenseBase<Derived> & other)
349 { return *v.get() << other; }
Eigen::CommaInitializer< tvm::VectorRef > operator<<(tvm::VariablePtr &v, double d)
Definition: Variable.h:335
#define TVM_DLLAPI
Definition: api.h:35
Definition: Range.h:19
Definition: Space.h:33
Definition: VariableVector.h:41
Definition: Variable.h:49
bool contains(const Variable &v) const
Variable & operator=(const Variable &)=delete
const Space & space() const
int derivativeNumber() const
VariablePtr subvariable(Space space, Space shift={0}) const
VariablePtr basePrimitive() const
VariablePtr subvariable(Space space, std::string_view baseName, Space shift={0}) const
Eigen::CommaInitializer< VectorRef > operator<<(double d)
Definition: Variable.h:321
bool isSubvariable() const
void set(const VectorConstRef &x)
VariablePtr superVariable() const
int size() const
VectorConstRef value() const
bool isBasePrimitive() const
bool isSuperVariable() const
bool isPrimitiveOf(const Variable &v) const
Variable(const Variable &)=delete
const Space & spaceShift() const
TVM_DEPRECATED void value(const VectorConstRef &x)
Definition: Variable.h:99
VariablePtr shared_from_this()
Range subvariableRange() const
VariablePtr primitive() const
Definition: Variable.h:297
bool isEuclidean() const
friend VariablePtr TVM_DLLAPI dot(VariablePtr, int, bool)
bool intersects(const Variable &v) const
const std::string & name() const
bool isSuperVariableOf(const Variable &v) const
void set(Eigen::DenseIndex idx, double value)
bool isDerivativeOf(const Variable &v) const
Range getMappingIn(const VariableVector &variables) const
Variable(make_shared_token, VariablePtr var, const Space &space, std::string_view name, const Space &shift)
Range tSubvariableRange() const
void set(Eigen::DenseIndex idx, Eigen::DenseIndex length, const VectorConstRef &value)
VariablePtr duplicate(std::string_view name="") const
Definition: ObjWithId.h:14
Definition: Clock.h:12
std::shared_ptr< Variable > VariablePtr
Definition: defs.h:65
bool operator!=(const Variable &u, const Variable &v)
Definition: Variable.h:330
Eigen::Ref< Eigen::VectorXd > VectorRef
Definition: defs.h:51
bool operator==(const Variable &u, const Variable &v)
Definition: Variable.h:327
Eigen::Ref< const Eigen::VectorXd > VectorConstRef
Definition: defs.h:50
VariablePtr TVM_DLLAPI dot(VariablePtr var, int ndiff=1, bool autoName=false)