Skip to content

Commit

Permalink
better abstraction of to_dict and from_dict for PaperRemote (#23)
Browse files Browse the repository at this point in the history
* better abstraction of to_dict and from_dict for PaperRemote

* registry of paper remote types for deserialization, requires population of the registry in service.py
  • Loading branch information
wrongu authored Jun 20, 2023
1 parent 19c6667 commit a8eb7c5
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 36 deletions.
2 changes: 1 addition & 1 deletion llm4papers/editor_agents/OpenAIChatEditorAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def edit(self, paper: PaperRemote, edit: EditTrigger) -> Iterable[EditResult]:
)

edited = response["new_window"] + "\n"
logger.info(f"Edited text for document {paper.dict()}:")
logger.info(f"Edited text for document {paper.to_dict()}:")
logger.info(f"- {editable_text}")
logger.info(f"+ {edited}")

Expand Down
4 changes: 0 additions & 4 deletions llm4papers/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
This file includes the PaperManager, types, as well as other systems to edit an
academic paper using a large language model AI.
"""
from typing import Hashable
from pydantic import BaseModel
from enum import Enum
Expand Down
38 changes: 31 additions & 7 deletions llm4papers/paper_manager/JSONFilePaperManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from llm4papers.editor_agents.OpenAIChatEditorAgent import OpenAIChatEditorAgent
from llm4papers.paper_manager import PaperManager
from llm4papers.logger import logger
from llm4papers.paper_remote.PaperRemote import PaperRemote
from llm4papers.paper_remote.OverleafGitPaperRemote import OverleafGitPaperRemote
from llm4papers.paper_remote import PaperRemote


class JSONFilePaperManager(PaperManager):
Expand All @@ -20,6 +19,7 @@ def __init__(
self._agents = agents or [OpenAIChatEditorAgent(OpenAIConfig().dict())]
self._json_path = pathlib.Path(json_path)
self._load_json()
self._paper_remote_class_lookup = {}

def _load_json(self):
if not self._json_path.exists():
Expand All @@ -30,27 +30,51 @@ def _load_json(self):
with open(self._json_path) as f:
self._json = json.load(f)

def register_paper_remote_class(self, cls):
self._paper_remote_class_lookup[cls.__name__] = cls

def add_paper_remote(self, remote: PaperRemote):
# Make sure it doesn't already exist.
for paper in self.papers():
if paper.dict() == remote.dict():
if paper.to_dict() == remote.to_dict():
logger.info("Paper already exists, not adding.")
return
self._json["papers"].append(remote.dict())
self._json["papers"].append(remote.to_dict())
with open(self._json_path, "w") as f:
json.dump(self._json, f)

def papers(self) -> list[PaperRemote]:
papers_json = self._json["papers"]
# TODO - support other PaperRemote classes with some abstraction here
return [OverleafGitPaperRemote(paper["git_repo"]) for paper in papers_json]
papers = []
for paper_dict in papers_json:
if "type" not in paper_dict:
logger.error(f"Paper dict {paper_dict} has no 'type' key.")
continue

if paper_dict["type"] in self._paper_remote_class_lookup:
cls = self._paper_remote_class_lookup[paper_dict["type"]]
else:
logger.error(
f"PaperRemote type {paper_dict['type']} is unknown (did "
f"you call manager.register_paper_remote_class?)."
)
continue

try:
paper = cls.from_dict(paper_dict)
except Exception as e:
logger.error(f"Error creating paper from dict {paper_dict}: {e}")
continue

papers.append(paper)
return papers

def poll_once(self):
self._load_json()
logger.info(f"Polling {len(self.papers())} papers for edits.")
is_triggered = False
for paper in self.papers():
logger.info(f"Polling paper {paper.dict()}")
logger.info(f"Polling paper {paper.to_dict()}")
paper.refresh()
is_triggered |= self._do_edits_helper(paper)
return is_triggered
Expand Down
13 changes: 4 additions & 9 deletions llm4papers/paper_remote/InMemoryPaperRemote.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,10 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
for doc_range in edit.input_ranges + edit.output_ranges:
return doc_range.doc_id in self.list_doc_ids()

def dict(self) -> dict:
"""
Return a dictionary representation of this remote.
"""
return {
"type": "InMemoryPaperRemote",
"kwargs": {"documents": self._documents},
}
def to_dict(self):
d = super().to_dict()
d["kwargs"]["documents"] = self._documents
return d

def perform_edit(self, edit: EditResult) -> bool:
"""
Expand Down
19 changes: 16 additions & 3 deletions llm4papers/paper_remote/MultiDocumentPaperRemote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from llm4papers.models import EditTrigger, EditResult
from llm4papers.paper_remote.PaperRemote import PaperRemote, DocumentID
from llm4papers.paper_remote.PaperRemote import PaperRemote, DocumentID, PaperRemoteDict


