Skip to content

Commit

Permalink
Add support for higher-order JSON descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Feb 14, 2024
1 parent 88a5803 commit 5e1f987
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions yateto/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,22 @@
def __transposeMatrix(matrix):
matrixT = dict()
for entry,value in matrix.items():
matrixT[(entry[1], entry[0])] = value
matrixT[tuple(entry[::-1])] = value
return matrixT

def __processMatrix(name, rows, columns, entries, clones, transpose, alignStride, namespace=None):
def __processMatrix(name, shape, entries, clones, transpose, alignStride, namespace=None):
matrix = dict()

dim = len(shape)

# traverse a list of matrix entries and generate a matrix description
# as a hash table
for entry in entries:
# adjust row and column numbers
row = int(entry[0])-1
col = int(entry[1])-1
index = tuple(int(entry[i]) - 1 for i in range(dim))

# allocate a matrix element inside of a table
matrix[(row, col)] = entry[2]
matrix[index] = entry[-1]

# allocate an empty hash table to hold tensors (matrices) which are going to be generated
# using the matrix description
Expand All @@ -42,16 +43,16 @@ def __processMatrix(name, rows, columns, entries, clones, transpose, alignStride

# generate tensors using description of a give matrix
for name in names:
# compute a shape of a tensor
shape = (columns, rows) if transpose(name) else (rows, columns)
if shape[1] == 1:
# compute a shape of a tensor (for now, assume transpose == invert dimensions)
shape = shape[::-1] if transpose(name) else shape
if shape[1] == 1 and len(shape) == 2: # TODO: remove once all files are converted
shape = (shape[0],)

# transpose matrix if it is needed
mtx = __transposeMatrix(matrix) if transpose(name) else matrix

# adjust layout description in case if a given matrix is a vector
if len(shape) == 1:
if len(shape) == 1: # TODO: remove once all files are converted
mtx = {(i[0],): val for i,val in mtx.items()}

# Create an tensor(matrix) using the matrix description and append the hash table
Expand Down Expand Up @@ -91,7 +92,7 @@ def parseXMLMatrixFile(xmlFile, clones=dict(), transpose=lambda name: False, ali
else:
__complain(child)

matrices.update( __processMatrix(name, rows, columns, entries, clones, transpose, alignStride, namespace) )
matrices.update( __processMatrix(name, (rows, columns), entries, clones, transpose, alignStride, namespace) )
else:
__complain(node)

Expand All @@ -104,9 +105,16 @@ def parseJSONMatrixFile(jsonFile, clones=dict(), transpose=lambda name: False, a
content = json.load(j)
for m in content:
entries = m['entries']
if len(next(iter(entries))) == 2:
entries = [(entry[0], entry[1], True) for entry in entries]
matrices.update( __processMatrix(m['name'], m['rows'], m['columns'], entries, clones, transpose, alignStride, namespace) )
if 'rows' in m:
shape = [m['rows']]
if 'columns' in m:
shape += [m['columns']]
else:
shape = m['shape']
dim = len(shape)
if len(next(iter(entries))) == dim:
entries = [(*entry, True) for entry in entries]
matrices.update( __processMatrix(m['name'], shape, entries, clones, transpose, alignStride, namespace) )

return create_collection(matrices)

Expand Down

0 comments on commit 5e1f987

Please sign in to comment.