Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

at() method was rewritten in C. #145

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/cumo/include/cumo/intern.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void cumo_na_parse_enumerator_step(VALUE enum_obj, VALUE *pstep);
// used in aref, aset
int cumo_na_get_result_dimension(VALUE self, int argc, VALUE *argv, ssize_t stride, size_t *pos_idx);
VALUE cumo_na_aref_main(int nidx, VALUE *idx, VALUE self, int keep_dim, int result_nd, size_t pos);
VALUE cumo_na_at_main(int nidx, VALUE *idx, VALUE self, int keep_dim, int result_nd, size_t pos);

// defined in array, used in math
VALUE cumo_na_ary_composition_dtype(VALUE ary);
Expand Down
1 change: 1 addition & 0 deletions ext/cumo/narray/gen/spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
def_method "aref", op:"[]"
def_method "aref_cpu"
def_method "aset", op:"[]="
def_method "at"

def_method "coerce_cast"
def_method "to_a"
Expand Down
34 changes: 34 additions & 0 deletions ext/cumo/narray/gen/tmpl/at.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
Multi-dimensional array indexing.
Same as [] for one-dimensional NArray.
Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
@overload at(*indices)
@param [Numeric,Range,etc] *indices Multi-dimensional Index Arrays.
@return [Cumo::NArray::<%=class_name%>] one-dimensional NArray view.

@example
x = Cumo::DFloat.new(3,3,3).seq
=> Cumo::DFloat#shape=[3,3,3]
[[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]]

x.at([0,1,2],[0,1,2],[-1,-2,-3])
=> Cumo::DFloat(view)#shape=[3]
[2, 13, 24]
*/
static VALUE
<%=c_func(-1)%>(int argc, VALUE *argv, VALUE self)
{
int result_nd;
size_t pos;

result_nd = cumo_na_get_result_dimension(self, argc, argv, sizeof(dtype), &pos);
return cumo_na_at_main(argc, argv, self, 0, result_nd, pos);
}
221 changes: 214 additions & 7 deletions ext/cumo/narray/index.c
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,187 @@ cumo_na_index_aref_naview(cumo_narray_view_t *na1, cumo_narray_view_t *na2,
na2->base.size = total;
}

void cumo_na_index_at_nadata_index_stride_add_kernel_launch(size_t *idx, size_t *idx1, ssize_t s1, uint64_t n);
void cumo_na_index_at_nadata_index_beg_step_stride_kernel_launch(size_t *idx, size_t beg, ssize_t step, ssize_t s1, uint64_t n);
void cumo_na_index_at_nadata_index_beg_step_stride_add_kernel_launch(size_t *idx, size_t beg, ssize_t step, ssize_t s1, uint64_t n);

static void
cumo_na_index_at_nadata(cumo_narray_data_t *na1, cumo_narray_view_t *na2,
cumo_na_index_arg_t *q, ssize_t elmsz, int ndim, int keep_dim)
{
int i;
ssize_t size = q[ndim-1].n;
ssize_t stride1;
ssize_t *strides_na1;
size_t *index;
ssize_t beg, step;
int use_cumo_cuda_runtime_malloc = 0;

strides_na1 = ALLOCA_N(ssize_t, na1->base.ndim);
cumo_na_get_strides_nadata(na1, strides_na1, elmsz);

if (q[ndim-1].idx != NULL) {
index = q[ndim-1].idx;
} else {
//index = ALLOC_N(size_t, size);
index = (size_t*)cumo_cuda_runtime_malloc(sizeof(size_t)*size);
use_cumo_cuda_runtime_malloc = 1;
}
CUMO_SDX_SET_INDEX(na2->stridx[0], index);

for (i=ndim-1; i>=0; i--) {
stride1 = strides_na1[q[i].orig_dim];
if (i==ndim-1) {
if (size == 0) {
rb_raise(cumo_na_eShapeError, "cannot get element of empty array");
}
} else {
if (size != q[i].n) {
rb_raise(cumo_na_eShapeError, "index array sizes mismatch");
}
}

if (q[i].idx != NULL) {
if (i==ndim-1) {
cumo_na_index_aref_nadata_index_stride_kernel_launch(index, stride1, size);
} else {
cumo_na_index_at_nadata_index_stride_add_kernel_launch(index, q[i].idx, stride1, size);
}
q[i].idx = NULL;
} else {
beg = q[i].beg;
step = q[i].step;
if (i==ndim-1) {
cumo_na_index_at_nadata_index_beg_step_stride_kernel_launch(index, beg, step, stride1, size);
} else {
cumo_na_index_at_nadata_index_beg_step_stride_add_kernel_launch(index, beg, step, stride1, size);
}
}

}
na2->base.size = size;
na2->base.shape[0] = size;
if (use_cumo_cuda_runtime_malloc) {
CUMO_SHOW_SYNCHRONIZE_FIXME_WARNING_ONCE("index", "cumo_na_index_at_nadata");
cumo_cuda_runtime_check_status(cudaDeviceSynchronize());
}
}

