Skip to content

Commit

Permalink
Fix TBB support.
Browse files Browse the repository at this point in the history
  • Loading branch information
devinamatthews committed Oct 31, 2017
1 parent 4dd52a2 commit 5a9e743
Show file tree
Hide file tree
Showing 34 changed files with 289 additions and 360 deletions.
Binary file modified src/external/tci/.git.bak
Binary file not shown.
337 changes: 171 additions & 166 deletions src/external/tci/tci/communicator.c

Large diffs are not rendered by default.

84 changes: 33 additions & 51 deletions src/external/tci/tci/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

#include "context.h"

#ifdef __cplusplus
extern "C" {
#endif

typedef struct tci_comm
{
tci_context* context;
Expand All @@ -30,15 +26,28 @@ enum
typedef struct tci_range
{
uint64_t size;
uint64_t chunk;
uint64_t grain;

#ifdef __cplusplus
tci_range() : size(0), grain(1) {}

template <typename T>
tci_range(const T& size) : size(size), grain(1) {}

template <typename T, typename U>
tci_range(const T& size, const U& grain) : size(size), grain(grain) {}
#endif
} tci_range;

typedef void (*tci_range_func)(tci_comm*, uint64_t, uint64_t, void*);

typedef void (*tci_range_2d_func)(tci_comm*, uint64_t, uint64_t,
uint64_t, uint64_t, void*);

#ifdef __cplusplus
extern "C" {
#endif

extern tci_comm* const tci_single;

int tci_comm_init_single(tci_comm* comm);
Expand Down Expand Up @@ -146,40 +155,6 @@ using index_sequence_for = make_index_sequence<sizeof...(T)>;

}

class range
{
public:
range(uint64_t size = 0)
{
range_.size = size;
range_.chunk = 1;
range_.grain = 1;
}

uint64_t size() const { return range_.size; }

uint64_t chunk() const { return range_.chunk; }

uint64_t grain() const { return range_.grain; }

range& chunk(uint64_t c)
{
range_.chunk = c;
return *this;
}

range& grain(uint64_t g)
{
range_.grain = g;
return *this;
}

operator tci_range() const { return range_; }

protected:
tci_range range_;
};

class communicator
{
protected:
Expand Down Expand Up @@ -213,11 +188,15 @@ class communicator
template <typename Func>
void visit(unsigned task, Func&& func)
{
typedef typename std::decay<Func>::type RealFunc;
RealFunc* payload = new RealFunc(std::forward<Func>(func));
tci_task_set_visit(&_tasks,
[](tci_comm* comm, unsigned, void* payload)
[](tci_comm* comm, unsigned, void* payload_)
{
(*(Func*)payload)(*reinterpret_cast<const communicator*>(comm));
}, task, &func);
RealFunc* payload = (RealFunc*)payload_;
(*payload)(*reinterpret_cast<const communicator*>(comm));
delete payload;
}, task, payload);
}

protected:
Expand All @@ -233,7 +212,8 @@ class communicator
tci_task_set_visit_all(&_tasks,
[](tci_comm* comm, unsigned task, void* payload)
{
(*(Func*)payload)(*reinterpret_cast<const communicator*>(comm), task);
(*(typename std::decay<Func>::type*)payload)
(*reinterpret_cast<const communicator*>(comm), task);
}, &func);
}

Expand Down Expand Up @@ -340,46 +320,48 @@ class communicator
}

template <typename Func>
void distribute_over_gangs(const range& n, Func&& func) const
void distribute_over_gangs(const tci_range& n, Func&& func) const
{
tci_comm_distribute_over_gangs(*this, n,
[](tci_comm* comm, uint64_t first, uint64_t last, void* payload)
{
(*(Func*)payload)(first, last);
(*(typename std::decay<Func>::type*)payload)(first, last);
}, &func);
}

template <typename Func>
void distribute_over_threads(const range& n, Func&& func) const
void distribute_over_threads(const tci_range& n, Func&& func) const
{
tci_comm_distribute_over_threads(*this, n,
[](tci_comm*, uint64_t first, uint64_t last, void* payload)
{
(*(Func*)payload)(first, last);
(*(typename std::decay<Func>::type*)payload)(first, last);
}, &func);
}

