Skip to content

Commit

Permalink
#35 added safe asking mode that wraps the query into try-catch block.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Nov 26, 2024
1 parent 7d6b5d3 commit 8398fb6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
26 changes: 24 additions & 2 deletions bulk_chain/core/llm_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
import logging
import time

from bulk_chain.core.utils import format_model_name


class BaseLM(object):

def __init__(self, name):
def __init__(self, name, attempts=None, delay_sec=1, enable_log=True):
self.__name = name
self.__attempts = 1 if attempts is None else attempts
self.__delay_sec = delay_sec

if enable_log:
self.__logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

def ask_safe(self, prompt):

for i in range(self.__attempts):
try:
response = self.ask(prompt)
return response
except:
if self.__logger is not None:
self.__logger.info("Unable to infer the result. Try {} out of {}.".format(i, self.__attempts))
time.sleep(self.__delay_sec)

raise Exception("Can't infer")

def ask(self, prompt):
raise NotImplemented()

def name(self):
return format_model_name(self.__name)
return format_model_name(self.__name)
5 changes: 3 additions & 2 deletions bulk_chain/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def iter_content(input_dicts_iter, llm, schema, cache_target, cache_table, id_co
assert (isinstance(cache_table, str))

infer_modes = {
"default": lambda prompt: llm.ask(prompt[:args.limit_prompt] if args.limit_prompt is not None else prompt)
"default": lambda prompt: llm.ask_safe(prompt[:args.limit_prompt] if args.limit_prompt is not None else prompt)
}

def optional_update_data_records(c, data):
Expand Down Expand Up @@ -99,6 +99,7 @@ def optional_update_data_records(c, data):

parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
parser.add_argument('--adapter', dest='adapter', type=str, default=None)
parser.add_argument('--attempts', dest='attempts', type=int, default=None)
parser.add_argument('--id-col', dest='id_col', type=str, default="uid")
parser.add_argument('--src', dest='src', type=str, default=None)
parser.add_argument('--schema', dest='schema', type=str, default=None,
Expand All @@ -115,7 +116,7 @@ def optional_update_data_records(c, data):
args = parser.parse_args(args=native_args[1:])

# Initialize Large Language Model.
model_args_dict = CmdArgsService.args_to_dict(model_args)
model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
llm, llm_model_name = init_llm(**model_args_dict)

# Setup schema.
Expand Down

0 comments on commit 8398fb6

Please sign in to comment.