diff --git a/yateto/ast/cost.py b/yateto/ast/cost.py index b088966..9c62ca7 100644 --- a/yateto/ast/cost.py +++ b/yateto/ast/cost.py @@ -94,7 +94,12 @@ class FusedGemmsBoundingBoxCostEstimator(BoundingBoxCostEstimator): def __init__(self): super().__init__() self._lead_dim = 0 - self._loaded_to_gpu_cache = set() + self._loaded_to_gpu_cache = {} + + def generic_estimate(self, node): + result = super().generic_estimate(node) + self._loaded_to_gpu_cache[node] = set() + return result def _get_terms(self, node): left_indices = node.leftTerm().indices @@ -117,15 +122,18 @@ def estimate_Product(self, node): bb = self._cache[left_term] cost /= bb[self._lead_dim].size() + # take the union of all cached nodes + self._loaded_to_gpu_cache[node] = self._loaded_to_gpu_cache[left_term].union(self._loaded_to_gpu_cache[right_term]) + extra_cost = 0 - if not right_term in self._loaded_to_gpu_cache: - self._loaded_to_gpu_cache.add(right_term) + if not right_term in self._loaded_to_gpu_cache[node]: + self._loaded_to_gpu_cache[node].add(right_term) rbb = self._cache[right_term] extra_cost += rbb.size() if node.indices[self._lead_dim] != left_term.indices[self._lead_dim]: - if not node.leftTerm in self._loaded_to_gpu_cache: - self._loaded_to_gpu_cache.add(left_term) + if not node.leftTerm in self._loaded_to_gpu_cache[node]: + self._loaded_to_gpu_cache[node].add(left_term) lbb = self._cache[left_term] extra_cost += lbb.size() return cost + extra_cost @@ -143,7 +151,12 @@ def estimate_IndexSum(self, node): left_term, _ = self._get_terms(child) bb = self._cache[left_term] - self._loaded_to_gpu_cache.add(node) + + # we will have visited node.term() as well at this point + # (but we need to add ourselves as well) + self._loaded_to_gpu_cache[node] = set(self._loaded_to_gpu_cache[node.term()]) + self._loaded_to_gpu_cache[node].add(node) + return cost / bb[self._lead_dim].size() diff --git a/yateto/ast/indices.py b/yateto/ast/indices.py index a1f42ee..d9baadb 100644 --- a/yateto/ast/indices.py +++ b/yateto/ast/indices.py @@ -25,7 +25,7 @@ def shape(self): return self.subShape(self._indices) def subShape(self, indexNames): - return tuple([self._size[index] for index in indexNames]) + return tuple(self._size[index] for index in indexNames) def indexSize(self, index): return self._size[index] @@ -70,7 +70,7 @@ def __rand__(self, other): def __le__(self, other): indexNamesContained = set(self._indices) <= set(other._indices) - return indexNamesContained and all([self._size[index] == other._size[index] for index in self._indices]) + return indexNamesContained and all(self._size[index] == other._size[index] for index in self._indices) def __sub__(self, other): indexNames = [index for index in self._indices if index not in other] @@ -89,7 +89,7 @@ def __str__(self): return self.tostring() def __repr__(self): - return '({})'.format(','.join(['{}={}'.format(index, self._size[index]) for index in self._indices])) + return '({})'.format(','.join('{}={}'.format(index, self._size[index]) for index in self._indices)) def size(self): return self._size @@ -143,8 +143,8 @@ def __contains__(self, entry): if len(self) == 0: return True if isinstance(entry[0], Range): - return all([e in self[i] for i,e in enumerate(entry)]) - return all([e >= self[i].start and e <= self[i].stop for i,e in enumerate(entry)]) + return all(e in self[i] for i,e in enumerate(entry)) + return all(e >= self[i].start and e <= self[i].stop for i,e in enumerate(entry)) def __getitem__(self, key): return self._box[key] @@ -156,10 +156,10 @@ def __iter__(self): return iter(self._box) def __eq__(self, other): - return all([s == o for s,o in zip(self,other)]) + return all(s == o for s,o in zip(self,other)) def __str__(self): - return '{}({})'.format(type(self).__name__, ', '.join([str(r) for r in self])) + return '{}({})'.format(type(self).__name__, ', '.join(str(r) for r in self)) @functools.total_ordering class LoGCost(object): diff --git a/yateto/ast/log.py b/yateto/ast/log.py index dc68b54..c48f201 100644 --- a/yateto/ast/log.py +++ b/yateto/ast/log.py @@ -13,13 +13,13 @@ def splitByDistance(p): def fusedVariants(memLayout, I, P, M, prune = False): D = list() - indices = sorted([P[p] for p in I]) + indices = sorted(P[p] for p in I) groups = splitByDistance(indices) - groupStrings = [''.join([M[p] for p in sorted(g)]) for g in groups] - D = set([s for g in groupStrings for s in allSubstrings(g)]) + groupStrings = [''.join(M[p] for p in sorted(g)) for g in groups] + D = set(s for g in groupStrings for s in allSubstrings(g)) if prune: - D = set([d for d in D if d[0] == M[0]]) - D = set([d for d in D if memLayout.mayFuse(sorted([P[i] for i in d]))]) + D = set(d for d in D if d[0] == M[0]) + D = set(d for d in D if memLayout.mayFuse(sorted(P[i] for i in d))) return D def LoG(contraction, Aperm = None, Bperm = None, Cperm = None): diff --git a/yateto/ast/node.py b/yateto/ast/node.py index ae8313d..93da011 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -59,7 +59,7 @@ def setIndexPermutation(self, indices, permuteEqspp=True): pass def permute(self, indices, spp): - perm = tuple([indices.find(idx) for idx in self.indices]) + perm = tuple(indices.find(idx) for idx in self.indices) return spp.transposed(perm) def _checkMultipleScalarMults(self): @@ -174,7 +174,7 @@ def setIndexPermutation(self, indices, permuteEqspp=True): if str(indices) == str(self.indices): return - p = tuple([self.indices.find(idx) for idx in indices]) + p = tuple(self.indices.find(idx) for idx in indices) if self._eqspp is not None: if permuteEqspp: self._eqspp = self._eqspp.transposed(p) diff --git a/yateto/ast/opt.py b/yateto/ast/opt.py index e006312..54884f8 100644 --- a/yateto/ast/opt.py +++ b/yateto/ast/opt.py @@ -1,14 +1,12 @@ import sys from .node import IndexSum, Product -from copy import deepcopy - def strengthReduction(terms, target_indices, cost_estimator, split = 0): n = len(terms) indexList = [index for term in terms for index in term.indices] uniqueIndices = set(indexList) - summationIndices = set([index for index in uniqueIndices if indexList.count(index) == 1]) - set(target_indices) + summationIndices = set(index for index in uniqueIndices if indexList.count(index) == 1) - set(target_indices) while len(summationIndices) != 0: i = split @@ -31,19 +29,17 @@ def strengthReduction(terms, target_indices, cost_estimator, split = 0): for i in range(n): for j in range(max(i+1,split),n): mulTerm = Product(terms[i], terms[j]) - prodCost = deepcopy(cost_estimator).estimate(mulTerm) + prodCost = cost_estimator.estimate(mulTerm) if best == None or prodCost < minCost: selection = set(range(n)) - set([i,j]) tree = strengthReduction([terms[i] for i in selection] + [mulTerm], - deepcopy(target_indices), + target_indices, cost_estimator, j-1) - cost_estimator_copy = deepcopy(cost_estimator) - treeCost = cost_estimator_copy.estimate(tree) + treeCost = cost_estimator.estimate(tree) if best == None or treeCost < minCost: best = tree minCost = treeCost - cost_estimator = cost_estimator_copy return best diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index d47aac3..152ce50 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -63,7 +63,7 @@ def visit_Einsum(self, node, bound): g = Indices() for child in node: overlap = g & child.indices - if any([g.size()[index] != child.size()[index] for index in overlap]): + if any(g.size()[index] != child.size()[index] for index in overlap): PrettyPrinter().visit(node) raise ValueError('Einsum: Index dimensions do not match: ', g, child.indices, str(child)) g = g.merged(child.indices - overlap) @@ -76,7 +76,7 @@ def visit_Add(self, node, bound): for child in node: self.visit(child, bound) - ok = all([node[0].indices <= child.indices and child.indices <= node[0].indices for child in node]) + ok = all(node[0].indices <= child.indices and child.indices <= node[0].indices for child in node) if not ok: raise ValueError('Add: Indices do not match: ', *[child.indices for child in node]) @@ -198,7 +198,7 @@ def visit_Assign(self, node): def getEqspp(self, terms, targetIndices): # Shortcut if all terms have dense eqspps - if all([term.eqspp().is_dense() for term in terms]): + if all(term.eqspp().is_dense() for term in terms): return aspp.dense(targetIndices.shape()) minTree = opt.strengthReduction(terms, targetIndices, ShapeCostEstimator()) diff --git a/yateto/ast/visitor.py b/yateto/ast/visitor.py index 5f87c53..dd6562c 100644 --- a/yateto/ast/visitor.py +++ b/yateto/ast/visitor.py @@ -46,7 +46,7 @@ def visit(self, node, **kwargs): return result def addIndent(string, indent): - return '\n'.join([indent + line for line in string.splitlines()]) + return '\n'.join(indent + line for line in string.splitlines()) class PrettyPrinter(Visitor): def __init__(self): @@ -236,9 +236,9 @@ def generic_visit(self, node): f.write('%%TensorMarket tensor coordinate real general\n') nzs = pattern.nonzero() if nzs: - f.write('{} {}\n'.format(' '.join([str(s) for s in pattern.shape]), len(nzs[0]))) + f.write('{} {}\n'.format(' '.join(str(s) for s in pattern.shape), len(nzs[0]))) for idx in zip(*nzs): - f.write('{} {}\n'.format(' '.join([str(i) for i in idx]), float(pattern[idx]))) + f.write('{} {}\n'.format(' '.join(str(i) for i in idx), float(pattern[idx]))) nSubplots = 1 for dim in range(2, eqspp.ndim): nSubplots *= eqspp.shape[dim] @@ -254,7 +254,7 @@ def generic_visit(self, node): for index in ndindex(*list(eqspp.shape)[2:]): sl = pattern[(slice(None, None), slice(None, None)) + index] axs[nSubplot].imshow(sl.astype(bool), cmap=self._cmap, norm=self._norm) - axs[nSubplot].set_title('(:,:,{})'.format(','.join([str(i) for i in index])), y=1.2) + axs[nSubplot].set_title('(:,:,{})'.format(','.join(str(i) for i in index)), y=1.2) nSubplot = nSubplot + 1 #plt.setp(axs, xticks=arange(eqspp.shape[1]), yticks=arange(eqspp.shape[0])) fig.tight_layout() @@ -286,14 +286,14 @@ def visit_Einsum(self, node): terms = self.generic_visit(node) childIndices = [child.indices for child in node] assert None not in childIndices and node.indices is not None, 'Use DeduceIndices before {}.'.format(self.__class__.__name__) - einsumDescription = ','.join([indices.tostring() for indices in childIndices]) + einsumDescription = ','.join(indices.tostring() for indices in childIndices) einsumDescription = '{}->{}'.format(einsumDescription, node.indices.tostring()) return einsum(einsumDescription, *terms) def visit_Add(self, node): terms = self.generic_visit(node) assert len(terms) > 1 - permute = lambda indices, tensor: tensor.transpose(tuple([indices.find(idx) for idx in node.indices])) + permute = lambda indices, tensor: tensor.transpose(tuple(indices.find(idx) for idx in node.indices)) return reduce(add, [permute(child.indices, terms[i]) for i,child in enumerate(node)]) def visit_ScalarMultiplication(self, node): diff --git a/yateto/generator.py b/yateto/generator.py index 2a36f75..3b9d47d 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -298,13 +298,12 @@ def unit_test_body(cpp, testFramework): print('Optimizing ASTs...') for kernel in self._kernels: - print(kernel.name) + print(f'{kernel.name} ({len(kernel.ast)} AST(s))') kernel.prepareUntilCodeGen(cost_estimator) for family in self._kernelFamilies.values(): - print(family.name) + print(f'{family.name} ({sum(len(kernel.ast) for kernel in family.kernels())} AST(s))') family.prepareUntilCodeGen(cost_estimator) - # Create mapping from namespace to kernel/family kernel_dict = {} for kernel in self._kernels: diff --git a/yateto/memory.py b/yateto/memory.py index e66c861..0292044 100644 --- a/yateto/memory.py +++ b/yateto/memory.py @@ -115,10 +115,7 @@ def permuted(self, permutation): def address(self, entry): assert entry in self._bbox - a = 0 - for i, e in enumerate(entry): - a += (e - self._bbox[i].start) * self._stride[i] - return a + return sum((e - self._bbox[i].start) * self._stride[i] for i, e in enumerate(entry)) def subtensorOffset(self, topLeftEntry): return self.address(topLeftEntry)