class MultiDocumentPaperRemote(PaperRemote):
Expand Down Expand Up @@ -64,12 +64,25 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
"""
raise NotImplementedError()

def dict(self) -> dict:
def to_dict(self) -> PaperRemoteDict:
"""
Return a dictionary representation of this remote.
Subclasses should start with super().to_dict() and then update the "kwargs"
"""
raise NotImplementedError()
return {
"type": self.__class__.__name__,
"kwargs": {},
}

@classmethod
def from_dict(cls, d: PaperRemoteDict) -> "MultiDocumentPaperRemote":
if cls.__name__ != d["type"]:
raise ValueError(
f"Cannot create {cls.__name__} from dict of type {d['type']}"
)
else:
return cls(**d["kwargs"])

def perform_edit(self, edit: EditResult) -> bool:
"""
Expand Down
14 changes: 4 additions & 10 deletions llm4papers/paper_remote/OverleafGitPaperRemote.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,10 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
return False
return True

def dict(self) -> dict:
"""
Return a dictionary representation of this remote.
"""
return {
"git_cached_repo": self._git_cached_repo_arg,
"repo_slug": self._reposlug,
"type": "OverleafGitPaperRemote",
}
def to_dict(self):
d = super().to_dict()
d["kwargs"]["git_cached_repo"] = self._git_cached_repo_arg
return d

def perform_edit(self, edit: EditResult) -> bool:
"""
Expand Down
13 changes: 11 additions & 2 deletions llm4papers/paper_remote/PaperRemote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from llm4papers.models import DocumentID, RevisionID, EditTrigger, EditResult

from typing import Protocol
from typing import Protocol, TypedDict, Any


class PaperRemoteDict(TypedDict):
type: str
kwargs: dict[str, Any]


class PaperRemote(Protocol):
Expand All @@ -20,7 +25,11 @@ class PaperRemote(Protocol):

current_revision_id: RevisionID

def dict(self):
def to_dict(self) -> PaperRemoteDict:
...

@classmethod
def from_dict(cls, d: PaperRemoteDict) -> "PaperRemote":
...

def refresh(self):
Expand Down
1 change: 1 addition & 0 deletions llm4papers/paper_remote/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .MultiDocumentPaperRemote import MultiDocumentPaperRemote
from .OverleafGitPaperRemote import OverleafGitPaperRemote


__all__ = [
"PaperRemote",
"DocumentID",
Expand Down
2 changes: 2 additions & 0 deletions llm4papers/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
from llm4papers.config import Settings
from llm4papers.paper_manager.JSONFilePaperManager import JSONFilePaperManager
from llm4papers.paper_remote import OverleafGitPaperRemote

if __name__ == "__main__":
manifest_file = (
Expand All @@ -9,4 +10,5 @@
else Settings().json_manifest_file
)
manager = JSONFilePaperManager(manifest_file)
manager.register_paper_remote_class(OverleafGitPaperRemote)
manager.poll(interval=Settings().polling_interval_sec)

0 comments on commit a8eb7c5

Please sign in to comment.