diff --git a/yateto/input.py b/yateto/input.py index 182f370..76e0194 100644 --- a/yateto/input.py +++ b/yateto/input.py @@ -17,21 +17,22 @@ def __transposeMatrix(matrix): matrixT = dict() for entry,value in matrix.items(): - matrixT[(entry[1], entry[0])] = value + matrixT[tuple(entry[::-1])] = value return matrixT -def __processMatrix(name, rows, columns, entries, clones, transpose, alignStride, namespace=None): +def __processMatrix(name, shape, entries, clones, transpose, alignStride, namespace=None): matrix = dict() + dim = len(shape) + # traverse a list of matrix entries and generate a matrix description # as a hash table for entry in entries: # adjust row and column numbers - row = int(entry[0])-1 - col = int(entry[1])-1 + index = tuple(int(entry[i]) - 1 for i in range(dim)) # allocate a matrix element inside of a table - matrix[(row, col)] = entry[2] + matrix[index] = entry[-1] # allocate an empty hash table to hold tensors (matrices) which are going to be generated # using the matrix description @@ -42,16 +43,16 @@ def __processMatrix(name, rows, columns, entries, clones, transpose, alignStride # generate tensors using description of a give matrix for name in names: - # compute a shape of a tensor - shape = (columns, rows) if transpose(name) else (rows, columns) - if shape[1] == 1: + # compute a shape of a tensor (for now, assume transpose == invert dimensions) + shape = shape[::-1] if transpose(name) else shape + if shape[1] == 1 and len(shape) == 2: # TODO: remove once all files are converted shape = (shape[0],) # transpose matrix if it is needed mtx = __transposeMatrix(matrix) if transpose(name) else matrix # adjust layout description in case if a given matrix is a vector - if len(shape) == 1: + if len(shape) == 1: # TODO: remove once all files are converted mtx = {(i[0],): val for i,val in mtx.items()} # Create an tensor(matrix) using the matrix description and append the hash table @@ -91,7 +92,7 @@ def parseXMLMatrixFile(xmlFile, clones=dict(), transpose=lambda name: False, ali else: __complain(child) - matrices.update( __processMatrix(name, rows, columns, entries, clones, transpose, alignStride, namespace) ) + matrices.update( __processMatrix(name, (rows, columns), entries, clones, transpose, alignStride, namespace) ) else: __complain(node) @@ -104,9 +105,16 @@ def parseJSONMatrixFile(jsonFile, clones=dict(), transpose=lambda name: False, a content = json.load(j) for m in content: entries = m['entries'] - if len(next(iter(entries))) == 2: - entries = [(entry[0], entry[1], True) for entry in entries] - matrices.update( __processMatrix(m['name'], m['rows'], m['columns'], entries, clones, transpose, alignStride, namespace) ) + if 'rows' in m: + shape = [m['rows']] + if 'columns' in m: + shape += [m['columns']] + else: + shape = m['shape'] + dim = len(shape) + if len(next(iter(entries))) == dim: + entries = [(*entry, True) for entry in entries] + matrices.update( __processMatrix(m['name'], shape, entries, clones, transpose, alignStride, namespace) ) return create_collection(matrices)