Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

substitute make expr more complicated #59

Open
xqdan opened this issue May 8, 2019 · 1 comment
Open

substitute make expr more complicated #59

xqdan opened this issue May 8, 2019 · 1 comment

Comments

@xqdan
Copy link
Contributor

xqdan commented May 8, 2019

import tvm

def register_mem(scope_tb, max_bits):
    #Register mem
    @tvm.register_func("tvm.info.mem.%s" % scope_tb)
    def mem_info_inp_buffer():
        return tvm.make.node("MemoryInfo",
                        unit_bits= 16,
                        max_simd_bits=32,
                        max_num_bits=max_bits,
                        head_address=None)


def test():
    scope_tb = "local.L0v"
    max_bits = 1024 * 1024 * 1024

    ib = tvm.ir_builder.create()
    A = ib.allocate("int32", 200, name="A", scope=scope_tb)
    with ib.for_range(0, 10, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            A[i*10+j] = 1

    B = ib.allocate("int32", 200, name="B", scope=scope_tb)
    with ib.for_range(0, 10, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            with ib.if_scope(j == A[i]):
                B[i*10+j] = 2

    body = ib.get()
    print(tvm.ir_pass.Simplify(body))

test()

before

// attr [A] storage_scope = "local.L0v"
allocate A[int32 * 200]
for (i, 0, 10) {
  for (j, 0, 10) {
    A[((i*10) + j)] = 1
  }
}
// attr [B] storage_scope = "local.L0v"
allocate B[int32 * 200]
for (j, 0, 10) {
  for (j, 0, 10) {
    if ((j == A[j])) {
      B[((j*10) + j)] = 2
    }
  }
}

after, got B[((j*10) + A[j])], which is a more complicated expr

// attr [A] storage_scope = "local.L0v"
allocate A[int32 * 200]
for (i, 0, 10) {
  for (j, 0, 10) {
    A[((i*10) + j)] = 1
  }
}
// attr [B] storage_scope = "local.L0v"
allocate B[int32 * 200]
for (j, 0, 10) {
  for (j, 0, 10) {
    if ((j == A[j])) {
      B[((j*10) + A[j])] = 2
    }
  }
}
@xqdan
Copy link
Contributor Author

xqdan commented May 8, 2019

diff --git a/src/arithmetic/Simplify.cpp b/src/arithmetic/Simplify.cpp
index 8a0d6e3..0dbb63e 100644
--- a/src/arithmetic/Simplify.cpp
+++ b/src/arithmetic/Simplify.cpp
@@ -3917,7 +3917,7 @@ private:
                 const Variable *var = eq ? eq->a.as<Variable>() : next.as<Variable>();
 
                 if (eq && var) {
-                    if (!or_chain) {
+                    if (!or_chain && is_const(eq->b)) {
                         then_case = substitute(var, eq->b, then_case);
                     }
                     if (!and_chain && eq->b.type().is_bool()) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant