Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some deepcop[ies] to speed up strengthReduction #66

Merged
merged 7 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions yateto/ast/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ class FusedGemmsBoundingBoxCostEstimator(BoundingBoxCostEstimator):
def __init__(self):
super().__init__()
self._lead_dim = 0
self._loaded_to_gpu_cache = set()
self._loaded_to_gpu_cache = {}

def generic_estimate(self, node):
result = super().generic_estimate(node)
self._loaded_to_gpu_cache[node] = set()
return result

def _get_terms(self, node):
left_indices = node.leftTerm().indices
Expand All @@ -117,15 +122,18 @@ def estimate_Product(self, node):
bb = self._cache[left_term]
cost /= bb[self._lead_dim].size()

# take the union of all cached nodes
self._loaded_to_gpu_cache[node] = self._loaded_to_gpu_cache[left_term].union(self._loaded_to_gpu_cache[right_term])

extra_cost = 0
if not right_term in self._loaded_to_gpu_cache:
self._loaded_to_gpu_cache.add(right_term)
if not right_term in self._loaded_to_gpu_cache[node]:
self._loaded_to_gpu_cache[node].add(right_term)
rbb = self._cache[right_term]
extra_cost += rbb.size()

if node.indices[self._lead_dim] != left_term.indices[self._lead_dim]:
if not node.leftTerm in self._loaded_to_gpu_cache:
self._loaded_to_gpu_cache.add(left_term)
if not node.leftTerm in self._loaded_to_gpu_cache[node]:
self._loaded_to_gpu_cache[node].add(left_term)
lbb = self._cache[left_term]
extra_cost += lbb.size()
return cost + extra_cost
Expand All @@ -143,7 +151,12 @@ def estimate_IndexSum(self, node):

left_term, _ = self._get_terms(child)
bb = self._cache[left_term]
self._loaded_to_gpu_cache.add(node)

# we will have visited node.term() as well at this point
# (but we need to add ourselves as well)
self._loaded_to_gpu_cache[node] = set(self._loaded_to_gpu_cache[node.term()])
self._loaded_to_gpu_cache[node].add(node)

return cost / bb[self._lead_dim].size()


Expand Down
14 changes: 7 additions & 7 deletions yateto/ast/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def shape(self):
return self.subShape(self._indices)

def subShape(self, indexNames):
return tuple([self._size[index] for index in indexNames])
return tuple(self._size[index] for index in indexNames)

def indexSize(self, index):
return self._size[index]
Expand Down Expand Up @@ -70,7 +70,7 @@ def __rand__(self, other):

def __le__(self, other):
indexNamesContained = set(self._indices) <= set(other._indices)
return indexNamesContained and all([self._size[index] == other._size[index] for index in self._indices])
return indexNamesContained and all(self._size[index] == other._size[index] for index in self._indices)

def __sub__(self, other):
indexNames = [index for index in self._indices if index not in other]
Expand All @@ -89,7 +89,7 @@ def __str__(self):
return self.tostring()

def __repr__(self):
return '({})'.format(','.join(['{}={}'.format(index, self._size[index]) for index in self._indices]))
return '({})'.format(','.join('{}={}'.format(index, self._size[index]) for index in self._indices))

def size(self):
return self._size
Expand Down Expand Up @@ -143,8 +143,8 @@ def __contains__(self, entry):
if len(self) == 0:
return True
if isinstance(entry[0], Range):
return all([e in self[i] for i,e in enumerate(entry)])
return all([e >= self[i].start and e <= self[i].stop for i,e in enumerate(entry)])
return all(e in self[i] for i,e in enumerate(entry))
return all(e >= self[i].start and e <= self[i].stop for i,e in enumerate(entry))

def __getitem__(self, key):
return self._box[key]
Expand All @@ -156,10 +156,10 @@ def __iter__(self):
return iter(self._box)

def __eq__(self, other):
return all([s == o for s,o in zip(self,other)])
return all(s == o for s,o in zip(self,other))

def __str__(self):
return '{}({})'.format(type(self).__name__, ', '.join([str(r) for r in self]))
return '{}({})'.format(type(self).__name__, ', '.join(str(r) for r in self))

@functools.total_ordering
class LoGCost(object):
Expand Down
10 changes: 5 additions & 5 deletions yateto/ast/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def splitByDistance(p):

