Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove embedding dataformat cast in tvm and update llama backward test (
#1111) ### Ticket Close #1112 ### Problem description We don't need explicit embedding dataformat cast in tvm (from float32 to bf16) as dataformat workaround for this case is implemented in mlir. PRs for reference: - [TVM change](tenstorrent/tt-tvm#59) - [Embedding Op workaround](tenstorrent/tt-mlir#1583) - [EmbeddingBackward Op workaround](tenstorrent/tt-mlir#1756) ### What's changed Removed explicit cast to bfloat16 if dataformat for embedding weights is float32. Updated llama backward test to reflect new forge api for training (setting training argument). ### Checklist - [x] Remove explicit cast in third_party/tvm/python/tvm/relay/frontend/pytorch.py - [x] Update test_llama_backward.py --------- Co-authored-by: Vladimir Milosevic <[email protected]>
- Loading branch information