forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConvShared.h
74 lines (61 loc) · 2.47 KB
/
ConvShared.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <ATen/ATen.h>
#include <ATen/cudnn/cudnn-wrapper.h>
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/native/ConvUtils.h>
namespace at { namespace native {
// ---------------------------------------------------------------------
//
// Helper classes
//
// ---------------------------------------------------------------------
// This POD struct is used to let us easily compute hashes of the
// parameters
struct ConvolutionParams
{
cudnnDataType_t dataType;
int input_size[2 + max_dim];
uint8_t input_dim;
at::MemoryFormat memory_format;
int weight_size[2 + max_dim];
int padding[max_dim];
int stride[max_dim];
int dilation[max_dim];
int64_t groups;
bool deterministic;
bool allow_tf32;
// NB: transposed purposely omitted: transposed just swaps
// forward and backward, so you can reuse the benchmark entry,
};
std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params);
// NB: This can't be a constructor, because then ConvolutionParams
// would not be a POD anymore.
// TODO: Use TensorGeometry here instead of the entire Tensor, which we
// don't actually need. (OTOH: We can always pass in
// grad_input/grad_output, so this is not very pressing)
void setConvolutionParams(
ConvolutionParams* params,
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool deterministic, bool allow_tf32);
std::string repro_from_args(const ConvolutionParams& args);
// ---------------------------------------------------------------------
//
// Raw functions
//
// ---------------------------------------------------------------------
void raw_cudnn_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32);
void raw_cudnn_convolution_backward_input_out(
const at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32);
void raw_cudnn_convolution_backward_weight_out(
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32);
}}