Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sum(idx is not None for idx in indices) == 1 check to fix incorrect mapping of multi-indexing in index function #51

Merged
merged 1 commit into from
Dec 5, 2024

Conversation

kamalrajkannan78
Copy link
Contributor

@kamalrajkannan78 kamalrajkannan78 commented Dec 3, 2024

This PR fixes tenstorrent/tt-forge-fe#817

Root cause

  • This Multi indexing operation is root cause of this issue
  • The current logic only supports single indexing for 3D and 4D tensors.
  • Since there is no check to ensure exactly one valid (non-None) index (e.g., sum(idx is not None for idx in indices) == 1), multi-indexing scenarios are incorrectly mapped to _op.take which is not designed to handle multi-indexing, leading to shape mismatches in subsequent passes.

Fix

  • A condition sum(idx is not None for idx in indices) has been added to ensure only single-indexing scenarios are mapped to _op.take.
  • Multi-indexing cases are now correctly routed to adv_index, which is equipped to handle such scenarios.

@kamalrajkannan78 kamalrajkannan78 force-pushed the kkannan/opt_multi_indexing_fix branch from 73c0dd5 to d85728f Compare December 4, 2024 03:03
@kamalrajkannan78 kamalrajkannan78 changed the title Add len(indices) == 1 check to fix incorrect mapping of multi-indexing in index function Add sum(idx is not None for idx in indices) == 1 check to fix incorrect mapping of multi-indexing in index function Dec 4, 2024
@kamalrajkannan78 kamalrajkannan78 marked this pull request as ready for review December 4, 2024 07:53
@kamalrajkannan78 kamalrajkannan78 merged commit 38d2522 into main Dec 5, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

opt variants Failed on "RemoveRedundantReshape" TVM callback in seq classification task
2 participants