Skip to content

Commit

Permalink
Remove embedding dataformat cast in tvm and update llama backward test (
Browse files Browse the repository at this point in the history
#1111)

### Ticket
Close #1112

### Problem description
We don't need explicit embedding dataformat cast in tvm (from float32 to
bf16) as dataformat workaround for this case is implemented in mlir.

PRs for reference:
- [TVM change](tenstorrent/tt-tvm#59)
- [Embedding Op
workaround](tenstorrent/tt-mlir#1583)
- [EmbeddingBackward Op
workaround](tenstorrent/tt-mlir#1756)

### What's changed
Removed explicit cast to bfloat16 if dataformat for embedding weights is
float32.
Updated llama backward test to reflect new forge api for training
(setting training argument).

### Checklist
- [x] Remove explicit cast in
third_party/tvm/python/tvm/relay/frontend/pytorch.py
- [x] Update test_llama_backward.py

---------

Co-authored-by: Vladimir Milosevic <[email protected]>
  • Loading branch information
pmarkovicTT and vmilosevic authored Feb 7, 2025
1 parent 513203c commit fb10a81
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 10 deletions.
10 changes: 1 addition & 9 deletions forge/test/mlir/llama/test_llama_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,19 @@

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import forge
from test.mlir.llama.utils.utils import load_model


# TODO(tt-mlir issue #1503): This test is failing because the embedding op doesn't work with FP32.
# It should be fixed in the tt-mlir compiler soon.
@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b"])
@pytest.mark.xfail()
def test_llama_backward(model_path):
# Load Model and Tokenizer
framework_model, tokenizer = load_model(model_path)

prompt = "Q: What is the largest animal?\nA:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

loss_fn = torch.nn.CrossEntropyLoss()
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=1e-3)

# Compile the model with loss and optimizer, this will invoke an autograd pass which produces bwd graph.
compiled_model = forge.compile(framework_model, input_ids, loss=loss_fn, optimizer=framework_optimizer)
compiled_model = forge.compile(framework_model, input_ids, training=True)
2 changes: 1 addition & 1 deletion third_party/tvm

0 comments on commit fb10a81

Please sign in to comment.