Halide 16.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
FunctionDAG.h
Go to the documentation of this file.
1/** This file defines the class FunctionDAG, which is our
2 * representation of a Halide pipeline, and contains methods to using
3 * Halide's bounds tools to query properties of it. */
4
5#ifndef FUNCTION_DAG_H
6#define FUNCTION_DAG_H
7
8#include <algorithm>
9#include <cstdint>
10#include <map>
11#include <string>
12#include <utility>
13#include <vector>
14
15#include "Errors.h"
16#include "Featurization.h"
17#include "Halide.h"
18#include "PerfectHashMap.h"
19
20namespace Halide {
21namespace Internal {
22namespace Autoscheduler {
23
24using std::map;
25using std::pair;
26using std::string;
27using std::unique_ptr;
28using std::vector;
29
30// First we have various utility classes.
31
32// An optional rational type used when analyzing memory dependencies.
33struct OptionalRational {
35
36 bool exists() const {
37 return denominator != 0;
38 }
39
40 OptionalRational() = default;
44
46 if ((denominator & other.denominator) == 0) {
48 return;
49 }
50 if (denominator == other.denominator) {
51 numerator += other.numerator;
52 return;
53 }
54
55 int64_t l = lcm(denominator, other.denominator);
57 denominator = l;
58 numerator += other.numerator * (l / other.denominator);
60 numerator /= g;
61 denominator /= g;
62 }
63
65 if ((*this) == 0) {
66 return *this;
67 }
68 int64_t num = numerator * factor;
70 }
71
73 if ((*this) == 0) {
74 return *this;
75 }
76 if (other == 0) {
77 return other;
78 }
79 int64_t num = numerator * other.numerator;
80 int64_t den = denominator * other.denominator;
81 return OptionalRational{num, den};
82 }
83
84 // Because this type is optional (exists may be false), we don't
85 // have a total ordering. These methods all return false when the
86 // operators are not comparable, so a < b is not the same as !(a
87 // >= b).
88 bool operator<(int x) const {
89 if (denominator == 0) {
90 return false;
91 } else if (denominator > 0) {
92 return numerator < x * denominator;
93 } else {
94 return numerator > x * denominator;
95 }
96 }
97
98 bool operator<=(int x) const {
99 if (denominator == 0) {
100 return false;
101 } else if (denominator > 0) {
102 return numerator <= x * denominator;
103 } else {
104 return numerator >= x * denominator;
105 }
106 }
107
108 bool operator>(int x) const {
109 if (!exists()) {
110 return false;
111 }
112 return !((*this) <= x);
113 }
114
115 bool operator>=(int x) const {
116 if (!exists()) {
117 return false;
118 }
119 return !((*this) < x);
120 }
121
122 bool operator==(int x) const {
123 return exists() && (numerator == (x * denominator));
124 }
125
126 bool operator==(const OptionalRational &other) const {
127 return (exists() == other.exists()) && (numerator * other.denominator == denominator * other.numerator);
128 }
129};
130
131// A LoadJacobian records the derivative of the coordinate accessed in
132// some producer w.r.t the loops of the consumer.
133class LoadJacobian {
134 std::vector<OptionalRational> coeffs;
135 int64_t c;
136 size_t rows, cols;
137
138public:
141 coeffs.resize(rows * cols);
142 }
143
144 bool all_coeffs_exist() const {
145 for (const auto &coeff : coeffs) {
146 if (!coeff.exists()) {
147 return false;
148 }
149 }
150 return true;
151 }
152
153 bool empty() const {
154 return rows == 0;
155 }
156
157 size_t producer_storage_dims() const {
158 return rows;
159 }
160
161 size_t consumer_loop_dims() const {
162 return cols;
163 }
164
165 bool is_constant() const {
166 for (const auto &c : coeffs) {
167 if (!c.exists() || !(c == 0)) {
168 return false;
169 }
170 }
171
172 return true;
173 }
174
176 if (producer_storage_dims() == 0 || consumer_loop_dims() == 0) {
177 // The producer or consumer is scalar, so all strides are zero.
178 return {0, 1};
179 }
180 return coeffs[producer_storage_dim * cols + consumer_loop_dim];
181 }
182
186
187 // To avoid redundantly re-recording copies of the same
188 // load Jacobian, we keep a count of how many times a
189 // load with this Jacobian occurs.
190 int64_t count() const {
191 return c;
192 }
193
194 // Try to merge another LoadJacobian into this one, increasing the
195 // count if the coefficients match.
196 bool merge(const LoadJacobian &other) {
197 if (other.rows != rows || other.cols != cols) {
198 return false;
199 }
200 for (size_t i = 0; i < rows * cols; i++) {
201 if (!(other.coeffs[i] == coeffs[i])) {
202 return false;
203 }
204 }
205 c += other.count();
206 return true;
207 }
208
209 // Scale the matrix coefficients by the given factors
210 LoadJacobian operator*(const std::vector<int64_t> &factors) const {
211 LoadJacobian result(rows, cols, c);
212 for (size_t i = 0; i < producer_storage_dims(); i++) {
213 for (size_t j = 0; j < consumer_loop_dims(); j++) {
214 result(i, j) = (*this)(i, j) * factors[j];
215 }
216 }
217 return result;
218 }
219
220 // Multiply Jacobians, used to look at memory dependencies through
221 // inlined functions.
223 LoadJacobian result(producer_storage_dims(), other.consumer_loop_dims(), count() * other.count());
224 for (size_t i = 0; i < producer_storage_dims(); i++) {
225 for (size_t j = 0; j < other.consumer_loop_dims(); j++) {
226 result(i, j) = OptionalRational{0, 1};
227 for (size_t k = 0; k < consumer_loop_dims(); k++) {
228 result(i, j) += (*this)(i, k) * other(k, j);
229 }
230 }
231 }
232 return result;
233 }
234
235 void dump(const char *prefix) const;
236};
237
238// Classes to represent a concrete set of bounds for a Func. A Span is
239// single-dimensional, and a Bound is a multi-dimensional box. For
240// each dimension we track the estimated size, and also whether or not
241// the size is known to be constant at compile-time. For each Func we
242// track three different types of bounds:
243
244// 1) The region required by consumers of the Func, which determines
245// 2) The region actually computed, which in turn determines
246// 3) The min and max of all loops in the loop next.
247
248// 3 in turn determines the region required of the inputs to a Func,
249// which determines their region computed, and hence their loop nest,
250// and so on back up the Function DAG from outputs back to inputs.
251
252class Span {
253 int64_t min_, max_;
254 bool constant_extent_;
255
256public:
257 int64_t min() const {
258 return min_;
259 }
260 int64_t max() const {
261 return max_;
262 }
263 int64_t extent() const {
264 return max_ - min_ + 1;
265 }
266 bool constant_extent() const {
267 return constant_extent_;
268 }
269
270 void union_with(const Span &other) {
271 min_ = std::min(min_, other.min());
272 max_ = std::max(max_, other.max());
273 constant_extent_ = constant_extent_ && other.constant_extent();
274 }
275
277 max_ = min_ + e - 1;
278 }
279
281 min_ += x;
282 max_ += x;
283 }
284
285 Span(int64_t a, int64_t b, bool c)
286 : min_(a), max_(b), constant_extent_(c) {
287 }
288 Span() = default;
289 Span(const Span &other) = default;
290 static Span empty_span() {
291 return Span(INT64_MAX, INT64_MIN, true);
292 }
293};
294
295// Bounds objects are created and destroyed very frequently while
296// exploring scheduling options, so we have a custom allocator and
297// memory pool. Much like IR nodes, we treat them as immutable once
298// created and wrapped in a Bound object so that they can be shared
299// safely across scheduling alternatives.
300struct BoundContents {
301 mutable RefCount ref_count;
302
303 class Layout;
304 const Layout *layout = nullptr;
305
306 Span *data() const {
307 // This struct is a header
308 return (Span *)(const_cast<BoundContents *>(this) + 1);
309 }
310
312 return data()[i];
313 }
314
316 return data()[i + layout->computed_offset];
317 }
318
319 Span &loops(int i, int j) {
320 return data()[j + layout->loop_offset[i]];
321 }
322
323 const Span &region_required(int i) const {
324 return data()[i];
325 }
326
327 const Span &region_computed(int i) const {
328 return data()[i + layout->computed_offset];
329 }
330
331 const Span &loops(int i, int j) const {
332 return data()[j + layout->loop_offset[i]];
333 }
334
336 auto *b = layout->make();
337 size_t bytes = sizeof(data()[0]) * layout->total_size;
338 memcpy(b->data(), data(), bytes);
339 return b;
340 }
341
342 void validate() const;
343
344 // We're frequently going to need to make these concrete bounds
345 // arrays. It makes things more efficient if we figure out the
346 // memory layout of those data structures once ahead of time, and
347 // make each individual instance just use that. Note that this is
348 // not thread-safe.
349 class Layout {
350 // A memory pool of free BoundContent objects with this layout
351 mutable std::vector<BoundContents *> pool;
352
353 // All the blocks of memory allocated
354 mutable std::vector<void *> blocks;
355
356 mutable size_t num_live = 0;
357
358 void allocate_some_more() const;
359
360 public:
361 // number of Span to allocate
362 int total_size;
363
364 // region_computed comes next at the following index
365 int computed_offset;
366
367 // the loop for each stage starts at the following index
368 std::vector<int> loop_offset;
369
370 Layout() = default;
372
373 Layout(const Layout &) = delete;
374 void operator=(const Layout &) = delete;
375 Layout(Layout &&) = delete;
376 void operator=(Layout &&) = delete;
377
378 // Make a BoundContents object with this layout
380
381 // Release a BoundContents object with this layout back to the pool
382 void release(const BoundContents *b) const;
383 };
384};
385
387
388// A representation of the function DAG. The nodes and edges are both
389// in reverse realization order, so if you want to walk backwards up
390// the DAG, just iterate the nodes or edges in-order.
391struct FunctionDAG {
392
393 // An edge is a producer-consumer relationship
394 struct Edge;
395
396 struct SymbolicInterval {
399 };
400
401 // A Node represents a single Func
402 struct Node {
403 // A pointer back to the owning DAG
404 FunctionDAG *dag;
405
406 // The Halide Func this represents
408
409 // The number of bytes per point stored.
410 double bytes_per_point;
411
412 // The min/max variables used to denote a symbolic region of
413 // this Func. Used in the cost above, and in the Edges below.
415
416 // A concrete region required from a bounds estimate. Only
417 // defined for outputs.
419
420 // The region computed of a Func, in terms of the region
421 // required. For simple Funcs this is identical to the
422 // region_required. However, in some Funcs computing one
423 // output requires computing other outputs too. You can't
424 // really ask for a single output pixel from something blurred
425 // with an IIR without computing the others, for example.
426 struct RegionComputedInfo {
427 // The min and max in their full symbolic glory. We use
428 // these in the general case.
429 Interval in;
430
431 // Analysis used to accelerate common cases
433 int64_t c_min = 0, c_max = 0;
434 };
437
438 // Expand a region required into a region computed, using the
439 // symbolic intervals above.
441
442 // Metadata about one symbolic loop in a Func's default loop nest.
443 struct Loop {
444 string var;
445 bool pure, rvar;
446 Expr min, max;
447
448 // Which pure dimension does this loop correspond to? Invalid if it's an rvar
449 int pure_dim;
450
451 // Precomputed metadata to accelerate common cases:
452
453 // If true, the loop bounds are just the region computed in the given dimension
454 bool equals_region_computed = false;
455 int region_computed_dim = 0;
456
457 // If true, the loop bounds are a constant with the given min and max
458 bool bounds_are_constant = false;
459 int64_t c_min = 0, c_max = 0;
460
461 // A persistent fragment of source for getting this Var
462 // from its owner Func. Used for printing source code
463 // equivalent to a computed schedule.
464 string accessor;
465 };
466
467 // Get the loop nest shape as a function of the region computed
468 void loop_nest_for_region(int stage_idx, const Span *computed, Span *loop) const;
469
470 // One stage of a Func
471 struct Stage {
472 // The owning Node
473 Node *node;
474
475 // Which stage of the Func is this. 0 = pure.
476 int index;
477
478 // The loop nest that computes this stage, from innermost out.
480 bool loop_nest_all_common_cases = false;
481
482 // The vectorization width that will be used for
483 // compute. Corresponds to the natural width for the
484 // narrowest type used.
485 int vector_size;
486
487 // The featurization of the compute done
489
490 // The actual Halide front-end stage object
492
493 // The name for scheduling (e.g. "foo.update(3)")
494 string name;
495
497
498 // Ids for perfect hashing on stages.
499 int id, max_id;
500
501 std::unique_ptr<LoadJacobian> store_jacobian;
502
504
505 vector<bool> dependencies;
506 bool downstream_of(const Node &n) const {
507 return dependencies[n.id];
508 };
509
511 : stage(std::move(s)) {
512 }
513
514 int get_loop_index_from_var(const std::string &var) const {
515 int i = 0;
516 for (const auto &l : loop) {
517 if (l.var == var) {
518 return i;
519 }
520
521 ++i;
522 }
523
524 return -1;
525 }
526 };
528
530
531 // Max vector size across the stages
532 int vector_size;
533
534 // A unique ID for this node, allocated consecutively starting
535 // at zero for each pipeline.
536 int id, max_id;
537
538 // Just func->dimensions(), but we ask for it so many times
539 // that's it's worth avoiding the function call into
540 // libHalide.
541 int dimensions;
542
543 // Is a single pointwise call to another Func
544 bool is_wrapper;
545
546 // We represent the input buffers as node, though we do not attempt to schedule them.
547 bool is_input;
548
549 // Is one of the pipeline outputs
550 bool is_output;
551
552 // Only uses pointwise calls
553 bool is_pointwise;
554
555 // Only uses pointwise calls + clamping on all indices
557
558 std::unique_ptr<BoundContents::Layout> bounds_memory_layout;
559
561 return bounds_memory_layout->make();
562 }
563 };
564
565 // A representation of a producer-consumer relationship
566 struct Edge {
567 struct BoundInfo {
568 // The symbolic expression for the bound in this dimension
569 Expr expr;
570
571 // Fields below are the results of additional analysis
572 // used to evaluate this bound more quickly.
575 bool affine, uses_max;
576
578 };
579
580 // Memory footprint on producer required by consumer.
582
585
586 // The number of calls the consumer makes to the producer, per
587 // point in the loop nest of the consumer.
588 int calls;
589
591
593
595
597
598 // Given a loop nest of the consumer stage, expand a region
599 // required of the producer to be large enough to include all
600 // points required.
602 };
603
606
608
609 // We're going to be querying this DAG a lot while searching for
610 // an optimal schedule, so we'll also create a variety of
611 // auxiliary data structures.
613
614 // Create the function DAG, and do all the dependency and cost
615 // analysis. This is done once up-front before the tree search.
616 FunctionDAG(const vector<Function> &outputs, const Target &target);
617
618 void dump() const;
619 std::ostream &dump(std::ostream &os) const;
620
621 // This class uses a lot of internal pointers, so we'll hide the copy constructor.
622 FunctionDAG(const FunctionDAG &other) = delete;
623 void operator=(const FunctionDAG &other) = delete;
624
625private:
626 // Compute the featurization for the entire DAG
627 void featurize();
628
629 template<typename OS>
630 void dump_internal(OS &os) const;
631};
632
633template<typename T>
635
636class ExprBranching : public VariadicVisitor<ExprBranching, int, int> {
638
639private:
640 const NodeMap<int64_t> &inlined;
641
642public:
643 int visit(const Reinterpret *op);
644 int visit(const IntImm *op);
645 int visit(const UIntImm *op);
646 int visit(const FloatImm *op);
647 int visit(const StringImm *op);
648 int visit(const Broadcast *op);
649 int visit(const Cast *op);
650 int visit(const Variable *op);
651 int visit(const Add *op);
652 int visit(const Sub *op);
653 int visit(const Mul *op);
654 int visit(const Div *op);
655 int visit(const Mod *op);
656 int visit(const Min *op);
657 int visit(const Max *op);
658 int visit(const EQ *op);
659 int visit(const NE *op);
660 int visit(const LT *op);
661 int visit(const LE *op);
662 int visit(const GT *op);
663 int visit(const GE *op);
664 int visit(const And *op);
665 int visit(const Or *op);
666 int visit(const Not *op);
667 int visit(const Select *op);
668 int visit(const Ramp *op);
669 int visit(const Load *op);
670 int visit(const Call *op);
671 int visit(const Shuffle *op);
672 int visit(const Let *op);
673 int visit(const VectorReduce *op);
674 int visit_binary(const Expr &a, const Expr &b);
675 int visit_nary(const std::vector<Expr> &exprs);
676
678 : inlined{inlined} {
679 }
680
681 int compute(const Function &f);
682};
683
684void sanitize_names(std::string &str);
685
686} // namespace Autoscheduler
687} // namespace Internal
688} // namespace Halide
689
690#endif // FUNCTION_DAG_H
void release(const BoundContents *b) const
int visit_binary(const Expr &a, const Expr &b)
ExprBranching(const NodeMap< int64_t > &inlined)
int visit_nary(const std::vector< Expr > &exprs)
OptionalRational operator()(int producer_storage_dim, int consumer_loop_dim) const
LoadJacobian operator*(const LoadJacobian &other) const
LoadJacobian(size_t producer_storage_dims, size_t consumer_loop_dims, int64_t count)
OptionalRational & operator()(int producer_storage_dim, int consumer_loop_dim)
bool merge(const LoadJacobian &other)
void dump(const char *prefix) const
LoadJacobian operator*(const std::vector< int64_t > &factors) const
Span(int64_t a, int64_t b, bool c)
Span(const Span &other)=default
void union_with(const Span &other)
A reference-counted handle to Halide's internal representation of a function.
Definition Function.h:39
A class representing a reference count to be used with IntrusivePtr.
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition IRVisitor.h:159
A single definition of a Func.
Definition Func.h:70
A Halide variable, to be used when defining functions.
Definition Var.h:19
void sanitize_names(std::string &str)
int64_t gcd(int64_t, int64_t)
The greatest common divisor of two integers.
int64_t lcm(int64_t, int64_t)
The least common multiple of two integers.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
Expr cast(Expr a)
Cast an expression to the halide type corresponding to the C++ type T.
Definition IROperator.h:358
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
void * memcpy(void *s1, const void *s2, size_t n)
A fragment of Halide syntax.
Definition Expr.h:257
The sum of two expressions.
Definition IR.h:48
Logical and - are both expressions true.
Definition IR.h:167
const Span & loops(int i, int j) const
BoundInfo(const Expr &e, const Node::Stage &consumer)
void expand_footprint(const Span *consumer_loop, Span *producer_required) const
vector< pair< BoundInfo, BoundInfo > > bounds
int get_loop_index_from_var(const std::string &var) const
void loop_nest_for_region(int stage_idx, const Span *computed, Span *loop) const
std::unique_ptr< BoundContents::Layout > bounds_memory_layout
void required_to_computed(const Span *required, Span *computed) const
void operator=(const FunctionDAG &other)=delete
FunctionDAG(const FunctionDAG &other)=delete
map< int, const Node * > stage_id_to_node_map
std::ostream & dump(std::ostream &os) const
FunctionDAG(const vector< Function > &outputs, const Target &target)
void operator+=(const OptionalRational &other)
Definition FunctionDAG.h:45
bool operator==(const OptionalRational &other) const
OptionalRational operator*(int64_t factor) const
Definition FunctionDAG.h:64
OptionalRational operator*(const OptionalRational &other) const
Definition FunctionDAG.h:72
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:251
A function call.
Definition IR.h:482
The actual IR nodes begin here.
Definition IR.h:29
The ratio of two expressions.
Definition IR.h:75
Is the first expression equal to the second.
Definition IR.h:113
Floating point constants.
Definition Expr.h:235
Is the first expression greater than or equal to the second.
Definition IR.h:158
Is the first expression greater than the second.
Definition IR.h:149
Integer constants.
Definition Expr.h:217
A class to represent ranges of Exprs.
Definition Interval.h:14
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
Is the first expression less than or equal to the second.
Definition IR.h:140
Is the first expression less than the second.
Definition IR.h:131
A let expression, like you might find in a functional language.
Definition IR.h:263
Load a value from a named symbol if predicate is true.
Definition IR.h:209
The greater of two values.
Definition IR.h:104
The lesser of two values.
Definition IR.h:95
The remainder of a / b.
Definition IR.h:86
The product of two expressions.
Definition IR.h:66
Is the first expression not equal to the second.
Definition IR.h:122
Logical not - true if the expression false.
Definition IR.h:185
Logical or - is at least one of the expression true.
Definition IR.h:176
A linear ramp vector node.
Definition IR.h:239
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition IR.h:39
A ternary operator.
Definition IR.h:196
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:819
String constants.
Definition Expr.h:244
The difference of two expressions.
Definition IR.h:57
Unsigned integer constants.
Definition Expr.h:226
A named variable.
Definition IR.h:741
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:929
A struct representing a target machine and os to generate code for.
Definition Target.h:19