Skip to content

Commit

Permalink
Add chat api with system prompt
Browse files Browse the repository at this point in the history
Summary: Chat api was missing option to add system prompt. Added chat api with system prompt

Reviewed By: csahana95

Differential Revision: D59473748

fbshipit-source-id: 56f53d338a8f090180100544716d4814331f938e
  • Loading branch information
dwjsong authored and facebook-github-bot committed Jul 8, 2024
1 parent 6c8d4d0 commit 67e5201
Showing 1 changed file with 125 additions and 1 deletion.
126 changes: 125 additions & 1 deletion CybersecurityBenchmarks/benchmark/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,24 @@ def chat(
guided_decode_json_schema: Optional[str] = None,
) -> str:
"""
Abstract method to query an LLM with a given prompt and return the response.
Abstract method to chat with an LLM with given prompts and return the response.
Args:
prompt_with_history ([str]): The prompt to send to the LLM along with chat history. The last string in the list is the prompt.
Returns:
str: The response from the LLM.
"""
pass

@abstractmethod
def chat_with_system_prompt(
self,
prompt_with_history: List[str],
guided_decode_json_schema: Optional[str] = None,
) -> str:
"""
Abstract method to chat with an LLM with given prompts as well as system prompt and return the response.
Args:
prompt_with_history ([str]): The prompt to send to the LLM along with chat history. The last string in the list is the prompt.
Expand Down Expand Up @@ -196,6 +213,20 @@ def query_with_system_prompt_with_retries(

return response

def chat_with_system_prompt_with_retries(
self,
system_prompt: str,
prompt_with_history: List[str],
guided_decode_json_schema: Optional[str] = None,
) -> str:
# add cache handler
return self._chat_with_retries(
self.chat_with_system_prompt,
system_prompt,
prompt_with_history,
guided_decode_json_schema,
)

def query_with_system_prompt_and_image(
self,
system_prompt: str,
Expand Down Expand Up @@ -299,6 +330,37 @@ def chat(
logging.getLogger().setLevel(level)
return response.choices[0].message.content

@override
def chat_with_system_prompt(
self,
system_prompt: str,
prompt_with_history: List[str],
guided_decode_json_schema: Optional[str] = None,
) -> str:
messages = [{"role": "system", "content": system_prompt}]
for i in range(len(prompt_with_history)):
if i % 2 == 0:
messages.append({"role": "user", "content": prompt_with_history[i]})
else:
messages.append(
{"role": "assistant", "content": prompt_with_history[i]}
)

level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
response_format=(
{"type": "json_object"}
if guided_decode_json_schema is not None
else None
),
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content

@override
def query(
self, prompt: str, guided_decode_json_schema: Optional[str] = None
Expand Down Expand Up @@ -424,6 +486,37 @@ def chat(
logging.getLogger().setLevel(level)
return response.choices[0].message.content

@override
def chat_with_system_prompt(
self,
system_prompt: str,
prompt_with_history: List[str],
guided_decode_json_schema: Optional[str] = None,
) -> str:
messages = [{"role": "system", "content": system_prompt}]
for i in range(len(prompt_with_history)):
if i % 2 == 0:
messages.append({"role": "user", "content": prompt_with_history[i]})
else:
messages.append(
{"role": "assistant", "content": prompt_with_history[i]}
)

level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
response_format=(
{"type": "json_object"}
if guided_decode_json_schema is not None
else None
),
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content

@override
def query(
self, prompt: str, guided_decode_json_schema: Optional[str] = None
Expand Down Expand Up @@ -525,6 +618,37 @@ def chat(
)
return response.choices[0].message.content

@override
def chat_with_system_prompt(
self,
system_prompt: str,
prompt_with_history: List[str],
guided_decode_json_schema: Optional[str] = None,
) -> str:
messages = [{"role": "system", "content": system_prompt}]
for i in range(len(prompt_with_history)):
if i % 2 == 0:
messages.append({"role": "user", "content": prompt_with_history[i]})
else:
messages.append(
{"role": "assistant", "content": prompt_with_history[i]}
)

level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=MAX_TOKENS,
response_format=(
{"type": "json_object"}
if guided_decode_json_schema is not None
else None
),
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content

@override
def query(
self, prompt: str, guided_decode_json_schema: Optional[str] = None
Expand Down

0 comments on commit 67e5201

Please sign in to comment.