Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New LM Reka #306

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions examples/storm_examples/run_storm_wiki_reka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""
STORM Wiki pipeline powered by Reka API and search engine.
You need to set up the following environment variables to run this script:
- REKA_API_KEY: Reka API key
- YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key,
BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key

Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
from argparse import ArgumentParser

from knowledge_storm import (
STORMWikiRunnerArguments,
STORMWikiRunner,
STORMWikiLMConfigs,
)
from knowledge_storm.rm import (
YouRM,
BingSearch,
BraveRM,
SerperRM,
DuckDuckGoSearchRM,
TavilySearchRM,
SearXNG,
)
from knowledge_storm.utils import load_api_key
from knowledge_storm.lm import RekaModel


def main(args):
load_api_key(toml_file_path="secrets.toml")
lm_configs = STORMWikiLMConfigs()
reka_kwargs = {
"api_key": os.getenv("REKA_API_KEY"),
"temperature": 1.0,
"top_p": 0.9,
}

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = RekaModel(model="reka-flash", max_tokens=500, **reka_kwargs)
question_asker_lm = RekaModel(model="reka-core", max_tokens=500, **reka_kwargs)
outline_gen_lm = RekaModel(model="reka-core", max_tokens=400, **reka_kwargs)
article_gen_lm = RekaModel(model="reka-core", max_tokens=700, **reka_kwargs)
article_polish_lm = RekaModel(model="reka-core", max_tokens=4000, **reka_kwargs)

lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)

engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)

# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
match args.retriever:
case "bing":
rm = BingSearch(
bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
k=engine_args.search_top_k,
)
case "you":
rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
case "brave":
rm = BraveRM(
brave_search_api_key=os.getenv("BRAVE_API_KEY"),
k=engine_args.search_top_k,
)
case "duckduckgo":
rm = DuckDuckGoSearchRM(
k=engine_args.search_top_k, safe_search="On", region="us-en"
)
case "serper":
rm = SerperRM(
serper_search_api_key=os.getenv("SERPER_API_KEY"),
query_params={"autocorrect": True, "num": 10, "page": 1},
)
case "tavily":
rm = TavilySearchRM(
tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
k=engine_args.search_top_k,
include_raw_content=True,
)
case "searxng":
rm = SearXNG(
searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
)
case _:
raise ValueError(
f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"'
)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

topic = input("Topic: ")
runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument(
"--output-dir",
type=str,
default="./results/reka",
help="Directory to store the outputs.",
)
parser.add_argument(
"--max-thread-num",
type=int,
default=3,
help="Maximum number of threads to use. The information seeking part and the article generation"
"part can speed up by using multiple threads. Consider reducing it if keep getting "
'"Exceed rate limit" error when calling LM API.',
)
parser.add_argument(
"--retriever",
type=str,
choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
help="The search engine API to use for retrieving information.",
)
# stage of the pipeline
parser.add_argument(
"--do-research",
action="store_true",
help="If True, simulate conversation to research the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-outline",
action="store_true",
help="If True, generate an outline for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-article",
action="store_true",
help="If True, generate an article for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-polish-article",
action="store_true",
help="If True, polish the article by adding a summarization section and (optionally) removing "
"duplicate content.",
)
# hyperparameters for the pre-writing stage
parser.add_argument(
"--max-conv-turn",
type=int,
default=3,
help="Maximum number of questions in conversational question asking.",
)
parser.add_argument(
"--max-perspective",
type=int,
default=3,
help="Maximum number of perspectives to consider in perspective-guided question asking.",
)
parser.add_argument(
"--search-top-k",
type=int,
default=3,
help="Top k search results to consider for each search query.",
)
# hyperparameters for the writing stage
parser.add_argument(
"--retrieve-top-k",
type=int,
default=3,
help="Top k collected references for each section title.",
)
parser.add_argument(
"--remove-duplicate",
action="store_true",
help="If True, remove duplicate content from the article.",
)

main(parser.parse_args())
118 changes: 118 additions & 0 deletions knowledge_storm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,3 +942,121 @@ def __call__(
completions.append(response.parts[0].text)

return completions


class RekaModel(dspy.dsp.modules.lm.LM):
"""A wrapper class for dspy.LM using Reka."""

def __init__(self, model: str, api_key: Optional[str] = None, **kwargs):
"""Initialize the Reka model.

Args:
model: The name of the Reka model to use
api_key: Optional API key for authentication
**kwargs: Additional arguments passed to the model
"""
super().__init__(model)
try:
from reka.client import Reka
except ImportError as err:
raise ImportError("Reka requires `pip install reka-api`.") from err

self.provider = "reka"
self.api_key = api_key = (
os.environ.get("REKA_API_KEY") if api_key is None else api_key
)

self.kwargs = {
"temperature": kwargs.get("temperature", 0.0),
"max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
"top_p": kwargs.get("top_p", 1.0),
"top_k": kwargs.get("top_k", 1),
"n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
**kwargs,
"model": model,
}
self.history: list[dict[str, Any]] = []
self.client = Reka(api_key=api_key)
self.model = model

self._token_usage_lock = threading.Lock()
self.prompt_tokens = 0
self.completion_tokens = 0

def log_usage(self, response):
"""Log the total tokens from the Reka API response."""
usage_data = response.usage
if usage_data:
with self._token_usage_lock:
self.prompt_tokens += usage_data.input_tokens
self.completion_tokens += usage_data.output_tokens

def get_usage_and_reset(self):
"""Get the total tokens used and reset the token usage."""
usage = {
self.model: {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
}
}
self.prompt_tokens = 0
self.completion_tokens = 0
return usage

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = {**self.kwargs, **kwargs}
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs.pop("n")
response = self.client.chat.create(**kwargs)

json_serializable_history = {
"prompt": prompt,
"response": {
"text": response.responses[0].message.content,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
},
},
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(json_serializable_history)
return response

@backoff.on_exception(
backoff.expo,
(Exception,),
max_time=1000,
max_tries=8,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retrieval of completions from Reka whilst handling API errors."""
return self.basic_request(prompt, **kwargs)

def __call__(self, prompt: str, only_completed=True, return_sorted=False, **kwargs):
"""Retrieves completions from Reka.

Args:
prompt (str): prompt to send to Reka
only_completed (bool, optional): return only completed responses. Defaults to True.
return_sorted (bool, optional): sort completion choices by probability. Defaults to False.

Returns:
list[str]: list of completion choices
"""
assert only_completed, "for now"
assert return_sorted is False, "for now"

n = kwargs.pop("n", 1)
completions = []
for _ in range(n):
response = self.request(prompt, **kwargs)
self.log_usage(response)
completions = [choice.message.content for choice in response.responses]
return completions