diff --git a/yateto/ast/cost.py b/yateto/ast/cost.py index a2ba456..9c62ca7 100644 --- a/yateto/ast/cost.py +++ b/yateto/ast/cost.py @@ -126,13 +126,13 @@ def estimate_Product(self, node): self._loaded_to_gpu_cache[node] = self._loaded_to_gpu_cache[left_term].union(self._loaded_to_gpu_cache[right_term]) extra_cost = 0 - if not right_term in self._loaded_to_gpu_cache: + if not right_term in self._loaded_to_gpu_cache[node]: self._loaded_to_gpu_cache[node].add(right_term) rbb = self._cache[right_term] extra_cost += rbb.size() if node.indices[self._lead_dim] != left_term.indices[self._lead_dim]: - if not node.leftTerm in self._loaded_to_gpu_cache: + if not node.leftTerm in self._loaded_to_gpu_cache[node]: self._loaded_to_gpu_cache[node].add(left_term) lbb = self._cache[left_term] extra_cost += lbb.size()