Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge json repair and prompt tune update from microsoft #2

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20240801001005275591.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Update Prompt Tuning meta prompts with finer examples"
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240802002107383210.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix json parsing when LLM returns faulty responses"
}
2 changes: 0 additions & 2 deletions graphrag/index/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .dicts import dict_has_keys_with_types
from .hashing import gen_md5_hash
from .is_null import is_null
from .json import clean_up_json
from .load_graph import load_graph
from .string import clean_str
from .tokens import num_tokens_from_string, string_from_tokens
Expand All @@ -15,7 +14,6 @@

__all__ = [
"clean_str",
"clean_up_json",
"dict_has_keys_with_types",
"gen_md5_hash",
"gen_uuid",
Expand Down
27 changes: 0 additions & 27 deletions graphrag/index/utils/json.py

This file was deleted.

25 changes: 0 additions & 25 deletions graphrag/llm/openai/_json.py

This file was deleted.

3 changes: 2 additions & 1 deletion graphrag/llm/openai/json_parsing_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ async def __call__(
"""Call the LLM with the input and kwargs."""
result = await self._delegate(input, **kwargs)
if kwargs.get("json") and result.json is None and result.output is not None:
result.json = try_parse_json_object(result.output)
_, parsed_json = try_parse_json_object(result.output)
result.json = parsed_json
return result
36 changes: 16 additions & 20 deletions graphrag/llm/openai/openai_chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""The Chat-based language model."""

import logging
from json import JSONDecodeError

from typing_extensions import Unpack

Expand All @@ -16,7 +15,6 @@
LLMOutput,
)

from ._json import clean_up_json
from ._prompts import JSON_CHECK_PROMPT
from .openai_configuration import OpenAIConfiguration
from .types import OpenAIClientTypes
Expand Down Expand Up @@ -104,11 +102,10 @@ async def _native_json(
},
)

raw_output = result.output or ""
json_output = try_parse_json_object(raw_output)
output, json_output = try_parse_json_object(result.output or "")

return LLMOutput[CompletionOutput](
output=raw_output,
output=output,
json=json_output,
history=result.history,
)
Expand All @@ -119,24 +116,23 @@ async def _manual_json(
# Otherwise, clean up the output and try to parse it as json
result = await self._invoke(input, **kwargs)
history = result.history or []
output = clean_up_json(result.output or "")
try:
json_output = try_parse_json_object(output)
output, json_output = try_parse_json_object(result.output or "")
if json_output:
return LLMOutput[CompletionOutput](
output=output, json=json_output, history=history
output=result.output, json=json_output, history=history
)
except (TypeError, JSONDecodeError):
log.warning("error parsing llm json, retrying")
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
result = await self._try_clean_json_with_llm(output, **kwargs)
output = clean_up_json(result.output or "")
json = try_parse_json_object(output)
# if not return correct formatted json, retry
log.warning("error parsing llm json, retrying")

return LLMOutput[CompletionOutput](
output=output,
json=json,
history=history,
)
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
result = await self._try_clean_json_with_llm(output, **kwargs)
output, json_output = try_parse_json_object(result.output or "")

return LLMOutput[CompletionOutput](
output=output,
json=json_output,
history=history,
)

async def _try_clean_json_with_llm(
self, output: str, **kwargs: Unpack[LLMInput]
Expand Down
50 changes: 43 additions & 7 deletions graphrag/llm/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import json
import logging
import re
from collections.abc import Callable
from typing import Any

import tiktoken
from json_repair import repair_json
from openai import (
APIConnectionError,
InternalServerError,
Expand Down Expand Up @@ -87,17 +89,51 @@ def get_completion_llm_args(
}


def try_parse_json_object(input: str) -> dict:
"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
def try_parse_json_object(input: str) -> tuple[str, dict]:
"""JSON cleaning and formatting utilities."""
"""sometime, the llm return a json string with some extra description, this function will clean it up."""
_pattern = r"\{(.*)\}"
_match = re.search(_pattern, input)
input = "{" + _match.group(1) + "}" if _match else input

"""Clean up json string."""
input = (
input.replace("{{", "{")
.replace("}}", "}")
.replace('"[{', "[{")
.replace('}]"', "}]")
.replace("\\", " ")
.replace("\\n", " ")
.replace("\n", " ")
.replace("\r", "")
.strip()
)

# Remove JSON Markdown Frame
if input.startswith("```json"):
input = input[len("```json") :]
if input.endswith("```"):
input = input[: len(input) - len("```")]

try:
result = json.loads(input)
except json.JSONDecodeError:
log.exception("error loading json, json=%s", input)
raise
"""Fixup potentially malformed json string using json_repair."""
input = str(repair_json(json_str=input, return_objects=False))

"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
try:
result = json.loads(input)
except json.JSONDecodeError:
log.exception("error loading json, json=%s", input)
return input, {}
else:
if not isinstance(result, dict):
log.exception("not expected dict type. type=%s:", type(result))
return input, {}
return input, result
else:
if not isinstance(result, dict):
raise TypeError
return result
return input, result


def get_sleep_time_from_error(e: Any) -> float:
Expand Down
Loading
Loading