From a1052dfc5b8fa27a6343eed81a28139abdbedd95 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 18 Aug 2023 00:48:08 -0400 Subject: [PATCH 1/2] activate/deactivate model, some defaults and some cleanup --- prompt_selection.ipynb | 2 +- rl_chain.ipynb | 68 +++++++++++++---- rl_chain/__init__.py | 2 +- rl_chain/pick_best_chain.py | 18 ++++- rl_chain/rl_chain_base.py | 119 +++++++++++++++++------------ tests/test_pick_best_chain_call.py | 72 ++++++++++++++++- tests/test_utils.py | 2 +- 7 files changed, 211 insertions(+), 72 deletions(-) diff --git a/prompt_selection.ipynb b/prompt_selection.ipynb index 0caea8b..6c17be3 100644 --- a/prompt_selection.ipynb +++ b/prompt_selection.ipynb @@ -334,7 +334,7 @@ " suffix = rl_chain.ToSelectFrom(['0']))\n", "\n", "vw_chain.metrics.to_pandas()['score'].plot(label=\"vw\")\n", - "rnd_chain.metrics.to_pandas()['score'].plot(label=\"slates\")\n", + "rnd_chain.metrics.to_pandas()['score'].plot(label=\"random\")\n", "plt.legend()" ] } diff --git a/rl_chain.ipynb b/rl_chain.ipynb index c39661d..3c8d8b0 100644 --- a/rl_chain.ipynb +++ b/rl_chain.ipynb @@ -30,10 +30,10 @@ " MealPlanner(name=\"One-Pan Beef Enchiladas Verdes with Mexican Cheese Blend & Hot Sauce Crema\", difficulty=\"Easy\", tags=\"Spicy, Easy Cleanup, Easy Prep\", desc=\"When it comes to Mexican-style cuisine, burritos typically get all the glory. In our humble opinion, enchiladas are an unsung dinner hero. They’re technically easier-to-assemble burritos that get smothered in a delicious sauce, but they’re really so much more than that! Ours start with spiced beef and charred green pepper that get rolled up in warm tortillas. This winning combo gets topped with tangy salsa verde and cheese, then baked until bubbly and melty. Hear that? That’s the sound of the dinner bell!\"),\n", " MealPlanner(name=\"Chicken & Mushroom Flatbreads with Gouda Cream Sauce & Parmesan\", difficulty=\"Easy\", tags=\"\", desc=\"Yes we love our simple cheese pizza with red sauce but tonight, move over, marinara—there’s a new sauce in town. In this recipe, crispy flatbreads are slathered with a rich, creamy gouda-mustard sauce we just can’t get enough of. We top that off with a pile of caramelized onion and earthy cremini mushrooms. Shower with Parmesan, and that’s it. Simple, satisfying, and all in 30 minutes–a dinner idea you can’t pass up!\"),\n", " MealPlanner(name=\"Sweet Potato & Pepper Quesadillas with Southwest Crema & Tomato Salsa\", difficulty=\"Easy\", tags=\"Veggie\", desc=\"This quesadilla is jam-packed with flavorful roasted sweet potato and green pepper, plus two types of gooey, melty cheese (how could we choose just one?!). Of course, we’d never forget the toppings—there’s a fresh tomato salsa and dollops of spiced lime crema. Now for the fun part: piling on a little bit of everything to construct the perfect bite!\"),\n", - " MealPlanner(name=\"One-Pan Trattoria Tortelloni Bake with a Crispy Parmesan Panko Topping\", difficulty=\"Easy\", tags=\"Veggie, Easy Cleanup, Easy Prep\", desc=\"Think a cheesy stuffed pasta can’t get any better? What about baking it in a creamy sauce with a crispy topping? In this recipe, we toss cheese-stuffed tortelloni in an herby tomato cream sauce, then top with Parmesan and panko breadcrumbs. Once broiled, it turns into a showstopping topping that’ll earn you plenty of oohs and aahs from your lucky fellow diners.\"),\n", + " MealPlanner(name=\"One-Pan Trattoria Tortelloni Bake with a Crispy vegan cheese Panko Topping\", difficulty=\"Easy\", tags=\"Veggie, Easy Cleanup, Easy Prep\", desc=\"Think a cheesy stuffed pasta can’t get any better? What about baking it in a creamy sauce with a crispy topping? In this recipe, we toss cheese-stuffed tortelloni in an herby tomato cream sauce, then top with vegan cheese and panko breadcrumbs. Once broiled, it turns into a showstopping topping that’ll earn you plenty of oohs and aahs from your lucky fellow diners.\"),\n", "]\n", "\n", - "meals = [f'title={action.name.replace(\":\", \"\").replace(\"|\", \"\")}' for action in actions]" + "meals = [[f'{action.name.replace(\":\", \"\").replace(\"|\", \"\")}', f' {action.tags}'] for action in actions]" ] }, { @@ -89,7 +89,8 @@ " input_variables=[\"meal\", \"text_to_personalize\"], template=_PROMPT_TEMPLATE\n", ")\n", "\n", - "chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)\n" + "chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT, metrics_step=1)\n", + "random_chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT, metrics_step=1, policy=rl_chain.PickBestRandomPolicy)\n" ] }, { @@ -98,16 +99,54 @@ "metadata": {}, "outputs": [], "source": [ - "response = chain.run(\n", - " meal = rl_chain.ToSelectFrom(meals),\n", - " User = rl_chain.BasedOn(\"Tom Hanks\"),\n", - " preference = rl_chain.BasedOn(\"Vegetarian, regular dairy is ok\"),\n", - " text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n", - ")\n", - "\n", - "print(response[\"response\"])\n", - "rr = response[\"selection_metadata\"]\n", - "print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")" + "for i in range(2):\n", + " try:\n", + " if i % 2:\n", + " print(\"Tom\")\n", + " response = chain.run(\n", + " meal = rl_chain.ToSelectFrom(meals),\n", + " User = rl_chain.BasedOn(\"Tom\"),\n", + " preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n", + " text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n", + " )\n", + " random_chain.run(\n", + " meal = rl_chain.ToSelectFrom(meals),\n", + " User = rl_chain.BasedOn(\"Tom\"),\n", + " preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n", + " text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n", + " )\n", + " else:\n", + " print(\"Anna\")\n", + " response = chain.run(\n", + " meal = rl_chain.ToSelectFrom(meals),\n", + " User = rl_chain.BasedOn(\"Anna\"),\n", + " preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n", + " text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n", + " )\n", + " random_chain.run(\n", + " meal = rl_chain.ToSelectFrom(meals),\n", + " User = rl_chain.BasedOn(\"Anna\"),\n", + " preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n", + " text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n", + " )\n", + "\n", + " print(response[\"response\"])\n", + " rr = response[\"selection_metadata\"]\n", + " print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")\n", + " except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "chain.metrics.to_pandas()['score'].plot(label=\"vw\")\n", + "random_chain.metrics.to_pandas()['score'].plot(label=\"random\")\n", + "plt.legend()" ] }, { @@ -330,8 +369,7 @@ "class CustomSelectionScorer(rl_chain.SelectionScorer):\n", " #grade or score the response\n", " def score_response(\n", - " self, inputs, llm_response: str\n", - " ) -> float:\n", + " self, inputs, llm_response: str, event: rl_chain.PickBest.Event) -> float:\n", " # do whatever you want here, use whatever inputs you supplied and return reward\n", " reward = 1.0\n", " return reward\n", diff --git a/rl_chain/__init__.py b/rl_chain/__init__.py index 1d9c216..09770a1 100644 --- a/rl_chain/__init__.py +++ b/rl_chain/__init__.py @@ -1,4 +1,4 @@ -from .pick_best_chain import PickBest +from .pick_best_chain import PickBest, PickBestRandomPolicy from .slates_chain import ( SlatesPersonalizerChain, SlatesRandomPolicy, diff --git a/rl_chain/pick_best_chain.py b/rl_chain/pick_best_chain.py index 6fe8e82..40e5da0 100644 --- a/rl_chain/pick_best_chain.py +++ b/rl_chain/pick_best_chain.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import random from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from sentence_transformers import SentenceTransformer @@ -83,6 +84,21 @@ def format(self, event: PickBest.Event) -> str: return example_string[:-1] +class PickBestRandomPolicy(base.Policy): + def __init__(self, feature_embedder: base.Embedder, *_, **__): + self.feature_embedder = feature_embedder + + def predict(self, event: PickBest.Event) -> List[Tuple[int, float]]: + num_items = len(event.to_select_from) + return [(i, 1.0 / num_items) for i in range(num_items)] + + def learn(self, event: PickBest.Event) -> Any: + pass + + def log(self, event: PickBest.Event) -> Any: + pass + + class PickBest(base.RLChain): """ PickBest class that utilizes the Vowpal Wabbit (VW) model for personalization. @@ -153,7 +169,7 @@ def __init__( "--quiet", "--interactions=::", "--coin", - "--epsilon=0.2", + "--squarecb", ] else: if "--cb_explore_adf" not in vw_cmd: diff --git a/rl_chain/rl_chain_base.py b/rl_chain/rl_chain_base.py index a20c78f..5d6bb82 100644 --- a/rl_chain/rl_chain_base.py +++ b/rl_chain/rl_chain_base.py @@ -204,7 +204,9 @@ class SelectionScorer(ABC, BaseModel): """Abstract method to grade the chosen selection or the response of the llm""" @abstractmethod - def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: + def score_response( + self, inputs: Dict[str, Any], llm_response: str, event: Event + ) -> float: pass @@ -222,7 +224,7 @@ def get_default_system_prompt() -> SystemMessagePromptTemplate: @staticmethod def get_default_prompt() -> ChatPromptTemplate: - human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".' + human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{rl_chain_selected}".' human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) default_system_prompt = AutoSelectionScorer.get_default_system_prompt() chat_prompt = ChatPromptTemplate.from_messages( @@ -249,7 +251,9 @@ def set_prompt_and_llm_chain(cls, values): values["llm_chain"] = LLMChain(llm=llm, prompt=prompt) return values - def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: + def score_response( + self, inputs: Dict[str, Any], llm_response: str, event: Event + ) -> float: ranking = self.llm_chain.predict(llm_response=llm_response, **inputs) ranking = ranking.strip() try: @@ -283,10 +287,11 @@ class RLChain(Chain): prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] policy: Optional[Policy] - auto_embed: bool = True + auto_embed: bool = False + selection_scorer_activated: bool = True + metrics: Optional[MetricsTracker] = None selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" - metrics: Optional[MetricsTracker] = None def __init__( self, @@ -336,6 +341,50 @@ def output_keys(self) -> List[str]: """ return [self.output_key] + def update_with_delayed_score( + self, score: float, event: Event, force_score=False + ) -> None: + """ + Learn will be called with the score specified and the actions/embeddings/etc stored in event + + Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call + """ + if self._can_use_selection_scorer() and not force_score: + raise RuntimeError( + "The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." + ) + self.metrics.on_feedback(score) + self._call_after_scoring_before_learning(event=event, score=score) + self.policy.learn(event=event) + self.policy.log(event=event) + + def deactivate_selection_scorer(self) -> None: + """ + Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses. + """ + self.selection_scorer_activated = False + + def activate_selection_scorer(self) -> None: + """ + Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses. + """ + self.selection_scorer_activated = True + + def save_progress(self) -> None: + """ + This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file. + + File Naming Convention: + The file will be named using the pattern `model-.vw`, where `` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `` will be the next in the sequence. + + Example: + If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`. + + Note: + Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers. + """ + self.policy.save() + def _validate_inputs(self, inputs: Dict[str, Any]) -> None: super()._validate_inputs(inputs) if ( @@ -346,6 +395,12 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None: f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." ) + def _can_use_selection_scorer(self) -> bool: + """ + Returns whether the chain can use the selection scorer to score responses or not. + """ + return self.selection_scorer is not None and self.selection_scorer_activated + @abstractmethod def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: pass @@ -368,32 +423,6 @@ def _call_after_scoring_before_learning( ) -> Event: pass - def update_with_delayed_score( - self, score: float, event: Event, force_score=False - ) -> None: - """ - Learn will be called with the score specified and the actions/embeddings/etc stored in event - - Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call - """ - if self.selection_scorer and not force_score: - raise RuntimeError( - "The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." - ) - self.metrics.on_feedback(score) - self._call_after_scoring_before_learning(event=event, score=score) - self.policy.learn(event=event) - self.policy.log(event=event) - - def set_auto_embed(self, auto_embed: bool) -> None: - """ - Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run. - - Args: - auto_embed (bool): Whether the chain should auto embed the inputs or not. - """ - self.auto_embed = auto_embed - def _call( self, inputs: Dict[str, Any], @@ -429,13 +458,13 @@ def _call( score = None try: - if self.selection_scorer: + if self._can_use_selection_scorer(): score = self.selection_scorer.score_response( - inputs=next_chain_inputs, llm_response=output + inputs=next_chain_inputs, llm_response=output, event=event ) except Exception as e: logger.info( - f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}" + f"The selection scorer was not able to rank and the chain was not able to adjust to this response, error: {e}" ) self.metrics.on_feedback(score) event = self._call_after_scoring_before_learning(score=score, event=event) @@ -444,21 +473,6 @@ def _call( return {self.output_key: {"response": output, "selection_metadata": event}} - def save_progress(self) -> None: - """ - This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file. - - File Naming Convention: - The file will be named using the pattern `model-.vw`, where `` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `` will be the next in the sequence. - - Example: - If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`. - - Note: - Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers. - """ - self.policy.save() - @property def _chain_type(self) -> str: return "llm_personalizer_chain" @@ -517,6 +531,13 @@ def embed_list_type( for embed_item in item: if isinstance(embed_item, dict): ret_list.append(embed_dict_type(embed_item, model)) + elif isinstance(embed_item, list): + item_embedding = embed_list_type(embed_item, model, namespace) + # Get the first key from the first dictionary + first_key = next(iter(item_embedding[0])) + # Group the values under that key + grouping = {first_key: [item[first_key] for item in item_embedding]} + ret_list.append(grouping) else: ret_list.append(embed_string_type(embed_item, model, namespace)) return ret_list diff --git a/tests/test_pick_best_chain_call.py b/tests/test_pick_best_chain_call.py index 6c8db42..0310734 100644 --- a/tests/test_pick_best_chain_call.py +++ b/tests/test_pick_best_chain_call.py @@ -115,7 +115,9 @@ def test_user_defined_scorer(): llm, PROMPT = setup() class CustomSelectionScorer(pick_best_chain.base.SelectionScorer): - def score_response(self, inputs, llm_response: str) -> float: + def score_response( + self, inputs, llm_response: str, event: pick_best_chain.PickBest.Event + ) -> float: score = 200 return score @@ -132,11 +134,11 @@ def score_response(self, inputs, llm_response: str) -> float: assert selection_metadata.selected.score == 200.0 -def test_default_embeddings(): +def test_auto_embeddings_on(): llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( - llm=llm, prompt=PROMPT, feature_embedder=feature_embedder + llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True ) str1 = "0" @@ -165,6 +167,32 @@ def test_default_embeddings(): assert vw_str == expected +def test_default_auto_embedder_is_off(): + llm, PROMPT = setup() + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, prompt=PROMPT, feature_embedder=feature_embedder + ) + + str1 = "0" + str2 = "1" + str3 = "2" + ctx_str_1 = "context1" + ctx_str_2 = "context2" + + expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ + + actions = [str1, str2, str3] + + response = chain.run( + User=pick_best_chain.base.BasedOn(ctx_str_1), + action=pick_best_chain.base.ToSelectFrom(actions), + ) + selection_metadata = response["selection_metadata"] + vw_str = feature_embedder.format(selection_metadata) + assert vw_str == expected + + def test_default_embeddings_off(): llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) @@ -194,7 +222,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings(): llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( - llm=llm, prompt=PROMPT, feature_embedder=feature_embedder + llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True ) str1 = "0" @@ -287,3 +315,39 @@ def test_calling_chain_w_reserved_inputs_throws(): User=pick_best_chain.base.BasedOn("Context"), rl_chain_selected=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), ) + + +def test_activate_and_deactivate_scorer(): + llm, PROMPT = setup() + scorer_llm = FakeListChatModel(responses=[300]) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, + prompt=PROMPT, + selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm), + ) + response = chain.run( + User=pick_best_chain.base.BasedOn("Context"), + action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), + ) + # chain llm used for both basic prompt and for scoring + assert response["response"] == "hey" + selection_metadata = response["selection_metadata"] + assert selection_metadata.selected.score == 300.0 + + chain.deactivate_selection_scorer() + response = chain.run( + User=pick_best_chain.base.BasedOn("Context"), + action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), + ) + assert response["response"] == "hey" + selection_metadata = response["selection_metadata"] + assert selection_metadata.selected.score == None + + chain.activate_selection_scorer() + response = chain.run( + User=pick_best_chain.base.BasedOn("Context"), + action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), + ) + assert response["response"] == "hey" + selection_metadata = response["selection_metadata"] + assert selection_metadata.selected.score == 300.0 diff --git a/tests/test_utils.py b/tests/test_utils.py index 8b37731..544a545 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,7 @@ class MockScorer(SelectionScorer): def score_response( - self, inputs: Dict[str, Any], llm_response: str, **kwargs + self, inputs: Dict[str, Any], llm_response: str, event: Any ) -> float: return float(llm_response) From 63d4ab1e008eb586a1233d739f79e88302e94011 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 18 Aug 2023 00:49:54 -0400 Subject: [PATCH 2/2] remove comment --- rl_chain/rl_chain_base.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/rl_chain/rl_chain_base.py b/rl_chain/rl_chain_base.py index 5d6bb82..4a7a021 100644 --- a/rl_chain/rl_chain_base.py +++ b/rl_chain/rl_chain_base.py @@ -374,14 +374,6 @@ def save_progress(self) -> None: """ This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file. - File Naming Convention: - The file will be named using the pattern `model-.vw`, where `` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `` will be the next in the sequence. - - Example: - If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`. - - Note: - Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers. """ self.policy.save()