diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py index d42d22b5..898bcfd0 100644 --- a/lagent/llms/base_api.py +++ b/lagent/llms/base_api.py @@ -8,85 +8,6 @@ from .base_llm import BaseModel -class BaseAPIModel(BaseModel): - """Base class for API model wrapper. - - Args: - model_type (str): The type of model. - query_per_second (int): The maximum queries allowed per second - between two consecutive calls of the API. Defaults to 1. - retry (int): Number of retires if the API call fails. Defaults to 2. - max_seq_len (int): The maximum sequence length of the model. Defaults - to 2048. - meta_template (Dict, optional): The model's meta prompt - template if needed, in case the requirement of injecting or - wrapping of any meta instructions. - """ - - is_api: bool = True - - def __init__(self, - model_type: str, - query_per_second: int = 1, - retry: int = 2, - max_seq_len: int = 2048, - meta_template: Optional[Dict] = None): - self.model_type = model_type - self.max_seq_len = max_seq_len - self.meta_template = meta_template - self.retry = retry - self.query_per_second = query_per_second - self.token_bucket = TokenBucket(query_per_second) - self.template_parser = APITemplateParser(meta_template) - - @abstractclassmethod - def generate(self, inputs, max_out_len: int) -> List[str]: - """Generate results given a list of inputs. - - Args: - inputs (List[str or list]): A list of strings or PromptDicts. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - - Returns: - List[str]: A list of generated strings. - """ - - def get_token_len(self, prompt: str) -> int: - """Get lengths of the tokenized string. Only English and Chinese - characters are counted for now. Users are encouraged to override this - method if more accurate length is needed. - - Args: - prompt (str): Input string. - - Returns: - int: Length of the input tokens - """ - - english_parts = re.findall(r'[A-Za-z0-9]+', prompt) - chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) - - # Count English words - english_count = sum(len(part.split()) for part in english_parts) - - # Count Chinese words - chinese_count = sum(len(part) for part in chinese_parts) - - return english_count + chinese_count - - def wait(self): - """Wait till the next query can be sent. - - Applicable in both single-thread and multi-thread environments. - """ - return self.token_bucket.get_token() - - def to(self, device): - pass - - class APITemplateParser: """Intermidate prompt template parser, specifically for API models. @@ -199,11 +120,10 @@ def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]: return res def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: - - merged_prompt = self.roles.get( - role_prompt['role'], - self.roles.get( - self.roles[role_prompt['role']].get('fallback_role'))) + merged_prompt = self.roles[self.roles[role_prompt['role']]] + if merged_prompt.get('fallback_role'): + merged_prompt = self.roles[self.roles[ + merged_prompt['fallback_role']]] res = {} res['role'] = merged_prompt['api_role'] res['content'] = merged_prompt.get('begin', '') @@ -212,6 +132,87 @@ def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: return res +class BaseAPIModel(BaseModel): + """Base class for API model wrapper. + + Args: + model_type (str): The type of model. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + max_seq_len (int): The maximum sequence length of the model. Defaults + to 2048. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + is_api: bool = True + + def __init__(self, + model_type: str, + query_per_second: int = 1, + retry: int = 2, + max_seq_len: int = 2048, + template_parser: 'APITemplateParser' = APITemplateParser, + meta_template: Optional[Dict] = None): + self.model_type = model_type + self.max_seq_len = max_seq_len + self.meta_template = meta_template + self.retry = retry + self.query_per_second = query_per_second + self.token_bucket = TokenBucket(query_per_second) + if template_parser: + self.template_parser = template_parser(meta_template) + + @abstractclassmethod + def generate(self, inputs, max_out_len: int) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or list]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized string. Only English and Chinese + characters are counted for now. Users are encouraged to override this + method if more accurate length is needed. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + + english_parts = re.findall(r'[A-Za-z0-9]+', prompt) + chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) + + # Count English words + english_count = sum(len(part.split()) for part in english_parts) + + # Count Chinese words + chinese_count = sum(len(part) for part in chinese_parts) + + return english_count + chinese_count + + def wait(self): + """Wait till the next query can be sent. + + Applicable in both single-thread and multi-thread environments. + """ + return self.token_bucket.get_token() + + def to(self, device): + pass + + class TokenBucket: """A token bucket for rate limiting. diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index 45cd8d2c..34729ce8 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -2,76 +2,6 @@ from typing import Dict, List, Optional, Tuple, Union -class BaseModel: - """Base class for model wrapper. - - Args: - path (str): The path to the model. - max_seq_len (int): The maximum sequence length of the model. Defaults - to 2048. - tokenizer_only (bool): If True, only the tokenizer will be initialized. - Defaults to False. - meta_template (list of dict, optional): The model's meta prompt - template if needed, in case the requirement of injecting or - wrapping of any meta instructions. - """ - - is_api: bool = False - - def __init__(self, - path: str, - max_seq_len: int = 2048, - tokenizer_only: bool = False, - meta_template: Optional[List[Dict]] = None): - self.path = path - self.max_seq_len = max_seq_len - self.tokenizer_only = tokenizer_only - # meta template - self.template_parser = LMTemplateParser(meta_template) - self.eos_token_id = None - if meta_template and 'eos_token_id' in meta_template: - self.eos_token_id = meta_template['eos_token_id'] - - @abstractclassmethod - def generate(self, inputs: List[str], max_out_len: int) -> List[str]: - """Generate results given a list of inputs. - - Args: - inputs (List[str]): A list of strings. - max_out_len (int): The maximum length of the output. - - Returns: - List[str]: A list of generated strings. - """ - - def parse_template(self, dialog) -> str: - """Parse a prompt template, and wrap it with meta template if - applicable. - - Args: - dialog (List[str or PromptList]): A prompt - template (potentially before being wrapped by meta template). - mode (str): Parsing mode. Choices are 'ppl' and 'gen'. - - Returns: - str: The final string. - """ - return self.template_parser.parse_template(dialog) - - def generate_from_template(self, templates, max_out_len: int, **kwargs): - """Generate completion from a list of templates. - - Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. - """ - inputs = self.parse_template(templates) - return self.generate(inputs, max_out_len=max_out_len, **kwargs) - - def to(self, device): - self.model.to(device) - - class LMTemplateParser: """Intermidate prompt template parser, specifically for language models. @@ -127,20 +57,108 @@ def parse_template(self, dialog) -> str: last_sep = '\n' return prompt + def _format_begin(self, role_cfg, message): + name = message.get('name', None) + if name is not None: + begin = role_cfg['begin'].get('with_name', '') + if name in role_cfg['begin'].get('name', {}): + begin = begin.format(name=role_cfg['begin']['name'][name]) + else: + begin = begin.format(name=name) + else: + if isinstance(role_cfg.get('begin', ''), str): + begin = role_cfg.get('begin', '') + elif isinstance(role_cfg['begin'], dict): + begin = role_cfg['begin'].get('without_name', '') + return begin + def _prompt2str(self, prompt: Union[str, Dict], last: bool = False) -> Tuple[str, bool]: if isinstance(prompt, str): return prompt - merged_prompt = self.roles.get( - prompt['role'], - self.roles.get(self.roles[prompt['role']].get('fallback_role'))) - res = merged_prompt.get('begin', '') + merged_prompt = self.roles.get(prompt['role']) + + if merged_prompt.get('fallback_role'): + merged_prompt = self.roles.get(merged_prompt['fallback_role']) + begin = self._format_begin(merged_prompt, prompt) + res = begin if last and merged_prompt.get('generate', False): res += prompt.get('content', '') return res res += prompt.get('content', '') + merged_prompt.get('end', '') if last and merged_prompt['role'] != 'assistant': - res += self.roles['assistant']['begin'] + res += self._format_begin(self.roles['assistant'], {}) return res return res + + +class BaseModel: + """Base class for model wrapper. + + Args: + path (str): The path to the model. + max_seq_len (int): The maximum sequence length of the model. Defaults + to 2048. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + meta_template (list of dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + is_api: bool = False + + def __init__(self, + path: str, + max_seq_len: int = 2048, + tokenizer_only: bool = False, + template_parser: 'LMTemplateParser' = LMTemplateParser, + meta_template: Optional[List[Dict]] = None): + self.path = path + self.max_seq_len = max_seq_len + self.tokenizer_only = tokenizer_only + # meta template + self.template_parser = template_parser(meta_template) + self.eos_token_id = None + if meta_template and 'eos_token_id' in meta_template: + self.eos_token_id = meta_template['eos_token_id'] + + @abstractclassmethod + def generate(self, inputs: List[str], max_out_len: int) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + + def parse_template(self, dialog) -> str: + """Parse a prompt template, and wrap it with meta template if + applicable. + + Args: + dialog (List[str or PromptList]): A prompt + template (potentially before being wrapped by meta template). + mode (str): Parsing mode. Choices are 'ppl' and 'gen'. + + Returns: + str: The final string. + """ + return self.template_parser.parse_template(dialog) + + def generate_from_template(self, templates, max_out_len: int, **kwargs): + """Generate completion from a list of templates. + + Args: + templates (List[PromptType]): A list of templates. + max_out_len (int): The maximum length of the output. + """ + inputs = self.parse_template(templates) + return self.generate(inputs, max_out_len=max_out_len, **kwargs) + + def to(self, device): + self.model.to(device)