diff --git a/src/ATen/native/xpu/PsRoiAlign.cpp b/src/ATen/native/xpu/PsRoiAlign.cpp new file mode 100644 index 000000000..e781c68fc --- /dev/null +++ b/src/ATen/native/xpu/PsRoiAlign.cpp @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include +namespace at::native::xpu { + +std::tuple ps_roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + TORCH_CHECK(input.is_xpu(), "input must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + c10::DeviceGuard device_guard(input.device()); + return ps_roi_align_kernel( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +at::Tensor _ps_roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio) { + TORCH_CHECK(grad.is_xpu(), "grad must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(channel_mapping.is_xpu(), "channel_mapping must be a XPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + c10::DeviceGuard device_guard(grad.device()); + + return ps_roi_align_backward_kernel( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/PsRoiPool.cpp b/src/ATen/native/xpu/PsRoiPool.cpp new file mode 100644 index 000000000..a4ca4145e --- /dev/null +++ b/src/ATen/native/xpu/PsRoiPool.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +#include +#include +namespace at::native::xpu { + +std::tuple ps_roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + TORCH_CHECK(input.is_xpu(), "input must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + c10::DeviceGuard device_guard(input.device()); + return ps_roi_pool_kernel( + input, rois, spatial_scale, pooled_height, pooled_width); +} + +at::Tensor _ps_roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + TORCH_CHECK(grad.is_xpu(), "grad must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(channel_mapping.is_xpu(), "channel_mapping must be a XPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + c10::DeviceGuard device_guard(grad.device()); + + return ps_roi_pool_backward_kernel( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/RoiPool.cpp b/src/ATen/native/xpu/RoiPool.cpp new file mode 100644 index 000000000..ac86c52db --- /dev/null +++ b/src/ATen/native/xpu/RoiPool.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +#include +#include +namespace at::native::xpu { + +std::tuple roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + TORCH_CHECK(input.is_xpu(), "input must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + c10::DeviceGuard device_guard(input.device()); + return roi_pool_kernel( + input, rois, spatial_scale, pooled_height, pooled_width); +} + +at::Tensor _roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + TORCH_CHECK(grad.is_xpu(), "grad must be a XPU tensor"); + TORCH_CHECK(rois.is_xpu(), "rois must be a XPU tensor"); + TORCH_CHECK(argmax.is_xpu(), "argmax must be a XPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + argmax_t{argmax, "argmax", 3}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + c10::DeviceGuard device_guard(grad.device()); + + return roi_pool_backward_kernel( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 62e5770ba..77b44dae1 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -26,6 +26,12 @@ namespace native::xpu { Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_); Tensor roi_align(const Tensor& input, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, bool aligned); Tensor _roi_align_backward(const Tensor& grad, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t batch_size, int64_t channels, int64_t height, int64_t width, int64_t sampling_ratio, bool aligned); +std::tuple ps_roi_align(const Tensor& input, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio); +Tensor _ps_roi_align_backward(const Tensor& grad, const Tensor& rois, const Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, int64_t batch_size, int64_t channels, int64_t height, int64_t width); +std::tuple roi_pool(const Tensor& input, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); +Tensor _roi_pool_backward(const Tensor& grad, const Tensor& rois, const Tensor& argmax, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t batch_size, int64_t channels, int64_t height, int64_t width); +std::tuple ps_roi_pool(const Tensor& input, const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); +Tensor _ps_roi_pool_backward(const Tensor& grad, const Tensor& rois, const Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t batch_size, int64_t channels, int64_t height, int64_t width); } // Register op's implementation lazily since sometimes the op is not defined, @@ -38,6 +44,12 @@ static std::map torchvision_ops_dispatching_table_ = { {"torchvision::nms", false}, {"torchvision::roi_align", false}, {"torchvision::_roi_align_backward", false}, + {"torchvision::ps_roi_align", false}, + {"torchvision::_ps_roi_align_backward", false}, + {"torchvision::roi_pool", false}, + {"torchvision::_roi_pool_backward", false}, + {"torchvision::ps_roi_pool", false}, + {"torchvision::_ps_roi_pool_backward", false}, }; // Return: @@ -56,6 +68,18 @@ static bool lazy_registration_and_redispatch( TORCH_SELECTIVE_NAME("torchvision::roi_align"),TORCH_FN(at::native::xpu::roi_align)); m.impl( TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),TORCH_FN(at::native::xpu::_roi_align_backward)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),TORCH_FN(at::native::xpu::ps_roi_align)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),TORCH_FN(at::native::xpu::_ps_roi_align_backward)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"),TORCH_FN(at::native::xpu::roi_pool)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),TORCH_FN(at::native::xpu::_roi_pool_backward)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),TORCH_FN(at::native::xpu::ps_roi_pool)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),TORCH_FN(at::native::xpu::_ps_roi_pool_backward)); }; static const torch::detail::TorchLibraryInit diff --git a/src/ATen/native/xpu/sycl/PsRoiAlignKernels.cpp b/src/ATen/native/xpu/sycl/PsRoiAlignKernels.cpp new file mode 100644 index 000000000..f2006d9ab --- /dev/null +++ b/src/ATen/native/xpu/sycl/PsRoiAlignKernels.cpp @@ -0,0 +1,494 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +T bilinear_interpolate( + const T* input, + int height, + int width, + T y, + T x, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +struct PsRoiAlignForwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int c_out = (index / pooled_width_ / pooled_height_) % channels_out_; + int n = index / pooled_width_ / pooled_height_ / channels_out_; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height_ + ph) * pooled_width_ + pw; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale_ - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale_ - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale_ - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale_ - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_height / pooled_height_); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_width / pooled_width_); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const T* offset_input = + input_ + (roi_batch_ind * channels_ + c_in) * height_ * width_; + T out_sum = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = + bilinear_interpolate(offset_input, height_, width_, y, x, index); + out_sum += val; + } + } + out_sum /= count; + output_[index] = out_sum; + channel_mapping_[index] = c_in; + } + } + PsRoiAlignForwardKernel( + int nthreads, + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + const T* rois, + int channels_out, + T* output, + int* channel_mapping) + : nthreads_(nthreads), + input_(input), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + sampling_ratio_(sampling_ratio), + rois_(rois), + channels_out_(channels_out), + output_(output), + channel_mapping_(channel_mapping) {} + + private: + int nthreads_; + const T* input_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + int sampling_ratio_; + const T* rois_; + int channels_out_; + T* output_; + int* channel_mapping_; +}; + +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +struct PsRoiAlignBackwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int n = index / pooled_width_ / pooled_height_ / channels_out_; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale_ - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale_ - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale_ - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale_ - static_cast(0.5); + + // Force small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + int c_in = channel_mapping_[index]; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output_[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_height / pooled_height_); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio_ > 0) + ? sampling_ratio_ + : std::ceil(roi_width / pooled_width_); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const int offset = (roi_batch_ind * channels_ + c_in) * height_ * width_; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height_, + width_, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + offset + y_low * width_ + x_low), + static_cast(g1)); + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + offset + y_low * width_ + x_high), + static_cast(g2)); + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + offset + y_high * width_ + x_low), + static_cast(g3)); + atomicAdd( + (sycl_global_ptr< + T>)(grad_input_ + offset + y_high * width_ + x_high), + static_cast(g4)); + } // if + } // ix + } // iy + } // XPU_KERNEL_LOOP + } + PsRoiAlignBackwardKernel( + int nthreads, + const T* grad_output, + const int* channel_mapping, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + int channels_out, + T* grad_input, + const T* rois) + : nthreads_(nthreads), + grad_output_(grad_output), + channel_mapping_(channel_mapping), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + sampling_ratio_(sampling_ratio), + channels_out_(channels_out), + grad_input_(grad_input), + rois_(rois) {} + + private: + int nthreads_; + const T* grad_output_; + const int* channel_mapping_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + int sampling_ratio_; + int channels_out_; + T* grad_input_; + const T* rois_; +}; + +std::tuple ps_roi_align_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + at::Tensor output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + at::Tensor channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + int64_t global_range = std::min( + ceil_div(static_cast(output_size), static_cast(512)), + static_cast(4096)); + int64_t local_range = 512; + + if (output.numel() == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ps_roi_align_forward_kernel_xpu", [&] { + auto kfn = PsRoiAlignForwardKernel( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + channels_out, + output.data_ptr(), + channel_mapping.data_ptr()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return std::make_tuple(output, channel_mapping); +} + +Tensor ps_roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + int64_t global_range = std::min( + ceil_div(static_cast(grad.numel()), static_cast(512)), + static_cast(4096)); + int64_t local_range = 512; + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel_xpu"); + + auto grad_ = grad.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ps_roi_align_backward_kernel_xpu", [&] { + auto kfn = PsRoiAlignBackwardKernel( + grad.numel(), + grad_.data_ptr(), + channel_mapping.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.data_ptr(), + rois_.data_ptr()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return grad_input; +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/PsRoiAlignKernels.h b/src/ATen/native/xpu/sycl/PsRoiAlignKernels.h new file mode 100644 index 000000000..24cf97a22 --- /dev/null +++ b/src/ATen/native/xpu/sycl/PsRoiAlignKernels.h @@ -0,0 +1,26 @@ +#pragma once + +#include +namespace at::native::xpu { + +TORCH_XPU_API std::tuple ps_roi_align_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +TORCH_XPU_API Tensor ps_roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/PsRoiPoolKernels.cpp b/src/ATen/native/xpu/sycl/PsRoiPoolKernels.cpp new file mode 100644 index 000000000..d5bf13c20 --- /dev/null +++ b/src/ATen/native/xpu/sycl/PsRoiPoolKernels.cpp @@ -0,0 +1,322 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct PsRoiPoolForwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c_out, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int c_out = (index / pooled_width_ / pooled_height_) % channels_out_; + int n = index / pooled_width_ / pooled_height_ / channels_out_; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height_ + ph) * pooled_width_ + pw; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = std::round(offset_rois[1] * spatial_scale_); + int roi_start_h = std::round(offset_rois[2] * spatial_scale_); + int roi_end_w = std::round(offset_rois[3] * spatial_scale_); + int roi_end_h = std::round(offset_rois[4] * spatial_scale_); + + // Force malformed ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + int hstart = + static_cast(std::floor(static_cast(ph) * bin_size_h)); + int wstart = + static_cast(std::floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(std::ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(std::ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height_ - 1); + hend = std::min(std::max(hend + roi_start_h, 0), height_ - 1); + wstart = std::min(std::max(wstart + roi_start_w, 0), width_ - 1); + wend = std::min(std::max(wend + roi_start_w, 0), width_ - 1); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const T* offset_input = + input_ + (roi_batch_ind * channels_ + c_in) * height_ * width_; + T out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width_ + w; + out_sum += offset_input[input_index]; + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output_[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping_[index] = c_in; + } + } + PsRoiPoolForwardKernel( + int nthreads, + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + const T* rois, + int channels_out, + T* output, + int* channel_mapping) + : nthreads_(nthreads), + input_(input), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + rois_(rois), + channels_out_(channels_out), + output_(output), + channel_mapping_(channel_mapping) {} + + private: + int nthreads_; + const T* input_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + const T* rois_; + int channels_out_; + T* output_; + int* channel_mapping_; +}; + +std::tuple ps_roi_pool_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + at::Tensor output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + at::Tensor channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + int64_t global_range = std::min( + ceil_div(static_cast(output_size), static_cast(512)), + static_cast(4096)); + int64_t local_range = 512; + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ps_roi_pool_forward_kernel_xpu", [&] { + auto kfn = PsRoiPoolForwardKernel( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.data_ptr(), + channels_out, + output.data_ptr(), + channel_mapping.data_ptr()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return std::make_tuple(output, channel_mapping); +} + +template +struct PsRoiPoolBackwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int n = index / pooled_width_ / pooled_height_ / channels_out_; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = std::roundf(offset_rois[1] * spatial_scale_); + int roi_start_h = std::roundf(offset_rois[2] * spatial_scale_); + int roi_end_w = std::roundf(offset_rois[3] * spatial_scale_); + int roi_end_h = std::roundf(offset_rois[4] * spatial_scale_); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + int hstart = + static_cast(std::floor(static_cast(ph) * bin_size_h)); + int wstart = + static_cast(std::floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(std::ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(std::ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height_); + hend = std::min(std::max(hend + roi_start_h, 0), height_); + wstart = std::min(std::max(wstart + roi_start_w, 0), width_); + wend = std::min(std::max(wend + roi_start_w, 0), width_); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int c_in = channel_mapping_[index]; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = + is_empty ? static_cast(0) : grad_output_[index] / bin_area; + + const int offset = (roi_batch_ind * channels_ + c_in) * height_ * width_; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int grad_input_index = h * width_ + w; + atomicAdd( + (sycl_global_ptr)(grad_input_ + offset + grad_input_index), + static_cast(diff_val)); + } + } + } + } + PsRoiPoolBackwardKernel( + int nthreads, + const T* grad_output, + const int* channel_mapping, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int channels_out, + T* grad_input, + const T* rois) + : nthreads_(nthreads), + grad_output_(grad_output), + channel_mapping_(channel_mapping), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + channels_out_(channels_out), + grad_input_(grad_input), + rois_(rois) {} + + private: + int nthreads_; + const T* grad_output_; + const int* channel_mapping_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + int channels_out_; + T* grad_input_; + const T* rois_; +}; + +Tensor ps_roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + int64_t global_range = std::min( + ceil_div(static_cast(grad.numel()), static_cast(512)), + static_cast(4096)); + int64_t local_range = 512; + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + auto grad_ = grad.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ps_roi_pool_backward_kernel_xpu", [&] { + auto kfn = PsRoiPoolBackwardKernel( + grad.numel(), + grad_.data_ptr(), + channel_mapping.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.data_ptr(), + rois_.data_ptr()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return grad_input; +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/PsRoiPoolKernels.h b/src/ATen/native/xpu/sycl/PsRoiPoolKernels.h new file mode 100644 index 000000000..297d9591e --- /dev/null +++ b/src/ATen/native/xpu/sycl/PsRoiPoolKernels.h @@ -0,0 +1,24 @@ +#pragma once + +#include +namespace at::native::xpu { + +TORCH_XPU_API std::tuple ps_roi_pool_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +TORCH_XPU_API Tensor ps_roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp b/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp index f95d82c42..5534874cd 100644 --- a/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp +++ b/src/ATen/native/xpu/sycl/RoiAlignKernels.cpp @@ -156,6 +156,7 @@ struct RoiAlignForwardKernel { width_(width), pooled_height_(pooled_height), pooled_width_(pooled_width), + sampling_ratio_(sampling_ratio), aligned_(aligned), rois_(rois), output_(output) {} diff --git a/src/ATen/native/xpu/sycl/RoiPoolKernels.cpp b/src/ATen/native/xpu/sycl/RoiPoolKernels.cpp new file mode 100644 index 000000000..279d70a7d --- /dev/null +++ b/src/ATen/native/xpu/sycl/RoiPoolKernels.cpp @@ -0,0 +1,305 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct RoiPoolForwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int c = (index / pooled_width_ / pooled_height_) % channels_; + int n = index / pooled_width_ / pooled_height_ / channels_; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = std::round(offset_rois[1] * spatial_scale_); + int roi_start_h = std::round(offset_rois[2] * spatial_scale_); + int roi_end_w = std::round(offset_rois[3] * spatial_scale_); + int roi_end_h = std::round(offset_rois[4] * spatial_scale_); + + // Force malformed ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height_); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); + + int hstart = + static_cast(std::floor(static_cast(ph) * bin_size_h)); + int wstart = + static_cast(std::floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(std::ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(std::ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height_); + hend = std::min(std::max(hend + roi_start_h, 0), height_); + wstart = std::min(std::max(wstart + roi_start_w, 0), width_); + wend = std::min(std::max(wend + roi_start_w, 0), width_); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0.0 : std::numeric_limits::lowest(); + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + const T* offset_input = + input_ + (roi_batch_ind * channels_ + c) * height_ * width_; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width_ + w; + if (offset_input[input_index] > maxval) { + maxval = offset_input[input_index]; + maxidx = input_index; + } + } + } + output_[index] = maxval; + argmax_[index] = maxidx; + } + } + RoiPoolForwardKernel( + int nthreads, + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + const T* rois, + T* output, + int* argmax) + : nthreads_(nthreads), + input_(input), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + rois_(rois), + output_(output), + argmax_(argmax) {} + + private: + int nthreads_; + const T* input_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + const T* rois_; + T* output_; + int* argmax_; +}; + +std::tuple roi_pool_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.options().dtype(at::kInt)); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + int64_t global_range = + ceil_div(static_cast(output_size), static_cast(512)); + int64_t local_range = 512; + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_pool_forward_kernel_xpu", [&] { + auto kfn = RoiPoolForwardKernel( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.data_ptr(), + output.data_ptr(), + argmax.data_ptr()); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return std::make_tuple(output, argmax); +} + +template +struct RoiPoolBackwardKernel { + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, nthreads_) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width_; + int ph = (index / pooled_width_) % pooled_height_; + int c = (index / pooled_width_ / pooled_height_) % channels_; + int n = index / pooled_width_ / pooled_height_ / channels_; + + const T* offset_rois = rois_ + n * 5; + int roi_batch_ind = offset_rois[0]; + + const int output_offset = n * n_stride_ + c * c_stride_; + const int* argmax_data_offset = + argmax_data_ + (n * channels_ + c) * pooled_height_ * pooled_width_; + const int argmax = argmax_data_offset[ph * pooled_width_ + pw]; + const int offset = (roi_batch_ind * channels_ + c) * height_ * width_; + + if (argmax != -1) { + atomicAdd( + (sycl_global_ptr)(grad_input_ + offset + argmax), + static_cast( + grad_output_[output_offset + ph * h_stride_ + pw * w_stride_])); + } + } + } + RoiPoolBackwardKernel( + int nthreads, + const T* grad_output, + const int* argmax_data, + int num_rois, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + T* grad_input, + const T* rois, + int n_stride, + int c_stride, + int h_stride, + int w_stride) + : nthreads_(nthreads), + grad_output_(grad_output), + argmax_data_(argmax_data), + num_rois_(num_rois), + spatial_scale_(spatial_scale), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + grad_input_(grad_input), + rois_(rois), + n_stride_(n_stride), + c_stride_(c_stride), + h_stride_(h_stride), + w_stride_(w_stride) {} + + private: + int nthreads_; + const T* grad_output_; + const int* argmax_data_; + int num_rois_; + const T spatial_scale_; + int channels_; + int height_; + int width_; + int pooled_height_; + int pooled_width_; + T* grad_input_; + const T* rois_; + int n_stride_; + int c_stride_; + int h_stride_; + int w_stride_; +}; + +Tensor roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + int64_t global_range = std::min( + ceil_div(static_cast(grad.numel()), static_cast(512)), + static_cast(4096)); + int64_t local_range = 512; + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto num_rois = rois.size(0); + auto argmax_ = argmax.contiguous(); + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "roi_pool_backward_kernel_xpu", [&] { + auto kfn = RoiPoolBackwardKernel( + grad.numel(), + grad.data_ptr(), + argmax_.data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + sycl_kernel_submit( + global_range * local_range, + local_range, + at::xpu::getCurrentSYCLQueue(), + kfn); + }); + return grad_input; +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/RoiPoolKernels.h b/src/ATen/native/xpu/sycl/RoiPoolKernels.h new file mode 100644 index 000000000..1adac5177 --- /dev/null +++ b/src/ATen/native/xpu/sycl/RoiPoolKernels.h @@ -0,0 +1,24 @@ +#pragma once + +#include +namespace at::native::xpu { + +TORCH_XPU_API std::tuple roi_pool_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +TORCH_XPU_API Tensor roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); +} // namespace at::native::xpu \ No newline at end of file diff --git a/test/regressions/optests_failures_dict.json b/test/regressions/optests_failures_dict.json new file mode 100644 index 000000000..3bad0bbb0 --- /dev/null +++ b/test/regressions/optests_failures_dict.json @@ -0,0 +1,5 @@ +{ + "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", + "_version": 1, + "data": {} +} diff --git a/test/regressions/test_torchvision_roi_align.py b/test/regressions/test_torchvision_roi_align.py deleted file mode 100644 index e0111bc72..000000000 --- a/test/regressions/test_torchvision_roi_align.py +++ /dev/null @@ -1,23 +0,0 @@ -# Owner(s): ["module: intel"] -import torch -from torch.testing._internal.common_utils import TestCase - - -class TestTorchVisionMethod(TestCase): - def test_roi_align(self): - atol = 1e-1 - rtol = 5e-5 - a_ref = torch.zeros([4, 256, 296, 304]).requires_grad_(True) - b_ref = torch.zeros([2292, 5]).requires_grad_(True) - - a_xpu = torch.zeros( - [4, 256, 296, 304], device=torch.device("xpu") - ).requires_grad_(True) - b_xpu = torch.zeros([2292, 5], device=torch.device("xpu")).requires_grad_(True) - - ref = torch.ops.torchvision.roi_align(a_ref, b_ref, 0.25, 7, 7, 2, False) - res = torch.ops.torchvision.roi_align(a_xpu, b_xpu, 0.25, 7, 7, 2, False) - ref.sum().backward() - res.sum().backward() - self.assertEqual(ref, res.cpu()) - self.assertEqual(a_ref.grad, a_xpu.grad.cpu(), rtol=rtol, atol=atol) diff --git a/test/regressions/test_torchvision_roi_ops.py b/test/regressions/test_torchvision_roi_ops.py new file mode 100644 index 000000000..2039984f0 --- /dev/null +++ b/test/regressions/test_torchvision_roi_ops.py @@ -0,0 +1,468 @@ +import math +from abc import ABC, abstractmethod + +import numpy as np +import pytest +import torch +import torch.fx +from torch import nn +from torch.autograd import gradcheck +from torchvision import ops +from torchvision.models.feature_extraction import get_graph_node_names + + +# Context manager for setting deterministic flag and automatically +# resetting it to its original value +class DeterministicGuard: + def __init__(self, deterministic, *, warn_only=False): + self.deterministic = deterministic + self.warn_only = warn_only + + def __enter__(self): + self.deterministic_restore = torch.are_deterministic_algorithms_enabled() + self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled() + torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only) + + def __exit__(self, exception_type, exception_value, traceback): + torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore) + + +class RoIOpTesterModuleWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 2 + + def forward(self, a, b): + self.layer(a, b) + + +class MultiScaleRoIAlignModuleWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 3 + + def forward(self, a, b, c): + self.layer(a, b, c) + + +class RoIOpTester(ABC): + dtype = torch.float64 + mps_dtype = torch.float32 + mps_backward_atol = 2e-2 + + @pytest.mark.parametrize("device", ("xpu",)) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize( + "x_dtype", + ( + torch.float16, + torch.float32, + torch.float64, + ), + ids=str, + ) + def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs): + if device == "mps" and x_dtype is torch.float64: + pytest.skip("MPS does not support float64") + + rois_dtype = x_dtype if rois_dtype is None else rois_dtype + + tol = 1e-5 + if x_dtype is torch.half: + if device == "mps": + tol = 5e-3 + else: + tol = 4e-3 + elif x_dtype == torch.bfloat16: + tol = 5e-3 + + pool_size = 5 + # n_channels % (pool_size ** 2) == 0 required for PS operations. + n_channels = 2 * (pool_size**2) + x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) + if not contiguous: + x = x.permute(0, 1, 3, 2) + rois = torch.tensor( + [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy) + dtype=rois_dtype, + device=device, + ) + + pool_h, pool_w = pool_size, pool_size + with DeterministicGuard(deterministic): + y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) + # the following should be true whether we're running an autocast test or not. + assert y.dtype == x.dtype + gt_y = self.expected_fn( + x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs + ) + + torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) + + @pytest.mark.parametrize("device", ("xpu",)) + def test_is_leaf_node(self, device): + op_obj = self.make_obj(wrap=True).to(device=device) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + + @pytest.mark.parametrize("device", ("xpu",)) + def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float): + op_obj = self.make_obj().to(device=device) + graph_module = torch.fx.symbolic_trace(op_obj) + pool_size = 5 + n_channels = 2 * (pool_size**2) + x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device) + rois = torch.tensor( + [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy) + dtype=rois_dtype, + device=device, + ) + output_gt = op_obj(x, rois) + assert output_gt.dtype == x.dtype + output_fx = graph_module(x, rois) + assert output_fx.dtype == x.dtype + tol = 1e-5 + torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) + + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("device", ("xpu",)) + @pytest.mark.parametrize("contiguous", (True, False)) + def test_backward(self, seed, device, contiguous, deterministic=False): + atol = self.mps_backward_atol if device == "mps" else 1e-05 + dtype = self.mps_dtype if device == "mps" else self.dtype + + torch.random.manual_seed(seed) + pool_size = 2 + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True) + if not contiguous: + x = x.permute(0, 1, 3, 2) + rois = torch.tensor( + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy) + ) + + def func(z): + return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) + + script_func = self.get_script_fn(rois, pool_size) + + with DeterministicGuard(deterministic): + gradcheck(func, (x,), atol=atol) + + gradcheck(script_func, (x,), atol=atol) + + @abstractmethod + def fn(*args, **kwargs): + pass + + @abstractmethod + def make_obj(*args, **kwargs): + pass + + @abstractmethod + def get_script_fn(*args, **kwargs): + pass + + @abstractmethod + def expected_fn(*args, **kwargs): + pass + + +class TestRoiPool(RoIOpTester): + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): + return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois) + + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False): + obj = ops.RoIPool((pool_h, pool_w), spatial_scale) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + + def get_script_fn(self, rois, pool_size): + scriped = torch.jit.script(ops.roi_pool) + return lambda x: scriped(x, rois, pool_size) + + def expected_fn( + self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64 + ): + if device is None: + device = torch.device("cpu") + + n_channels = x.size(1) + y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device) + + def get_slice(k, block): + return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block))) + + for roi_idx, roi in enumerate(rois): + batch_idx = int(roi[0]) + j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) + roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1] + + roi_h, roi_w = roi_x.shape[-2:] + bin_h = roi_h / pool_h + bin_w = roi_w / pool_w + + for i in range(0, pool_h): + for j in range(0, pool_w): + bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)] + if bin_x.numel() > 0: + y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0] + return y + + +class TestPSRoIPool(RoIOpTester): + mps_backward_atol = 5e-2 + + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): + return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) + + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False): + obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + + def get_script_fn(self, rois, pool_size): + scriped = torch.jit.script(ops.ps_roi_pool) + return lambda x: scriped(x, rois, pool_size) + + def expected_fn( + self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64 + ): + if device is None: + device = torch.device("cpu") + n_input_channels = x.size(1) + assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw" + n_output_channels = int(n_input_channels / (pool_h * pool_w)) + y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device) + + def get_slice(k, block): + return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block))) + + for roi_idx, roi in enumerate(rois): + batch_idx = int(roi[0]) + j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) + roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1] + + roi_height = max(i_end - i_begin, 1) + roi_width = max(j_end - j_begin, 1) + bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w) + + for i in range(0, pool_h): + for j in range(0, pool_w): + bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)] + if bin_x.numel() > 0: + area = bin_x.size(-2) * bin_x.size(-1) + for c_out in range(0, n_output_channels): + c_in = c_out * (pool_h * pool_w) + pool_w * i + j + t = torch.sum(bin_x[c_in, :, :]) + y[roi_idx, c_out, i, j] = t / area + return y + + +def bilinear_interpolate(data, y, x, snap_border=False): + height, width = data.shape + + if snap_border: + if -1 < y <= 0: + y = 0 + elif height - 1 <= y < height: + y = height - 1 + + if -1 < x <= 0: + x = 0 + elif width - 1 <= x < width: + x = width - 1 + + y_low = int(math.floor(y)) + x_low = int(math.floor(x)) + y_high = y_low + 1 + x_high = x_low + 1 + + wy_h = y - y_low + wx_h = x - x_low + wy_l = 1 - wy_h + wx_l = 1 - wx_h + + val = 0 + for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): + for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): + if 0 <= yp < height and 0 <= xp < width: + val += wx * wy * data[yp, xp] + return val + + +class TestRoIAlign(RoIOpTester): + mps_backward_atol = 6e-2 + + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): + return ops.RoIAlign( + (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned + )(x, rois) + + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False): + obj = ops.RoIAlign( + (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned + ) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + + def get_script_fn(self, rois, pool_size): + scriped = torch.jit.script(ops.roi_align) + return lambda x: scriped(x, rois, pool_size) + + def expected_fn( + self, + in_data, + rois, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + aligned=False, + device=None, + dtype=torch.float64, + ): + if device is None: + device = torch.device("cpu") + n_channels = in_data.size(1) + out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device) + + offset = 0.5 if aligned else 0.0 + + for r, roi in enumerate(rois): + batch_idx = int(roi[0]) + j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:]) + + roi_h = i_end - i_begin + roi_w = j_end - j_begin + bin_h = roi_h / pool_h + bin_w = roi_w / pool_w + + for i in range(0, pool_h): + start_h = i_begin + i * bin_h + grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h)) + for j in range(0, pool_w): + start_w = j_begin + j * bin_w + grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) + + for channel in range(0, n_channels): + val = 0 + for iy in range(0, grid_h): + y = start_h + (iy + 0.5) * bin_h / grid_h + for ix in range(0, grid_w): + x = start_w + (ix + 0.5) * bin_w / grid_w + val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True) + val /= grid_h * grid_w + + out_data[r, channel, i, j] = val + return out_data + + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("device", ("xpu",)) + @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) + @pytest.mark.opcheck_only_one() + def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None): + if deterministic and device == "cpu": + pytest.skip("cpu is always deterministic, don't retest") + super().test_forward( + device=device, + contiguous=contiguous, + deterministic=deterministic, + x_dtype=x_dtype, + rois_dtype=rois_dtype, + aligned=aligned, + ) + + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) + @pytest.mark.opcheck_only_one() + def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): + with torch.amp.autocast("xpu"): + self.test_forward( + torch.device("xpu"), + contiguous=False, + deterministic=deterministic, + aligned=aligned, + x_dtype=x_dtype, + rois_dtype=rois_dtype, + ) + + +class TestPSRoIAlign(RoIOpTester): + mps_backward_atol = 5e-2 + + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): + return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) + + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False): + obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + + def get_script_fn(self, rois, pool_size): + scriped = torch.jit.script(ops.ps_roi_align) + return lambda x: scriped(x, rois, pool_size) + + def expected_fn( + self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64 + ): + if device is None: + device = torch.device("cpu") + n_input_channels = in_data.size(1) + assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw" + n_output_channels = int(n_input_channels / (pool_h * pool_w)) + out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device) + + for r, roi in enumerate(rois): + batch_idx = int(roi[0]) + j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:]) + + roi_h = i_end - i_begin + roi_w = j_end - j_begin + bin_h = roi_h / pool_h + bin_w = roi_w / pool_w + + for i in range(0, pool_h): + start_h = i_begin + i * bin_h + grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h)) + for j in range(0, pool_w): + start_w = j_begin + j * bin_w + grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) + for c_out in range(0, n_output_channels): + c_in = c_out * (pool_h * pool_w) + pool_w * i + j + + val = 0 + for iy in range(0, grid_h): + y = start_h + (iy + 0.5) * bin_h / grid_h + for ix in range(0, grid_w): + x = start_w + (ix + 0.5) * bin_w / grid_w + val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True) + val /= grid_h * grid_w + + out_data[r, c_out, i, j] = val + return out_data + + +class TestMultiScaleRoIAlign: + def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False): + if fmap_names is None: + fmap_names = ["0"] + obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) + return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj + + @pytest.mark.parametrize("device", ("xpu",)) + def test_is_leaf_node(self, device): + op_obj = self.make_obj(wrap=True).to(device=device) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + + +if __name__ == "__main__": + pytest.main([__file__])