Skip to content

Commit

Permalink
Fix AEV and CUAEV Multi-GPU Device Bug (aiqm#597)
Browse files Browse the repository at this point in the history
* Fix AEV and CUAEV GPU Device Bug

* format

* apply review suggestion

* apply review suggestion
  • Loading branch information
yueyericardo authored Jul 19, 2021
1 parent ef83458 commit bf771af
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
17 changes: 13 additions & 4 deletions tests/test_cuaev.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

path = os.path.dirname(os.path.realpath(__file__))

skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(),
'There is no device to run this test')
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(), 'There is no device to run this test')
skipIfNoMultiGPU = unittest.skipIf(not torch.cuda.device_count() >= 2, 'There is not enough GPU devices to run this test')
skipIfNoCUAEV = unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed")


Expand Down Expand Up @@ -39,9 +39,9 @@ def testAEVComputer(self):
@skipIfNoCUAEV
class TestCUAEV(TestCase):

def setUp(self):
def setUp(self, device='cuda:0'):
self.tolerance = 5e-5
self.device = 'cuda'
self.device = device
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=self.device)
Expand Down Expand Up @@ -131,6 +131,15 @@ def testSimple(self):
_, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev)

@skipIfNoMultiGPU
def testMultiGPU(self):
self.setUp(device='cuda:1')
self.testSimple()
self.testSimpleBackward()
self.testSimpleDoubleBackward_1()
self.testSimpleDoubleBackward_2()
self.setUp(device='cuda:0')

def testSimpleBackward(self):
coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679],
Expand Down
30 changes: 25 additions & 5 deletions torchani/cuaev/aev.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#include <aev.h>
#include <torch/extension.h>
#include <cuaev_cub.cuh>
#include <vector>

#include <ATen/Context.h>
#include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <THC/THCThrustAllocator.cuh>
#include <vector>

#define PI 3.141592653589793
using torch::Tensor;
Expand Down Expand Up @@ -742,9 +745,13 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type");
TORCH_CHECK(
aev_params.EtaR_t.size(0) == 1 || aev_params.EtaA_t.size(0) == 1 || aev_params.Zeta_t.size(0) == 1,
aev_params.EtaR_t.size(0) == 1 && aev_params.EtaA_t.size(0) == 1 && aev_params.Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified "
"configuration");
TORCH_CHECK(
coordinates_t.device() == species_t.device() && coordinates_t.device() == aev_params.EtaR_t.device() &&
coordinates_t.device() == aev_params.EtaA_t.device(),
"coordinates, species, and aev_params should be on the same device");

float Rcr = aev_params.Rcr;
float Rca = aev_params.Rca;
Expand All @@ -759,7 +766,8 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
aev_t, Tensor(), Tensor(), Tensor(), 0, 0, 0, Tensor(), Tensor(), Tensor(), 0, 0, 0, coordinates_t, species_t};
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
auto& allocator = *c10::cuda::CUDACachingAllocator::get();

// buffer to store all the pairwise distance (Rij)
Expand Down Expand Up @@ -790,6 +798,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
max_natoms_per_mol);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule
Expand All @@ -800,6 +809,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
max_natoms_per_mol);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr
Expand All @@ -822,6 +832,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
aev_params.radial_length,
aev_params.radial_sublength,
nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();

// reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation
Expand Down Expand Up @@ -890,6 +901,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();

return {
aev_t,
Expand Down Expand Up @@ -917,7 +929,8 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para

const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3]

Expand All @@ -943,6 +956,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
aev_params.radial_length,
aev_params.radial_sublength,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();

// For best result, block_size should match average molecule size (no padding) to avoid atomicAdd
nblocks = (result.nRadialRij + block_size - 1) / block_size;
Expand All @@ -952,6 +966,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();

auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3;
Expand Down Expand Up @@ -991,6 +1006,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
result.maxnbrs_per_atom_aligned,
result.angular_length_aligned,
result.ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();

return grad_coord;
}
Expand All @@ -1002,7 +1018,8 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae

const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

int aev_length = aev_params.radial_length + aev_params.angular_length;

Expand All @@ -1027,6 +1044,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();

nblocks = (result.nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>(
Expand All @@ -1040,6 +1058,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
aev_params.radial_length,
aev_params.radial_sublength,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();

auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3;
Expand Down Expand Up @@ -1078,6 +1097,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
result.maxnbrs_per_atom_aligned,
result.angular_length_aligned,
result.ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();

return grad_grad_aev;
}

0 comments on commit bf771af

Please sign in to comment.