Skip to content

Commit

Permalink
【Bug】Fix templateparser (#77)
Browse files Browse the repository at this point in the history
[Fix] fix template parser
  • Loading branch information
Harold-lkk authored Jan 3, 2024
1 parent 8ddde9b commit 987618c
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 159 deletions.
169 changes: 85 additions & 84 deletions lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,85 +8,6 @@
from .base_llm import BaseModel


class BaseAPIModel(BaseModel):
"""Base class for API model wrapper.
Args:
model_type (str): The type of model.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""

is_api: bool = True

def __init__(self,
model_type: str,
query_per_second: int = 1,
retry: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None):
self.model_type = model_type
self.max_seq_len = max_seq_len
self.meta_template = meta_template
self.retry = retry
self.query_per_second = query_per_second
self.token_bucket = TokenBucket(query_per_second)
self.template_parser = APITemplateParser(meta_template)

@abstractclassmethod
def generate(self, inputs, max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or list]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""

def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""

english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)

# Count English words
english_count = sum(len(part.split()) for part in english_parts)

# Count Chinese words
chinese_count = sum(len(part) for part in chinese_parts)

return english_count + chinese_count

def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()

def to(self, device):
pass


class APITemplateParser:
"""Intermidate prompt template parser, specifically for API models.
Expand Down Expand Up @@ -199,11 +120,10 @@ def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
return res

def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:

merged_prompt = self.roles.get(
role_prompt['role'],
self.roles.get(
self.roles[role_prompt['role']].get('fallback_role')))
merged_prompt = self.roles[self.roles[role_prompt['role']]]
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
merged_prompt['fallback_role']]]
res = {}
res['role'] = merged_prompt['api_role']
res['content'] = merged_prompt.get('begin', '')
Expand All @@ -212,6 +132,87 @@ def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
return res


class BaseAPIModel(BaseModel):
"""Base class for API model wrapper.
Args:
model_type (str): The type of model.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""

is_api: bool = True

def __init__(self,
model_type: str,
query_per_second: int = 1,
retry: int = 2,
max_seq_len: int = 2048,
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None):
self.model_type = model_type
self.max_seq_len = max_seq_len
self.meta_template = meta_template
self.retry = retry
self.query_per_second = query_per_second
self.token_bucket = TokenBucket(query_per_second)
if template_parser:
self.template_parser = template_parser(meta_template)

@abstractclassmethod
def generate(self, inputs, max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or list]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""

def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""

english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)

# Count English words
english_count = sum(len(part.split()) for part in english_parts)

# Count Chinese words
chinese_count = sum(len(part) for part in chinese_parts)

return english_count + chinese_count

def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()

def to(self, device):
pass


class TokenBucket:
"""A token bucket for rate limiting.
Expand Down
168 changes: 93 additions & 75 deletions lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,6 @@
from typing import Dict, List, Optional, Tuple, Union


class BaseModel:
"""Base class for model wrapper.
Args:
path (str): The path to the model.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (list of dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""

is_api: bool = False

def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
meta_template: Optional[List[Dict]] = None):
self.path = path
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
# meta template
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']

@abstractclassmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""

def parse_template(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
Args:
dialog (List[str or PromptList]): A prompt
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
str: The final string.
"""
return self.template_parser.parse_template(dialog)

def generate_from_template(self, templates, max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates)
return self.generate(inputs, max_out_len=max_out_len, **kwargs)

def to(self, device):
self.model.to(device)


class LMTemplateParser:
"""Intermidate prompt template parser, specifically for language models.
Expand Down Expand Up @@ -127,20 +57,108 @@ def parse_template(self, dialog) -> str:
last_sep = '\n'
return prompt

def _format_begin(self, role_cfg, message):
name = message.get('name', None)
if name is not None:
begin = role_cfg['begin'].get('with_name', '')
if name in role_cfg['begin'].get('name', {}):
begin = begin.format(name=role_cfg['begin']['name'][name])
else:
begin = begin.format(name=name)
else:
if isinstance(role_cfg.get('begin', ''), str):
begin = role_cfg.get('begin', '')
elif isinstance(role_cfg['begin'], dict):
begin = role_cfg['begin'].get('without_name', '')
return begin

def _prompt2str(self,
prompt: Union[str, Dict],
last: bool = False) -> Tuple[str, bool]:
if isinstance(prompt, str):
return prompt
merged_prompt = self.roles.get(
prompt['role'],
self.roles.get(self.roles[prompt['role']].get('fallback_role')))
res = merged_prompt.get('begin', '')
merged_prompt = self.roles.get(prompt['role'])

if merged_prompt.get('fallback_role'):
merged_prompt = self.roles.get(merged_prompt['fallback_role'])
begin = self._format_begin(merged_prompt, prompt)
res = begin
if last and merged_prompt.get('generate', False):
res += prompt.get('content', '')
return res
res += prompt.get('content', '') + merged_prompt.get('end', '')
if last and merged_prompt['role'] != 'assistant':
res += self.roles['assistant']['begin']
res += self._format_begin(self.roles['assistant'], {})
return res
return res


class BaseModel:
"""Base class for model wrapper.
Args:
path (str): The path to the model.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (list of dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""

is_api: bool = False

def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
template_parser: 'LMTemplateParser' = LMTemplateParser,
meta_template: Optional[List[Dict]] = None):
self.path = path
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
# meta template
self.template_parser = template_parser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']

@abstractclassmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""

def parse_template(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
Args:
dialog (List[str or PromptList]): A prompt
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
str: The final string.
"""
return self.template_parser.parse_template(dialog)

def generate_from_template(self, templates, max_out_len: int, **kwargs):
"""Generate completion from a list of templates.
Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates)
return self.generate(inputs, max_out_len=max_out_len, **kwargs)

def to(self, device):
self.model.to(device)

0 comments on commit 987618c

Please sign in to comment.