Skip to content

Commit

Permalink
Make LLM optional (for Nagini)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeetHet committed Feb 24, 2025
1 parent 8682aef commit 8ce21be
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
4 changes: 2 additions & 2 deletions verified_cogen/runners/rewriters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


class Rewriter:
llm_with_idx: tuple[LLMConfig, int]
llm_with_idx: Optional[tuple[LLMConfig, int]]

def __init__(self, llm: tuple[LLMConfig, int]):
def __init__(self, llm: Optional[tuple[LLMConfig, int]] = None):
self.llm_with_idx = llm

def rewrite(self, prg: str, error: Optional[str] = None) -> tuple[str, str]: ...
10 changes: 5 additions & 5 deletions verified_cogen/runners/rewriters/construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from verified_cogen.runners.rewriters.verus_rewriter import VerusRewriter


def construct_nagini_rewriter(runner_types: list[str], llm: tuple[LLMConfig, int]) -> Optional[Rewriter]:
def construct_nagini_rewriter(runner_types: list[str]) -> Optional[Rewriter]:
runner = None
for runner_type in runner_types:
match runner_type:
case "NaginiRewriter":
runner = NaginiRewriter(llm)
runner = NaginiRewriter()
case "NaginiRewriterFixing":
runner = NaginiRewriterFixing(llm, runner)
runner = NaginiRewriterFixing(runner)
case "NaginiRewriterFixingAST":
runner = NaginiRewriterFixingAST(llm, runner)
runner = NaginiRewriterFixingAST(runner)
case _:
raise ValueError(f"Unexpected nagini rewriter type: {runner_type}")
return runner
Expand All @@ -38,7 +38,7 @@ def construct_verus_rewriter(runner_types: list[str], llm: tuple[LLMConfig, int]
def construct_rewriter(extension: str, llm: tuple[LLMConfig, int], runner_types: list[str]) -> Optional[Rewriter]:
match extension:
case "py":
return construct_nagini_rewriter(runner_types, llm)
return construct_nagini_rewriter(runner_types)
case "rs":
return construct_verus_rewriter(runner_types, llm)
case _ if runner_types:
Expand Down
5 changes: 2 additions & 3 deletions verified_cogen/runners/rewriters/nagini_rewriter_fixing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Optional

from verified_cogen.config import LLMConfig
from verified_cogen.runners.rewriters.__init__ import Rewriter


class NaginiRewriterFixing(Rewriter):
wrapped_rewriter: Optional[Rewriter]

def __init__(self, llm: tuple[LLMConfig, int], rewriter: Optional[Rewriter] = None):
super().__init__(llm)
def __init__(self, rewriter: Optional[Rewriter] = None):
super().__init__()
self.wrapped_rewriter = rewriter

def replace_impl(self, prg: str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional

from verified_cogen.config import LLMConfig
from verified_cogen.runners.rewriters.__init__ import Rewriter
from verified_cogen.tools.inequality_replacer import (
contains_double_inequality,
Expand All @@ -11,8 +10,8 @@
class NaginiRewriterFixingAST(Rewriter):
wrapped_rewriter: Optional[Rewriter]

def __init__(self, llm: tuple[LLMConfig, int], rewriter: Optional[Rewriter] = None):
super().__init__(llm)
def __init__(self, rewriter: Optional[Rewriter] = None):
super().__init__()
self.wrapped_rewriter = rewriter

def rewrite(self, prg: str, error: Optional[str] = None) -> tuple[str, str]:
Expand Down
6 changes: 6 additions & 0 deletions verified_cogen/runners/rewriters/verus_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from verified_cogen.config import LLMConfig
from verified_cogen.runners.rewriters import Rewriter
from verified_cogen.tools import extract_code_from_llm_output

Expand Down Expand Up @@ -40,8 +41,13 @@


class VerusRewriter(Rewriter):
def __init__(self, llm: Optional[tuple[LLMConfig, int]] = None):
super().__init__(llm)
assert self.llm_with_idx is not None, "VerusRewriter requires LLM be set"

def rewrite(self, prg: str, error: Optional[str] = None) -> tuple[str, str]:
assert error is not None, "VerusRewriter requires error message"
assert self.llm_with_idx is not None

llm_config, idx = self.llm_with_idx
llm = llm_config.build(idx)
Expand Down

0 comments on commit 8ce21be

Please sign in to comment.