Skip to content

Commit

Permalink
template DistributedNDArray so that different chunk types can be used
Browse files Browse the repository at this point in the history
  • Loading branch information
philippwindischhofer committed Feb 6, 2024
1 parent 5aadc59 commit 02c3baf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 34 deletions.
9 changes: 6 additions & 3 deletions src/DistributedNDArray.hh
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ struct ChunkMetadata {
IndexVector stop_ind;
};

template <class T, std::size_t dims, class SerializerT>
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
class DistributedNDArray : public NDArray<T, dims> {

public:

using chunk_t = ChunkT<T, dims>;

DistributedNDArray(std::string dirpath, std::size_t max_cache_size, SerializerT& ser);
~DistributedNDArray();

using chunk_t = DenseNDArray<T, dims>;

// For assembling and indexing a distributed array
void RegisterChunk(const chunk_t& chunk, const IndexVector start_ind, bool require_nonoverlapping = false);
void MakeIndexPersistent();
Expand Down Expand Up @@ -83,4 +83,7 @@ private:

#include "DistributedNDArray.hxx"

template <class T, std::size_t dims, class SerializerT>
using DistributedDenseNDArray = DistributedNDArray<T, dims, DenseNDArray, SerializerT>;

#endif
56 changes: 28 additions & 28 deletions src/DistributedNDArray.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace stor {
};
}

