From 68db0a6ca74d0d2488049c25ff4ea7d9eaf089b6 Mon Sep 17 00:00:00 2001 From: Taurus3301 <58012753+Taurus3301@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:28:41 -0800 Subject: [PATCH 1/2] Added Reka model wrapper --- knowledge_storm/lm.py | 118 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index 0cae49be..e0d2f8fe 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -942,3 +942,121 @@ def __call__( completions.append(response.parts[0].text) return completions + + +class RekaModel(dspy.dsp.modules.lm.LM): + """A wrapper class for dspy.LM using Reka.""" + + def __init__(self, model: str, api_key: Optional[str] = None, **kwargs): + """Initialize the Reka model. + + Args: + model: The name of the Reka model to use + api_key: Optional API key for authentication + **kwargs: Additional arguments passed to the model + """ + super().__init__(model) + try: + from reka.client import Reka + except ImportError as err: + raise ImportError("Reka requires `pip install reka-api`.") from err + + self.provider = "reka" + self.api_key = api_key = ( + os.environ.get("REKA_API_KEY") if api_key is None else api_key + ) + + self.kwargs = { + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", 1), + "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), + **kwargs, + "model": model, + } + self.history: list[dict[str, Any]] = [] + self.client = Reka(api_key=api_key) + self.model = model + + self._token_usage_lock = threading.Lock() + self.prompt_tokens = 0 + self.completion_tokens = 0 + + def log_usage(self, response): + """Log the total tokens from the Reka API response.""" + usage_data = response.usage + if usage_data: + with self._token_usage_lock: + self.prompt_tokens += usage_data.prompt_tokens + self.completion_tokens += usage_data.completion_tokens + + def get_usage_and_reset(self): + """Get the total tokens used and reset the token usage.""" + usage = { + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } + } + self.prompt_tokens = 0 + self.completion_tokens = 0 + return usage + + def basic_request(self, prompt: str, **kwargs): + raw_kwargs = kwargs + kwargs = {**self.kwargs, **kwargs} + # caching mechanism requires hashable kwargs + kwargs["messages"] = [{"role": "user", "content": prompt}] + kwargs.pop("n") + response = self.client.completions.create(**kwargs) + + json_serializable_history = { + "prompt": prompt, + "response": { + "text": response.choices[0].text, + "model": response.model, + "usage": { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + }, + }, + "kwargs": kwargs, + "raw_kwargs": raw_kwargs, + } + self.history.append(json_serializable_history) + return response + + @backoff.on_exception( + backoff.expo, + (Exception,), + max_time=1000, + max_tries=8, + on_backoff=backoff_hdlr, + giveup=giveup_hdlr, + ) + def request(self, prompt: str, **kwargs): + """Handles retrieval of completions from Reka whilst handling API errors.""" + return self.basic_request(prompt, **kwargs) + + def __call__(self, prompt: str, only_completed=True, return_sorted=False, **kwargs): + """Retrieves completions from Reka. + + Args: + prompt (str): prompt to send to Reka + only_completed (bool, optional): return only completed responses. Defaults to True. + return_sorted (bool, optional): sort completion choices by probability. Defaults to False. + + Returns: + list[str]: list of completion choices + """ + assert only_completed, "for now" + assert return_sorted is False, "for now" + + n = kwargs.pop("n", 1) + completions = [] + for _ in range(n): + response = self.request(prompt, **kwargs) + self.log_usage(response) + completions = [choice.text for choice in response.choices] + return completions From bc720f5d1a8f2922bef6df6f9410ac8b9f86a44f Mon Sep 17 00:00:00 2001 From: Taurus3301 <58012753+Taurus3301@users.noreply.github.com> Date: Fri, 17 Jan 2025 13:50:37 -0800 Subject: [PATCH 2/2] Reka LM fix, add example --- .../storm_examples/run_storm_wiki_reka.py | 205 ++++++++++++++++++ knowledge_storm/lm.py | 14 +- 2 files changed, 212 insertions(+), 7 deletions(-) create mode 100644 examples/storm_examples/run_storm_wiki_reka.py diff --git a/examples/storm_examples/run_storm_wiki_reka.py b/examples/storm_examples/run_storm_wiki_reka.py new file mode 100644 index 00000000..187c7fc8 --- /dev/null +++ b/examples/storm_examples/run_storm_wiki_reka.py @@ -0,0 +1,205 @@ +""" +STORM Wiki pipeline powered by Reka API and search engine. +You need to set up the following environment variables to run this script: + - REKA_API_KEY: Reka API key + - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, + BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key + +Output will be structured as below +args.output_dir/ + topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash + conversation_log.json # Log of information-seeking conversation + raw_search_results.json # Raw search results from search engine + direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge + storm_gen_outline.txt # Outline refined with collected information + url_to_info.json # Sources that are used in the final article + storm_gen_article.txt # Final article generated + storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) +""" + +import os +from argparse import ArgumentParser + +from knowledge_storm import ( + STORMWikiRunnerArguments, + STORMWikiRunner, + STORMWikiLMConfigs, +) +from knowledge_storm.rm import ( + YouRM, + BingSearch, + BraveRM, + SerperRM, + DuckDuckGoSearchRM, + TavilySearchRM, + SearXNG, +) +from knowledge_storm.utils import load_api_key +from knowledge_storm.lm import RekaModel + + +def main(args): + load_api_key(toml_file_path="secrets.toml") + lm_configs = STORMWikiLMConfigs() + reka_kwargs = { + "api_key": os.getenv("REKA_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, + } + + # STORM is a LM system so different components can be powered by different models. + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models + # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm + # which is responsible for generating sections with citations. + conv_simulator_lm = RekaModel(model="reka-flash", max_tokens=500, **reka_kwargs) + question_asker_lm = RekaModel(model="reka-core", max_tokens=500, **reka_kwargs) + outline_gen_lm = RekaModel(model="reka-core", max_tokens=400, **reka_kwargs) + article_gen_lm = RekaModel(model="reka-core", max_tokens=700, **reka_kwargs) + article_polish_lm = RekaModel(model="reka-core", max_tokens=4000, **reka_kwargs) + + lm_configs.set_conv_simulator_lm(conv_simulator_lm) + lm_configs.set_question_asker_lm(question_asker_lm) + lm_configs.set_outline_gen_lm(outline_gen_lm) + lm_configs.set_article_gen_lm(article_gen_lm) + lm_configs.set_article_polish_lm(article_polish_lm) + + engine_args = STORMWikiRunnerArguments( + output_dir=args.output_dir, + max_conv_turn=args.max_conv_turn, + max_perspective=args.max_perspective, + search_top_k=args.search_top_k, + max_thread_num=args.max_thread_num, + ) + + # STORM is a knowledge curation system which consumes information from the retrieval module. + # Currently, the information source is the Internet and we use search engine API as the retrieval module. + match args.retriever: + case "bing": + rm = BingSearch( + bing_search_api=os.getenv("BING_SEARCH_API_KEY"), + k=engine_args.search_top_k, + ) + case "you": + rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) + case "brave": + rm = BraveRM( + brave_search_api_key=os.getenv("BRAVE_API_KEY"), + k=engine_args.search_top_k, + ) + case "duckduckgo": + rm = DuckDuckGoSearchRM( + k=engine_args.search_top_k, safe_search="On", region="us-en" + ) + case "serper": + rm = SerperRM( + serper_search_api_key=os.getenv("SERPER_API_KEY"), + query_params={"autocorrect": True, "num": 10, "page": 1}, + ) + case "tavily": + rm = TavilySearchRM( + tavily_search_api_key=os.getenv("TAVILY_API_KEY"), + k=engine_args.search_top_k, + include_raw_content=True, + ) + case "searxng": + rm = SearXNG( + searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k + ) + case _: + raise ValueError( + f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' + ) + + runner = STORMWikiRunner(engine_args, lm_configs, rm) + + topic = input("Topic: ") + runner.run( + topic=topic, + do_research=args.do_research, + do_generate_outline=args.do_generate_outline, + do_generate_article=args.do_generate_article, + do_polish_article=args.do_polish_article, + ) + runner.post_run() + runner.summary() + + +if __name__ == "__main__": + parser = ArgumentParser() + # global arguments + parser.add_argument( + "--output-dir", + type=str, + default="./results/reka", + help="Directory to store the outputs.", + ) + parser.add_argument( + "--max-thread-num", + type=int, + default=3, + help="Maximum number of threads to use. The information seeking part and the article generation" + "part can speed up by using multiple threads. Consider reducing it if keep getting " + '"Exceed rate limit" error when calling LM API.', + ) + parser.add_argument( + "--retriever", + type=str, + choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], + help="The search engine API to use for retrieving information.", + ) + # stage of the pipeline + parser.add_argument( + "--do-research", + action="store_true", + help="If True, simulate conversation to research the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-outline", + action="store_true", + help="If True, generate an outline for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-generate-article", + action="store_true", + help="If True, generate an article for the topic; otherwise, load the results.", + ) + parser.add_argument( + "--do-polish-article", + action="store_true", + help="If True, polish the article by adding a summarization section and (optionally) removing " + "duplicate content.", + ) + # hyperparameters for the pre-writing stage + parser.add_argument( + "--max-conv-turn", + type=int, + default=3, + help="Maximum number of questions in conversational question asking.", + ) + parser.add_argument( + "--max-perspective", + type=int, + default=3, + help="Maximum number of perspectives to consider in perspective-guided question asking.", + ) + parser.add_argument( + "--search-top-k", + type=int, + default=3, + help="Top k search results to consider for each search query.", + ) + # hyperparameters for the writing stage + parser.add_argument( + "--retrieve-top-k", + type=int, + default=3, + help="Top k collected references for each section title.", + ) + parser.add_argument( + "--remove-duplicate", + action="store_true", + help="If True, remove duplicate content from the article.", + ) + + main(parser.parse_args()) diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index e0d2f8fe..1f4e1798 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -988,8 +988,8 @@ def log_usage(self, response): usage_data = response.usage if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.prompt_tokens - self.completion_tokens += usage_data.completion_tokens + self.prompt_tokens += usage_data.input_tokens + self.completion_tokens += usage_data.output_tokens def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" @@ -1009,16 +1009,16 @@ def basic_request(self, prompt: str, **kwargs): # caching mechanism requires hashable kwargs kwargs["messages"] = [{"role": "user", "content": prompt}] kwargs.pop("n") - response = self.client.completions.create(**kwargs) + response = self.client.chat.create(**kwargs) json_serializable_history = { "prompt": prompt, "response": { - "text": response.choices[0].text, + "text": response.responses[0].message.content, "model": response.model, "usage": { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, }, }, "kwargs": kwargs, @@ -1058,5 +1058,5 @@ def __call__(self, prompt: str, only_completed=True, return_sorted=False, **kwar for _ in range(n): response = self.request(prompt, **kwargs) self.log_usage(response) - completions = [choice.text for choice in response.choices] + completions = [choice.message.content for choice in response.responses] return completions