Skip to content

Commit

Permalink
actually add the llama model file
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Nov 20, 2023
1 parent 50c4b73 commit 28c1718
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions stanza/models/coref/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Functions related to BERT or similar models"""

import logging
from typing import List, Tuple
from peft import prepare_model_for_kbit_training


import numpy as np # type: ignore
from transformers import LlamaModel, AutoTokenizer # type: ignore

from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc

import torch.nn as nn

logger = logging.getLogger('stanza')

def load_llama(config: Config):
logger.debug(f"Loading {config.llama_model}...")

base_llama_name = config.llama_model.split("/")[-1]
tokenizer_kwargs = config.tokenizer_kwargs.get(base_llama_name, {})
if tokenizer_kwargs:
logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}")
tokenizer = AutoTokenizer.from_pretrained(config.llama_model, **tokenizer_kwargs)
model = prepare_model_for_kbit_training(LlamaModel.from_pretrained(config.llama_model, load_in_8bit=True))
# .to(config.device)

This comment has been minimized.

Copy link
@Jemoka

Jemoka Nov 20, 2023

Author Member

because bits and bytes through load_in_8_bit handles device placement on your behalf


logger.debug("Llama successfully loaded.")

return model, tokenizer



0 comments on commit 28c1718

Please sign in to comment.