def fusedVariants(memLayout, I, P, M, prune = False):
D = list()
indices = sorted([P[p] for p in I])
indices = sorted(P[p] for p in I)
groups = splitByDistance(indices)
groupStrings = [''.join([M[p] for p in sorted(g)]) for g in groups]
D = set([s for g in groupStrings for s in allSubstrings(g)])
groupStrings = [''.join(M[p] for p in sorted(g)) for g in groups]
D = set(s for g in groupStrings for s in allSubstrings(g))
if prune:
D = set([d for d in D if d[0] == M[0]])
D = set([d for d in D if memLayout.mayFuse(sorted([P[i] for i in d]))])
D = set(d for d in D if d[0] == M[0])
D = set(d for d in D if memLayout.mayFuse(sorted(P[i] for i in d)))
return D

def LoG(contraction, Aperm = None, Bperm = None, Cperm = None):
Expand Down
4 changes: 2 additions & 2 deletions yateto/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setIndexPermutation(self, indices, permuteEqspp=True):
pass

def permute(self, indices, spp):
perm = tuple([indices.find(idx) for idx in self.indices])
perm = tuple(indices.find(idx) for idx in self.indices)
return spp.transposed(perm)

def _checkMultipleScalarMults(self):
Expand Down Expand Up @@ -174,7 +174,7 @@ def setIndexPermutation(self, indices, permuteEqspp=True):
if str(indices) == str(self.indices):
return

p = tuple([self.indices.find(idx) for idx in indices])
p = tuple(self.indices.find(idx) for idx in indices)
if self._eqspp is not None:
if permuteEqspp:
self._eqspp = self._eqspp.transposed(p)
Expand Down
12 changes: 4 additions & 8 deletions yateto/ast/opt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import sys
from .node import IndexSum, Product
from copy import deepcopy


def strengthReduction(terms, target_indices, cost_estimator, split = 0):
n = len(terms)

indexList = [index for term in terms for index in term.indices]
uniqueIndices = set(indexList)
summationIndices = set([index for index in uniqueIndices if indexList.count(index) == 1]) - set(target_indices)
summationIndices = set(index for index in uniqueIndices if indexList.count(index) == 1) - set(target_indices)

while len(summationIndices) != 0:
i = split
Expand All @@ -31,19 +29,17 @@ def strengthReduction(terms, target_indices, cost_estimator, split = 0):
for i in range(n):
for j in range(max(i+1,split),n):
mulTerm = Product(terms[i], terms[j])
prodCost = deepcopy(cost_estimator).estimate(mulTerm)
prodCost = cost_estimator.estimate(mulTerm)
if best == None or prodCost < minCost:
selection = set(range(n)) - set([i,j])
tree = strengthReduction([terms[i] for i in selection] + [mulTerm],
deepcopy(target_indices),
target_indices,
cost_estimator,
j-1)

cost_estimator_copy = deepcopy(cost_estimator)
treeCost = cost_estimator_copy.estimate(tree)
treeCost = cost_estimator.estimate(tree)
if best == None or treeCost < minCost:
best = tree
minCost = treeCost
cost_estimator = cost_estimator_copy

return best
6 changes: 3 additions & 3 deletions yateto/ast/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def visit_Einsum(self, node, bound):
g = Indices()
for child in node:
overlap = g & child.indices
if any([g.size()[index] != child.size()[index] for index in overlap]):
if any(g.size()[index] != child.size()[index] for index in overlap):
PrettyPrinter().visit(node)
raise ValueError('Einsum: Index dimensions do not match: ', g, child.indices, str(child))
g = g.merged(child.indices - overlap)
Expand All @@ -76,7 +76,7 @@ def visit_Add(self, node, bound):
for child in node:
self.visit(child, bound)

ok = all([node[0].indices <= child.indices and child.indices <= node[0].indices for child in node])
ok = all(node[0].indices <= child.indices and child.indices <= node[0].indices for child in node)
if not ok:
raise ValueError('Add: Indices do not match: ', *[child.indices for child in node])

Expand Down Expand Up @@ -198,7 +198,7 @@ def visit_Assign(self, node):