void cumo_na_index_at_naview_index_index_index_add_kernel_launch(size_t *idx, size_t *idx1, size_t *idx2, uint64_t n);
void cumo_na_index_at_naview_index_index_beg_step_add_kernel_launch(size_t *idx, size_t *idx1, size_t beg, ssize_t step, uint64_t n);
void cumo_na_index_at_naview_index_stride_last_add_kernel_launch(size_t *idx, ssize_t s1, size_t last, uint64_t n);

static void
cumo_na_index_at_naview(cumo_narray_view_t *na1, cumo_narray_view_t *na2,
cumo_na_index_arg_t *q, ssize_t elmsz, int ndim, int keep_dim)
{
int i;
size_t *index;
ssize_t size = q[ndim-1].n;
int use_cumo_cuda_runtime_malloc = 0;

if (q[ndim-1].idx != NULL) {
index = q[ndim-1].idx;
} else {
//index = ALLOC_N(size_t, size);
index = (size_t*)cumo_cuda_runtime_malloc(sizeof(size_t)*size);
use_cumo_cuda_runtime_malloc = 1;
}
CUMO_SDX_SET_INDEX(na2->stridx[0], index);

for (i=ndim-1; i>=0; i--) {
cumo_stridx_t sdx1 = na1->stridx[q[i].orig_dim];
if (i==ndim-1) {
if (size == 0) {
rb_raise(cumo_na_eShapeError, "cannot get element of empty array");
}
} else {
if (size != q[i].n) {
rb_raise(cumo_na_eShapeError, "index array sizes mismatch");
}
}

if (q[i].idx != NULL && CUMO_SDX_IS_INDEX(sdx1)) {
// index <- index
size_t *index1 = CUMO_SDX_GET_INDEX(sdx1);
if (i==ndim-1) {
cumo_na_index_aref_naview_index_index_kernel_launch(index, index1, size);
} else {
cumo_na_index_at_naview_index_index_index_add_kernel_launch(index, index1, q[i].idx, size);
}
q[i].idx = NULL;
}
else if (q[i].idx == NULL && CUMO_SDX_IS_INDEX(sdx1)) {
// step <- index
size_t beg = q[i].beg;
ssize_t step = q[i].step;
size_t *index1 = CUMO_SDX_GET_INDEX(sdx1);
if (i==ndim-1) {
cumo_na_index_aref_naview_index_index_beg_step_kernel_launch(index, index1, beg, step, size);
} else {
cumo_na_index_at_naview_index_index_beg_step_add_kernel_launch(index, index1, beg, step, size);
}
}
else if (q[i].idx != NULL && CUMO_SDX_IS_STRIDE(sdx1)) {
// index <- step
ssize_t stride1 = CUMO_SDX_GET_STRIDE(sdx1);
if (stride1<0) {
size_t last;
stride1 = -stride1;
last = na1->base.shape[q[i].orig_dim] - 1;
if (na2->offset < last * stride1) {
rb_raise(rb_eStandardError,"bug: negative offset");
}
na2->offset -= last * stride1;
if (i==ndim-1) {
cumo_na_index_aref_naview_index_stride_last_kernel_launch(index, stride1, last, size);
} else {
cumo_na_index_at_naview_index_stride_last_add_kernel_launch(index, stride1, last, size);
}
} else {
if (i==ndim-1) {
cumo_na_index_aref_nadata_index_stride_kernel_launch(index, stride1, size);
} else {
cumo_na_index_at_nadata_index_stride_add_kernel_launch(index, q[i].idx, stride1, size);
}
}
q[i].idx = NULL;
}
else if (q[i].idx == NULL && CUMO_SDX_IS_STRIDE(sdx1)) {
// step <- step
size_t beg = q[i].beg;
ssize_t step = q[i].step;
ssize_t stride1 = CUMO_SDX_GET_STRIDE(sdx1);
if (stride1<0) {
size_t last;
stride1 = -stride1;
last = na1->base.shape[q[i].orig_dim] - 1;
if (na2->offset < last * stride1) {
rb_raise(rb_eStandardError,"bug: negative offset");
}
na2->offset -= last * stride1;
if (i==ndim-1) {
cumo_na_index_at_nadata_index_beg_step_stride_kernel_launch(index, last - beg, -step, stride1, size);
} else {
cumo_na_index_at_nadata_index_beg_step_stride_add_kernel_launch(index, last - beg, -step, stride1, size);
}
} else {
if (i==ndim-1) {
cumo_na_index_at_nadata_index_beg_step_stride_kernel_launch(index, beg, step, stride1, size);
} else {
cumo_na_index_at_nadata_index_beg_step_stride_add_kernel_launch(index, beg, step, stride1, size);
}
}
}
}
na2->base.size = size;
na2->base.shape[0] = size;
if (use_cumo_cuda_runtime_malloc) {
CUMO_SHOW_SYNCHRONIZE_FIXME_WARNING_ONCE("index", "cumo_na_index_at_naview");
cumo_cuda_runtime_check_status(cudaDeviceSynchronize());
}
}