template <class T, std::size_t dims, class SerializerT>
DistributedNDArray<T, dims, SerializerT>::DistributedNDArray(std::string dirpath, std::size_t max_cache_size, SerializerT& ser) :
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
DistributedNDArray<T, dims, ChunkT, SerializerT>::DistributedNDArray(std::string dirpath, std::size_t max_cache_size, SerializerT& ser) :
NDArray<T, dims>(), m_dirpath(dirpath), m_indexpath(dirpath + "/index.bin"), m_max_cache_size(max_cache_size),
m_global_start_ind(dims, 0), m_ser(ser) {

Expand All @@ -47,11 +47,11 @@ DistributedNDArray<T, dims, SerializerT>::DistributedNDArray(std::string dirpath
calculateShape();
}

template <class T, std::size_t dims, class SerializerT>
DistributedNDArray<T, dims, SerializerT>::~DistributedNDArray() { }
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
DistributedNDArray<T, dims, ChunkT, SerializerT>::~DistributedNDArray() { }

template <class T, std::size_t dims, class SerializerT>
void DistributedNDArray<T, dims, SerializerT>::RegisterChunk(const DenseNDArray<T, dims>& chunk, const IndexVector start_ind, bool require_nonoverlapping) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
void DistributedNDArray<T, dims, ChunkT, SerializerT>::RegisterChunk(const chunk_t& chunk, const IndexVector start_ind, bool require_nonoverlapping) {

IndexVector stop_ind = start_ind + chunk.shape();

Expand Down Expand Up @@ -89,8 +89,8 @@ void DistributedNDArray<T, dims, SerializerT>::RegisterChunk(const DenseNDArray<
calculateShape();
}

template <class T, std::size_t dims, class SerializerT>
void DistributedNDArray<T, dims, SerializerT>::MakeIndexPersistent() {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
void DistributedNDArray<T, dims, ChunkT, SerializerT>::MakeIndexPersistent() {
// Update index on disk
std::fstream ofs;
ofs.open(m_indexpath, std::ios::out | std::ios::binary);
Expand All @@ -99,8 +99,8 @@ void DistributedNDArray<T, dims, SerializerT>::MakeIndexPersistent() {
ofs.close();
}

template <class T, std::size_t dims, class SerializerT>
void DistributedNDArray<T, dims, SerializerT>::rebuildIndex() {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
void DistributedNDArray<T, dims, ChunkT, SerializerT>::rebuildIndex() {
m_chunk_index.clear();

// With the index gone, also the cache is now out of scope
Expand All @@ -118,8 +118,8 @@ void DistributedNDArray<T, dims, SerializerT>::rebuildIndex() {
}
}

template <class T, std::size_t dims, class SerializerT>
T DistributedNDArray<T, dims, SerializerT>::operator()(IndexVector& inds) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
T DistributedNDArray<T, dims, ChunkT, SerializerT>::operator()(IndexVector& inds) {

// check to which chunk this index belongs
std::size_t chunk_ind = getChunkIndex(inds);
Expand All @@ -132,13 +132,13 @@ T DistributedNDArray<T, dims, SerializerT>::operator()(IndexVector& inds) {
return found_chunk(inds_within_chunk);
}

template <class T, std::size_t dims, class SerializerT>
bool DistributedNDArray<T, dims, SerializerT>::chunkContainsInds(const ChunkMetadata& chunk_meta, const IndexVector& inds) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
bool DistributedNDArray<T, dims, ChunkT, SerializerT>::chunkContainsInds(const ChunkMetadata& chunk_meta, const IndexVector& inds) {
return isInIndexRange(inds, chunk_meta.start_ind, chunk_meta.stop_ind);
}

template <class T, std::size_t dims, class SerializerT>
std::size_t DistributedNDArray<T, dims, SerializerT>::getChunkIndex(const IndexVector& inds) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
std::size_t DistributedNDArray<T, dims, ChunkT, SerializerT>::getChunkIndex(const IndexVector& inds) {
std::size_t chunk_ind = 0;
for(chunk_ind = 0; chunk_ind < m_chunk_index.size(); chunk_ind++) {
if(chunkContainsInds(m_chunk_index[chunk_ind], inds)) {
Expand All @@ -149,8 +149,8 @@ std::size_t DistributedNDArray<T, dims, SerializerT>::getChunkIndex(const IndexV
throw std::runtime_error("No chunk provides these indices!");
}

template <class T, std::size_t dims, class SerializerT>
DistributedNDArray<T, dims, SerializerT>::chunk_t& DistributedNDArray<T, dims, SerializerT>::retrieveChunk(std::size_t chunk_ind) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
DistributedNDArray<T, dims, ChunkT, SerializerT>::chunk_t& DistributedNDArray<T, dims, ChunkT, SerializerT>::retrieveChunk(std::size_t chunk_ind) {

ChunkMetadata& chunk_meta = m_chunk_index[chunk_ind];

Expand Down Expand Up @@ -180,8 +180,8 @@ DistributedNDArray<T, dims, SerializerT>::chunk_t& DistributedNDArray<T, dims, S
return m_chunk_cache.find(chunk_ind) -> second;
}

template <class T, std::size_t dims, class SerializerT>
void DistributedNDArray<T, dims, SerializerT>::calculateShape() {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
void DistributedNDArray<T, dims, ChunkT, SerializerT>::calculateShape() {

if(m_chunk_index.empty()) {
// Nothing to do if everything is empty
Expand All @@ -199,8 +199,8 @@ void DistributedNDArray<T, dims, SerializerT>::calculateShape() {
}
}

template <class T, std::size_t dims, class SerializerT>
bool DistributedNDArray<T, dims, SerializerT>::isGloballyContiguous(IndexVector& global_start_inds, IndexVector& global_stop_inds) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
bool DistributedNDArray<T, dims, ChunkT, SerializerT>::isGloballyContiguous(IndexVector& global_start_inds, IndexVector& global_stop_inds) {

std::size_t global_volume = getVolume(global_start_inds, global_stop_inds);

Expand All @@ -212,17 +212,17 @@ bool DistributedNDArray<T, dims, SerializerT>::isGloballyContiguous(IndexVector&
return global_volume == total_chunk_volume;
}

template <class T, std::size_t dims, class SerializerT>
std::size_t DistributedNDArray<T, dims, SerializerT>::getVolume(IndexVector& start_inds, IndexVector& stop_inds) {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
std::size_t DistributedNDArray<T, dims, ChunkT, SerializerT>::getVolume(IndexVector& start_inds, IndexVector& stop_inds) {
std::size_t volume = 1;
for(std::size_t cur_dim = 0; cur_dim < dims; cur_dim++) {
volume *= (stop_inds(cur_dim) - start_inds(cur_dim));
}
return volume;
}

template <class T, std::size_t dims, class SerializerT>
IndexVector& DistributedNDArray<T, dims, SerializerT>::getGlobalStartInd() {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
IndexVector& DistributedNDArray<T, dims, ChunkT, SerializerT>::getGlobalStartInd() {
if(m_chunk_index.empty()) {
throw std::runtime_error("Trying to compute start index of an empty array!");
}
Expand All @@ -235,8 +235,8 @@ IndexVector& DistributedNDArray<T, dims, SerializerT>::getGlobalStartInd() {
return *global_start_ind;
}

template <class T, std::size_t dims, class SerializerT>
IndexVector& DistributedNDArray<T, dims, SerializerT>::getGlobalStopInd() {
template <class T, std::size_t dims, template<class, std::size_t> class ChunkT, class SerializerT>
IndexVector& DistributedNDArray<T, dims, ChunkT, SerializerT>::getGlobalStopInd() {
if(m_chunk_index.empty()) {
throw std::runtime_error("Trying to compute stop index of an empty array!");
}
Expand Down
2 changes: 1 addition & 1 deletion src/FieldStorage.hh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RZFieldStorage : public FieldStorage {

public:
using serializer_t = stor::DefaultSerializer;
using storage_t = DistributedNDArray<scalar_t, 3, serializer_t>;
using storage_t = DistributedDenseNDArray<scalar_t, 3, serializer_t>;
using chunk_t = DenseNDArray<scalar_t, 3>;

public:
Expand Down
4 changes: 2 additions & 2 deletions tests/io/testDistributedNDArray.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ int main(int argc, char* argv[]) {

std::shared_ptr<stor::DefaultSerializer> ser = std::make_shared<stor::DefaultSerializer>();

DistributedNDArray<float, 2, stor::DefaultSerializer> darr_save("./distarr/", 10, *ser);
DistributedDenseNDArray<float, 2, stor::DefaultSerializer> darr_save("./distarr/", 10, *ser);
darr_save.RegisterChunk(chunk1, start_ind1);
darr_save.RegisterChunk(chunk2, start_ind2);
darr_save.MakeIndexPersistent();

DistributedNDArray<float, 2, stor::DefaultSerializer> darr_load("./distarr/", 10, *ser);
DistributedDenseNDArray<float, 2, stor::DefaultSerializer> darr_load("./distarr/", 10, *ser);

IndexVector acc_ind1 = {1,1};
std::cout << darr_load(acc_ind1) << std::endl;
Expand Down

0 comments on commit 02c3baf

Please sign in to comment.