template <typename Func>
void distribute_over_gangs(const range& m, const range& n,
void distribute_over_gangs(const tci_range& m, const tci_range& n,
Func&& func) const
{
tci_comm_distribute_over_gangs_2d(*this, m, n,
[](tci_comm* comm, uint64_t mfirst, uint64_t mlast,
uint64_t nfirst, uint64_t nlast, void* payload)
{
(*(Func*)payload)(mfirst, mlast, nfirst, nlast);
(*(typename std::decay<Func>::type*)payload)
(mfirst, mlast, nfirst, nlast);
}, &func);
}

template <typename Func>
void distribute_over_threads(const range& m, const range& n,
void distribute_over_threads(const tci_range& m, const tci_range& n,
Func&& func) const
{
tci_comm_distribute_over_threads_2d(*this, m, n,
[](tci_comm*, uint64_t mfirst, uint64_t mlast,
uint64_t nfirst, uint64_t nlast, void* payload)
{
(*(Func*)payload)(mfirst, mlast, nfirst, nlast);
(*(typename std::decay<Func>::type*)payload)
(mfirst, mlast, nfirst, nlast);
}, &func);
}

Expand Down
3 changes: 2 additions & 1 deletion src/external/tci/tci/parallel.c
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ int tci_parallelize(tci_thread_func func, void* payload,
int tci_parallelize(tci_thread_func func, void* payload,
unsigned nthread, unsigned arity)
{
func(tci_single, payload);
tci_comm comm = {NULL, 1, 0, nthread, 0};
func(&comm, payload);
return 0;
}

Expand Down
1 change: 1 addition & 0 deletions src/external/tci/tci/task_set.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "communicator.h"
#include "task_set.h"

#ifdef __cplusplus
Expand Down
6 changes: 2 additions & 4 deletions src/internal/1m/add.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ void add(const communicator& comm, const config& cfg, len_type m, len_type n,
std::swap(rs_B, cs_B);
}

comm.distribute_over_threads(tci::range(m).chunk(50).grain(MR),
tci::range(n).chunk(50).grain(NR),
comm.distribute_over_threads({m, MR}, {n, NR},
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
if (beta == T(0))
Expand Down Expand Up @@ -73,8 +72,7 @@ void add(const communicator& comm, const config& cfg, len_type m, len_type n,
std::swap(rs_B, cs_B);
}

comm.distribute_over_threads(tci::range(m).chunk(1000),
tci::range(n).chunk(1000/m),
comm.distribute_over_threads(m, n,
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
if (beta == T(0))
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1m/dot.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ void dot(const communicator& comm, const config& cfg, len_type m, len_type n,

atomic_accumulator<T> local_result;

comm.distribute_over_threads(tci::range(m).chunk(1000),
tci::range(n).chunk(1000/m),
comm.distribute_over_threads(m, n,
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
T micro_result = T();
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1m/reduce.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ void reduce(const communicator& comm, const config& cfg, reduce_t op,
atomic_reducer<T> local_result;
reduce_init(op, local_result);

comm.distribute_over_threads(tci::range(m).chunk(1000),
tci::range(n).chunk(1000/m),
comm.distribute_over_threads(m, n,
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
T micro_result;
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1m/scale.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ void scale(const communicator& comm, const config& cfg, len_type m, len_type n,
std::swap(rs_A, cs_A);
}

comm.distribute_over_threads(tci::range(m).chunk(1000),
tci::range(n).chunk(1000/m),
comm.distribute_over_threads(m, n,
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
for (len_type j = n_min;j < n_max;j++)
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1m/set.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ void set(const communicator& comm, const config& cfg, len_type m, len_type n,
std::swap(rs_A, cs_A);
}

comm.distribute_over_threads(tci::range(m).chunk(1000),
tci::range(n).chunk(1000/m),
comm.distribute_over_threads(m, n,
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
for (len_type j = n_min;j < n_max;j++)
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1m/shift.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ void shift(const communicator& comm, const config& cfg, len_type m, len_type n,
std::swap(rs_A, cs_A);
}

comm.distribute_over_threads(tci::range(m).chunk(1000),
tci::range(n).chunk(1000/m),
comm.distribute_over_threads(m, n,
[&](len_type m_min, len_type m_max, len_type n_min, len_type n_max)
{
for (len_type j = n_min;j < n_max;j++)
Expand Down
7 changes: 3 additions & 4 deletions src/internal/1t/dense/add.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void add(const communicator& comm, const config& cfg,
//TODO sum (reduce?) ukr
//TODO fused ukr

comm.distribute_over_threads(tci::range(n_AB).chunk(1000/n_A),
comm.distribute_over_threads(n_AB,
[&](len_type n_min, len_type n_max)
{
auto A1 = A;
Expand Down Expand Up @@ -126,7 +126,7 @@ void add(const communicator& comm, const config& cfg,
//TODO replicate ukr
//TODO fused ukr

comm.distribute_over_threads(tci::range(n_AB).chunk(1000/n_B),
comm.distribute_over_threads(n_AB,
[&](len_type n_min, len_type n_max)
{
auto A1 = A;
Expand Down Expand Up @@ -170,8 +170,7 @@ void add(const communicator& comm, const config& cfg,
stride_type stride_B0 = stride_B_AB[0];
stride_vector stride_B1(stride_B_AB.begin()+1, stride_B_AB.end());

comm.distribute_over_threads(tci::range(n0).chunk(1000),
tci::range(n1).chunk(1000/n0),
comm.distribute_over_threads(n0, n1,
[&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
{
auto A1 = A;
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1t/dense/dot.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void dot(const communicator& comm, const config& cfg,

if (conj_A) conj_B = !conj_B;

comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
auto A1 = A;
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1t/dense/reduce.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ void reduce(const communicator& comm, const config& cfg, reduce_t op,
atomic_reducer<T> local_result;
reduce_init(op, local_result);

comm.distribute_over_threads(tci::range(n0).chunk(1000),
tci::range(n1).chunk(1000/n0),
comm.distribute_over_threads(n0, n1,
[&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
{
auto A1 = A;
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1t/dense/scale.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ void scale(const communicator& comm, const config& cfg,
stride_type stride0 = (empty ? 1 : stride_A[0]);
len_vector stride1(stride_A.begin() + !empty, stride_A.end());

comm.distribute_over_threads(tci::range(n0).chunk(1000),
tci::range(n1).chunk(1000/n0),
comm.distribute_over_threads(n0, n1,
[&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
{
auto A1 = A;
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1t/dense/set.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ void set(const communicator& comm, const config& cfg,
stride_type stride0 = (empty ? 1 : stride_A[0]);
len_vector stride1(stride_A.begin() + !empty, stride_A.end());

comm.distribute_over_threads(tci::range(n0).chunk(1000),
tci::range(n1).chunk(1000/n0),
comm.distribute_over_threads(n0, n1,
[&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
{
auto A1 = A;
Expand Down
3 changes: 1 addition & 2 deletions src/internal/1t/dense/shift.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ void shift(const communicator& comm, const config& cfg,
stride_type stride0 = (empty ? 1 : stride_A[0]);
len_vector stride1(stride_A.begin() + !empty, stride_A.end());

comm.distribute_over_threads(tci::range(n0).chunk(1000),
tci::range(n1).chunk(1000/n0),
comm.distribute_over_threads(n0, n1,
[&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
{
auto A1 = A;
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1v/add.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ void add(const communicator& comm, const config& cfg, len_type n,
T alpha, bool conj_A, const T* A, stride_type inc_A,
T beta, bool conj_B, T* B, stride_type inc_B)
{
comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
if (beta == T(0))
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1v/dot.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void dot(const communicator& comm, const config& cfg, len_type n,
{
atomic_accumulator<T> local_result;

comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
T micro_result = T();
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1v/reduce.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void reduce(const communicator& comm, const config& cfg, reduce_t op, len_type n
atomic_reducer<T> local_result;
reduce_init(op, local_result);

comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
T micro_result;
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1v/scale.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ template <typename T>
void scale(const communicator& comm, const config& cfg, len_type n,
T alpha, bool conj_A, T* A, stride_type inc_A)
{
comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
cfg.scale_ukr.call<T>(n_max-n_min, alpha, conj_A, A+n_min*inc_A, inc_A);
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1v/set.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ template <typename T>
void set(const communicator& comm, const config& cfg, len_type n,
T alpha, T* A, stride_type inc_A)
{
comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
cfg.set_ukr.call<T>(n_max-n_min, alpha, A+n_min*inc_A, inc_A);
Expand Down
2 changes: 1 addition & 1 deletion src/internal/1v/shift.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ template <typename T>
void shift(const communicator& comm, const config& cfg, len_type n,
T alpha, T beta, bool conj_A, T* A, stride_type inc_A)
{
comm.distribute_over_threads(tci::range(n).chunk(1000),
comm.distribute_over_threads(n,
[&](len_type n_min, len_type n_max)
{
cfg.shift_ukr.call<T>(n_max-n_min, alpha, beta, conj_A, A+n_min*inc_A, inc_A);
Expand Down
8 changes: 4 additions & 4 deletions src/internal/3m/mult.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ void mult(const communicator& comm, const config& cfg,
gemm_thread_config tc = make_gemm_thread_config<T>(cfg, nt, m, n, k);

GotoGEMM gemm;
step<0>(gemm).distribute = tc.jc_nt;
step<3>(gemm).distribute = tc.ic_nt;
step<5>(gemm).distribute = tc.jr_nt;
step<6>(gemm).distribute = tc.ir_nt;
step<0>(gemm).subcomm = comm.gang(TCI_EVENLY, tc.jc_nt);
step<3>(gemm).subcomm = step<0>(gemm).subcomm.gang(TCI_EVENLY, tc.ic_nt);
step<5>(gemm).subcomm = step<3>(gemm).subcomm.gang(TCI_EVENLY, tc.jr_nt);
step<6>(gemm).subcomm = step<5>(gemm).subcomm.gang(TCI_EVENLY, tc.ir_nt);

gemm(comm, cfg, alpha, Av, Bv, beta, Cv);

Expand Down
Loading

0 comments on commit 5a9e743

Please sign in to comment.