static int
cumo_na_ndim_new_narray(int ndim, const cumo_na_index_arg_t *q)
{
Expand All @@ -587,6 +768,7 @@ typedef struct {
cumo_narray_t *na1;
int keep_dim;
size_t pos; // offset position for 0-dimensional narray. 0-dimensional array does not use q.
int at_mode; // 0: aref, 1: at
} cumo_na_aref_md_data_t;

static cumo_na_index_arg_t*
Expand Down Expand Up @@ -614,6 +796,7 @@ VALUE cumo_na_aref_md_protected(VALUE data_value)
cumo_na_index_arg_t *q = data->q;
cumo_narray_t *na1 = data->na1;
int keep_dim = data->keep_dim;
int at_mode = data->at_mode;

int ndim_new;
VALUE view;
Expand All @@ -624,10 +807,14 @@ VALUE cumo_na_aref_md_protected(VALUE data_value)

if (cumo_na_debug_flag) print_index_arg(q,ndim);

if (keep_dim) {
ndim_new = ndim;
if (at_mode) {
ndim_new = 1;
} else {
ndim_new = cumo_na_ndim_new_narray(ndim, q);
if (keep_dim) {
ndim_new = ndim;
} else {
ndim_new = cumo_na_ndim_new_narray(ndim, q);
}
}
view = cumo_na_s_allocate_view(rb_obj_class(self));

Expand All @@ -647,7 +834,11 @@ VALUE cumo_na_aref_md_protected(VALUE data_value)
na2->offset = data->pos;
na2->base.size = 1;
} else {
cumo_na_index_aref_nadata((cumo_narray_data_t *)na1,na2,q,elmsz,ndim,keep_dim);
if (at_mode) {
cumo_na_index_at_nadata((cumo_narray_data_t *)na1,na2,q,elmsz,ndim,keep_dim);
} else {
cumo_na_index_aref_nadata((cumo_narray_data_t *)na1,na2,q,elmsz,ndim,keep_dim);
}
}
na2->data = self;
break;
Expand All @@ -659,7 +850,11 @@ VALUE cumo_na_aref_md_protected(VALUE data_value)
} else {
na2->offset = ((cumo_narray_view_t *)na1)->offset;
na2->data = ((cumo_narray_view_t *)na1)->data;
cumo_na_index_aref_naview((cumo_narray_view_t *)na1,na2,q,elmsz,ndim,keep_dim);
if (at_mode) {
cumo_na_index_at_naview((cumo_narray_view_t *)na1,na2,q,elmsz,ndim,keep_dim);
} else {
cumo_na_index_aref_naview((cumo_narray_view_t *)na1,na2,q,elmsz,ndim,keep_dim);
}
}
break;
}
Expand All @@ -684,7 +879,7 @@ cumo_na_aref_md_ensure(VALUE data_value)
}

static VALUE
cumo_na_aref_md(int argc, VALUE *argv, VALUE self, int keep_dim, int result_nd, size_t pos)
cumo_na_aref_md(int argc, VALUE *argv, VALUE self, int keep_dim, int result_nd, size_t pos, int at_mode)
{
VALUE args; // should be GC protected
cumo_narray_t *na1;
Expand All @@ -696,6 +891,9 @@ cumo_na_aref_md(int argc, VALUE *argv, VALUE self, int keep_dim, int result_nd,
CumoGetNArray(self,na1);

args = rb_ary_new4(argc,argv);
if (at_mode && na1->ndim == 0) {
rb_raise(cumo_na_eDimensionError,"argument length does not match dimension size");
}

if (argc == 1 && result_nd == 1) {
idx = argv[0];
Expand Down Expand Up @@ -724,6 +922,7 @@ cumo_na_aref_md(int argc, VALUE *argv, VALUE self, int keep_dim, int result_nd,
data.q = cumo_na_allocate_index_args(result_nd);
data.na1 = na1;
data.keep_dim = keep_dim;
data.at_mode = at_mode;

switch(na1->type) {
case CUMO_NARRAY_DATA_T:
Expand Down Expand Up @@ -760,7 +959,15 @@ cumo_na_aref_main(int nidx, VALUE *idx, VALUE self, int keep_dim, int result_nd,
return rb_funcall(*idx,cumo_id_mask,1,self);
}
}
return cumo_na_aref_md(nidx, idx, self, keep_dim, result_nd, pos);
return cumo_na_aref_md(nidx, idx, self, keep_dim, result_nd, pos, 0);
}

/* method: at([idx1,idx2,...,idxN], [idx1,idx2,...,idxN]) */
VALUE
cumo_na_at_main(int nidx, VALUE *idx, VALUE self, int keep_dim, int result_nd, size_t pos)
{
cumo_na_index_arg_to_internal_order(nidx, idx, self);
return cumo_na_aref_md(nidx, idx, self, keep_dim, result_nd, pos, 1);
}


Expand Down
Loading