Skip to content

Commit

Permalink
Merge pull request #165 from rapidsai/branch-0.5
Browse files Browse the repository at this point in the history
[gpuCI] Auto-merge branch-0.5 to branch-0.6 [skip ci]
  • Loading branch information
GPUtester authored Jan 30, 2019
2 parents 6277f1e + 7fc5e86 commit c651e6d
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- PR #94: Add cmake flag to set ABI compatibility
- PR #139: Move thirdparty submodules to root and add symlinks to new locations
- PR #151: Replace TravisCI testing and conda pkg builds with gpuCI
- PR #164: Add numba kernel for faster column to row major transform

## Bug Fixes

Expand Down
2 changes: 2 additions & 0 deletions python/cuML/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) 2018, NVIDIA CORPORATION.
# Versioneer
from cuML import numba_utils

from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
2 changes: 1 addition & 1 deletion python/cuML/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_config():
cfg.style = "pep440"
cfg.tag_prefix = "v"
cfg.parentdir_prefix = "cuml-"
cfg.versionfile_source = "python/cuML/_version.py"
cfg.versionfile_source = "cuML/_version.py"
cfg.verbose = False
return cfg

Expand Down
4 changes: 3 additions & 1 deletion python/cuML/dbscan/dbscan_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ from libcpp cimport bool
import ctypes
from libc.stdint cimport uintptr_t
from c_dbscan cimport *
# temporary import for numba_utils
from cuML import numba_utils


class DBSCAN:
Expand Down Expand Up @@ -82,7 +84,7 @@ class DBSCAN:
cdef uintptr_t input_ptr
if (isinstance(X, cudf.DataFrame)):
self.gdf_datatype = np.dtype(X[X.columns[0]]._column.dtype)
X_m = X.as_gpu_matrix(order = "C")
X_m = numba_utils.row_matrix(X)
self.n_rows = len(X)
self.n_cols = len(X._cols)

Expand Down
4 changes: 3 additions & 1 deletion python/cuML/kalman/kalman_filter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import numpy as np
from numba import cuda
# temporary import for numba_utils
from cuML import numba_utils


cdef extern from "kalman_filter/kf_variables.h" namespace "kf::linear":
Expand Down Expand Up @@ -520,7 +522,7 @@ class KalmanFilter:
def __setattr__(self, name, value):
if name in ["F", "x_up", "x", "P_up", "P", "Q", "H", "R", "z"]:
if (isinstance(value, cudf.DataFrame)):
val = value.as_gpu_matrix(order='C')
val = numba_utils.row_matrix(value)

elif (isinstance(value, cudf.Series)):
val = value.to_gpu_array()
Expand Down
82 changes: 82 additions & 0 deletions python/cuML/numba_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2018, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pandas as pd
import cudf
import numba
from librmm_cffi import librmm as rmm
from numba.cuda.cudadrv.driver import driver
import math
from numba import cuda


def row_matrix(df):
"""Compute the C (row major) version gpu matrix of df
This implements the algorithm documented in
http://devblogs.nvidia.com/parallelforall/efficient-matrix-transpose-cuda-cc/
:param a: an `np.ndarray` or a `DeviceNDArrayBase` subclass. If already on
the device its stream will be used to perform the transpose (and to copy
`b` to the device if necessary).
Adapted from numba:
https://github.com/numba/numba/blob/master/numba/cuda/kernels/transpose.py
To be replaced by CUDA ml-prim in upcoming version
"""

cols = [df._cols[k] for k in df._cols]
ncol = len(cols)
nrow = len(df)
dtype = cols[0].dtype

a = df.as_gpu_matrix(order='F')
b = rmm.device_array((nrow, ncol), dtype=dtype, order='C')
dtype = numba.typeof(a)

tpb = driver.get_device().MAX_THREADS_PER_BLOCK

tile_width = int(math.pow(2, math.log(tpb, 2) / 2))
tile_height = int(tpb / tile_width)

tile_shape = (tile_height, tile_width + 1)

@cuda.jit
def kernel(input, output):

tile = cuda.shared.array(shape=tile_shape, dtype=numba.float32)

tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x * cuda.blockDim.x
by = cuda.blockIdx.y * cuda.blockDim.y
y = by + tx
x = bx + ty

if by + ty < input.shape[0] and bx + tx < input.shape[1]:
tile[ty, tx] = input[by + ty, bx + tx]
cuda.syncthreads()
if y < output.shape[0] and x < output.shape[1]:
output[y, x] = tile[tx, ty]

# one block per tile, plus one for remainders
blocks = int((b.shape[1]) / tile_height + 1), int((b.shape[0]) / tile_width + 1)
# one thread per tile element
threads = tile_height, tile_width
kernel[blocks, threads](a, b)

return b
4 changes: 2 additions & 2 deletions python/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ exclude = cuML,ml-prims,__init__.py,versioneer.py
[versioneer]
VCS = git
style = pep440
versionfile_source = python/cuML/_version.py
versionfile_build = python/cuML/_version.py
versionfile_source = cuML/_version.py
versionfile_build = cuML/_version.py
tag_prefix = v
parentdir_prefix = cuml-
3 changes: 2 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

from setuptools import setup
from setuptools import setup, find_packages
from setuptools.extension import Extension
from Cython.Build import cythonize
import numpy
Expand Down Expand Up @@ -64,6 +64,7 @@
author="NVIDIA Corporation",
setup_requires=['cython'],
ext_modules=cythonize(extensions),
packages=find_packages(include=['cuML', 'cuML.*']),
install_requires=install_requires,
license="Apache",
cmdclass=versioneer.get_cmdclass(),
Expand Down

0 comments on commit c651e6d

Please sign in to comment.