diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index b1e65457c42..64bdb6682a2 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -186,7 +186,6 @@ def test_llama_model_inference( tt_prefill_input = model_args.prepare_residual_tensor_prefill( pt_prefill_input, - force_replicated=False if model_args.is_galaxy else True, ) for i in range(1): start_pos = 0 diff --git a/models/demos/llama3/tt/distributed_norm.py b/models/demos/llama3/tt/distributed_norm.py index 49c98f15cd1..8fe6c9b7fa4 100644 --- a/models/demos/llama3/tt/distributed_norm.py +++ b/models/demos/llama3/tt/distributed_norm.py @@ -8,7 +8,7 @@ class DistributedNorm(LightweightModule): - def __init__(self, norm, args, TG): + def __init__(self, norm, args, TG=False): self.norm = norm self.args = args