forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFunctionalTensorWrapper.h
408 lines (352 loc) · 16 KB
/
FunctionalTensorWrapper.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
#pragma once
#include <ATen/ArrayRef.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/List.h>
#include <ATen/core/boxing/BoxedKernel.h>
#include <ATen/core/boxing/impl/boxing.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/DispatchKey.h>
namespace at {
// Note [Functionalization Pass In Core]
// The Functionalization pass is used to remove aliasing from a pytorch program.
//
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
// It's also necessary in order to remove mutation from a program, which is
// needed in Functorch.
//
// Consider this program:
// a = torch.ones(...)
// b = a.view(...)
// b.add_(1)
//
// In this program, b is meant to alias with a due to the use of view(). At the
// end of the program, both a and b are full of 2's. However, backends that
// don't support aliasing aren't able to correctly implement the view()
// operator. Instead, they can opt into the Functionalization pass, which will
// sit between the user and the backend, and provide the necessary aliasing
// logic.
//
// The functionalization pass will turn the above program into a slightly
// different program that has the same semantics, transparently to the user,
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
// a.view_copy(...) # view() replaced with view_copy(). Backends like
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
// pass machinery knows that a and b are aliased - it applies b's mutation to a
// too.
//
// So, how does the functionalization pass keep track of which tensors are
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
// FunctionalTensorWrapper, which knows about its alias'd tensors.
//
// See Note [Functionalization: Alias Removal] for details on the aliasing
// machinery. See Note [Functionalization: Mutation Removal] for details on
// mutation removal.
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper(const Tensor& value);
// Additional constructor to create a FunctionalTensorWrapper directly from an
// underlying tensor that was created from a view. For example, the code b =
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
// view1_meta)
explicit FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const functionalization::ViewMeta& meta);
// Get the underlying, actual tensor, that doesn't know anything about
// functionalization.
const Tensor& value() const {
return value_;
};
// The concept of "level" is only ever important to functorch; it's exposed
// here as more of a hook for functorch to use.
int64_t level() const {
return level_;
};
void set_level(int64_t level) {
level_ = level;
}
bool has_metadata_mutation() const {
return has_metadata_mutation_;
};
// Denotes a mutation that's hidden from autograd,
// e.g. for the purposes of passing a tensor to a triton kernel
void mark_mutation_hidden_from_autograd() {
mutation_hidden_from_autograd_counter_++;
}
void mark_mutation_during_no_grad_or_inference_mode() {
mutation_during_no_grad_or_inference_mode_++;
}
// Are all the mutations happening to the tensor hidden from autograd
bool are_all_mutations_hidden_from_autograd() const {
return mutation_hidden_from_autograd_counter_ == mutation_counter_;
}
// Did all mutations happen under no_grad or inference_mode
// (We also need to ignore mutations fully hidden from autograd here)
bool are_all_mutations_under_no_grad_or_inference_mode() const {
return mutation_hidden_from_autograd_counter_ +
mutation_during_no_grad_or_inference_mode_ ==
mutation_counter_;
}
// Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
// Replay the views (if any) to regenerate the current tensor off of the
// updated alias.
void sync_();
// Performs step (1) of the sync. This is its own public API because it's
// needed by view_inplace ops like transpose_. See Note [Functionalization
// Pass - Inplace View Ops]
void regenerate_from_base();
// Performs step (2) of the sync. This is its own public API because it's
// needed by functorch. functorch wants to make sure that all input tensors to
// a functionalized program have been properly synced so it can properly
// propagate mutations to inputs. It can't just call sync_(), because the
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
// a noop. We use the reference count on storage_ to determine if the wrapper
// is aliased, and by the time functorch is ready to propagate updates to
// inputs, any intermediate views of the input created by the program will
// have been deallocated. This function also returns whether or not the base
// actually had any updates to apply.
bool apply_updates();
// Takes the current state of value_ and snapshots it, sending it as a pending
// update to the alias.
void commit_update();
// When any tensor is mutated, the tensor increments its alias's "generation".
// Separately, each tensor maintains its own "generation" counter, which is
// used to determine if it's up-to-date with its alias. The act of syncing a
// tensor will set a tensor's generation equal to its alias's generation.
bool is_up_to_date() const;
// Freezes the storage of this tensor, preventing subsequent mutations
void freeze_storage() const;
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
// describing the series of view ops that ran to generate the current tensor
// from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias.
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
// Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other);
// Returns whether the current tensor's data was ever mutated
bool has_data_mutation();
//
// Returns whether the current FunctionalTensorWrapper
// experienced a set_() call.
bool was_storage_changed() {
return was_storage_changed_;
}
// The functionalization pass can be used to remove mutations.
// It does so by replacing any mutation op with it's corresponding
// out-of-place op, followed by a call to replace_(). e.g:
//
// a.add_(1)
//
// will turn into:
//
// tmp = a.add(1)
// a.replace_(tmp)
//
// replace_() swaps out the wrapped tensor, value_, with tmp.
void replace_(const Tensor& other);
bool is_multi_output_view() {
return is_multi_output_view_;
}
// See Note[resize_() in functionalization pass]
void maybe_replace_storage(const Tensor& other);
// Replaces the storage with a new functional storage,
// and clears the view_metas_ stack.
// WARNING: Calling this function will sever the aliasing relationship between
// the current FunctionalTensorWrapper and any of its outstanding aliases.
// Please only call if you know what you're doing.
void _unsafe_reset_storage();
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
~FunctionalTensorWrapper() override = default;
// FunctionalTensorWrapper overrides all custom size/stride function,
// so that if the inner tensor has a custom implementation
// we make sure to call that implementation.
at::IntArrayRef sizes_custom() const override;
at::IntArrayRef strides_custom() const override;
int64_t dim_custom() const override;
int64_t numel_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymInt sym_size_custom(int64_t d) const override;
c10::SymIntArrayRef sym_strides_custom() const override;
c10::SymInt sym_storage_offset_custom() const override;
c10::Device device_custom() const override;
private:
const char* tensorimpl_type_name() const override;
void set_constructor_metadata();
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
// This is used to re-implement shallow_copy_and_detach for
// FunctionalTensorWrapper. The implementation is identical, but we just need
// to return a subclass instead of a plain TensorImpl.
// TODO: maybe it's possible to arrange for that to happen automatically
// without an override here?
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
void copy_tensor_metadata_and_refresh(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const;
// Note that value is not taken by reference: internally, the wrapper will
// change the value tensor that it points to over time.
Tensor value_;
int64_t level_{};
// These two counters are used for identifying
// whether all the mutations on a given tensor are hidden from autograd or
// not. If we have an input mutation that is hidden from autograd, then once
// we convert the input mutation to a copy_() we know it will be safe to hide
// the copy_() from autograd as well.
uint64_t mutation_counter_ = 0;
uint64_t mutation_hidden_from_autograd_counter_ = 0;
uint64_t mutation_during_no_grad_or_inference_mode_ = 0;
bool has_metadata_mutation_ = false;
bool is_multi_output_view_ = false;
// Did the tensor experience a set_() call.
bool was_storage_changed_ = false;
size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
protected:
static void copy_tensor_metadata(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change);
};
// Utility functions for the functionalization pass.
namespace functionalization {
namespace impl {
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
const Tensor& tensor) {
auto functional_impl =
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
return functional_impl;
}
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API bool isFunctionalTensor(ITensorListRef list);
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
TORCH_API c10::optional<Tensor> to_functional_tensor(
const c10::optional<Tensor>& tensor);
TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
TORCH_API void freeze_functional_tensor(const Tensor& tensor);
TORCH_API Tensor
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
TORCH_API c10::optional<Tensor> from_functional_tensor(
const c10::optional<Tensor>& t,
bool assert_functional = true);
TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
TORCH_API void sync(const at::Tensor& t);
TORCH_API void sync(const c10::optional<Tensor>& t);
TORCH_API void sync(const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API void sync(ITensorListRef t_list);
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
TORCH_API void replace_(
const ITensorListRef functional_tensor,
ITensorListRef other);
TORCH_API void commit_update(const Tensor& functional_tensor);
TORCH_API void commit_update(ITensorListRef functional_tensor);
TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
TORCH_API void mark_mutation_hidden_from_autograd(
const Tensor& functional_tensor);
TORCH_API bool are_all_mutations_hidden_from_autograd(
const Tensor& functional_tensor);
TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
const Tensor& functional_tensor);
// These two methods are XLA-specific logic and are no-ops
// for the normal functionalization flow.
TORCH_API void propagate_xla_data(
const Tensor& functional_tensor,
const Tensor& other);
TORCH_API void propagate_xla_data(
const ITensorListRef functional_tensor,
ITensorListRef other);
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta,
int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const Tensor& base,
const functionalization::ViewMeta& meta);
void mutate_view_meta(
const Tensor& self,
const functionalization::ViewMeta& meta);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset(
const std::vector<Tensor>& outs,
const std::vector<Tensor>& meta_outs);
// ~~~~~ TLS used in functionalization ~~~~~
TORCH_API bool getFunctionalizationReapplyViewsTLS();
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
class TORCH_API FunctionalizationReapplyViewsGuard {
public:
FunctionalizationReapplyViewsGuard(bool reapply_views)
: prev_(getFunctionalizationReapplyViewsTLS()) {
setFunctionalizationReapplyViewsTLS(reapply_views);
}
~FunctionalizationReapplyViewsGuard() {
setFunctionalizationReapplyViewsTLS(prev_);
}
FunctionalizationReapplyViewsGuard(
const FunctionalizationReapplyViewsGuard&) = delete;
FunctionalizationReapplyViewsGuard operator=(
const FunctionalizationReapplyViewsGuard&) = delete;
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
delete;
FunctionalizationReapplyViewsGuard operator=(
FunctionalizationReapplyViewsGuard&&) = delete;
private:
bool prev_;
};
} // namespace impl
// Helper function to call an out-of-place composite aten kernel that may use
// mutations / views internally, and functionalize them.
TORCH_API void functionalize_op_helper(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op final {};
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
static ReturnType call(
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
using FuncType = ReturnType(
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow(
(const char*)Op::name, (const char*)Op::overload_name)
.typed<FuncType>();
return c10::impl::BoxedKernelWrapper<FuncType>::call(
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
op,
// BoxedKernelWrapper knows to ignore this keyset argument,
// because functionalize_op_helper doesn't take in a DispatchKeySet
c10::DispatchKeySet(),
args...);
}
};
template <class Op>
using functionalize_aten_op =
_functionalize_aten_op<Op, false, typename Op::schema>;
template <class Op>
using functionalize_aten_op_symint =
_functionalize_aten_op<Op, true, typename Op::schema>;
} // namespace functionalization
} // namespace at