def getEqspp(self, terms, targetIndices):
# Shortcut if all terms have dense eqspps
if all([term.eqspp().is_dense() for term in terms]):
if all(term.eqspp().is_dense() for term in terms):
return aspp.dense(targetIndices.shape())

minTree = opt.strengthReduction(terms, targetIndices, ShapeCostEstimator())
Expand Down
12 changes: 6 additions & 6 deletions yateto/ast/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def visit(self, node, **kwargs):
return result

def addIndent(string, indent):
return '\n'.join([indent + line for line in string.splitlines()])
return '\n'.join(indent + line for line in string.splitlines())

class PrettyPrinter(Visitor):
def __init__(self):
Expand Down Expand Up @@ -236,9 +236,9 @@ def generic_visit(self, node):
f.write('%%TensorMarket tensor coordinate real general\n')
nzs = pattern.nonzero()
if nzs:
f.write('{} {}\n'.format(' '.join([str(s) for s in pattern.shape]), len(nzs[0])))
f.write('{} {}\n'.format(' '.join(str(s) for s in pattern.shape), len(nzs[0])))
for idx in zip(*nzs):
f.write('{} {}\n'.format(' '.join([str(i) for i in idx]), float(pattern[idx])))
f.write('{} {}\n'.format(' '.join(str(i) for i in idx), float(pattern[idx])))
nSubplots = 1
for dim in range(2, eqspp.ndim):
nSubplots *= eqspp.shape[dim]
Expand All @@ -254,7 +254,7 @@ def generic_visit(self, node):
for index in ndindex(*list(eqspp.shape)[2:]):
sl = pattern[(slice(None, None), slice(None, None)) + index]
axs[nSubplot].imshow(sl.astype(bool), cmap=self._cmap, norm=self._norm)
axs[nSubplot].set_title('(:,:,{})'.format(','.join([str(i) for i in index])), y=1.2)
axs[nSubplot].set_title('(:,:,{})'.format(','.join(str(i) for i in index)), y=1.2)
nSubplot = nSubplot + 1
#plt.setp(axs, xticks=arange(eqspp.shape[1]), yticks=arange(eqspp.shape[0]))
fig.tight_layout()
Expand Down Expand Up @@ -286,14 +286,14 @@ def visit_Einsum(self, node):
terms = self.generic_visit(node)
childIndices = [child.indices for child in node]
assert None not in childIndices and node.indices is not None, 'Use DeduceIndices before {}.'.format(self.__class__.__name__)
einsumDescription = ','.join([indices.tostring() for indices in childIndices])
einsumDescription = ','.join(indices.tostring() for indices in childIndices)
einsumDescription = '{}->{}'.format(einsumDescription, node.indices.tostring())
return einsum(einsumDescription, *terms)

def visit_Add(self, node):
terms = self.generic_visit(node)
assert len(terms) > 1
permute = lambda indices, tensor: tensor.transpose(tuple([indices.find(idx) for idx in node.indices]))
permute = lambda indices, tensor: tensor.transpose(tuple(indices.find(idx) for idx in node.indices))
return reduce(add, [permute(child.indices, terms[i]) for i,child in enumerate(node)])

def visit_ScalarMultiplication(self, node):
Expand Down
5 changes: 2 additions & 3 deletions yateto/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,12 @@ def unit_test_body(cpp, testFramework):

print('Optimizing ASTs...')
for kernel in self._kernels:
print(kernel.name)
print(f'{kernel.name} ({len(kernel.ast)} AST(s))')
kernel.prepareUntilCodeGen(cost_estimator)
for family in self._kernelFamilies.values():
print(family.name)
print(f'{family.name} ({sum(len(kernel.ast) for kernel in family.kernels())} AST(s))')
family.prepareUntilCodeGen(cost_estimator)


# Create mapping from namespace to kernel/family
kernel_dict = {}
for kernel in self._kernels:
Expand Down
5 changes: 1 addition & 4 deletions yateto/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ def permuted(self, permutation):

def address(self, entry):
assert entry in self._bbox
a = 0
for i, e in enumerate(entry):
a += (e - self._bbox[i].start) * self._stride[i]
return a
return sum((e - self._bbox[i].start) * self._stride[i] for i, e in enumerate(entry))

def subtensorOffset(self, topLeftEntry):
return self.address(topLeftEntry)
Expand Down
Loading