From 75fcae31909760f306b6c292ebb3a2eacae6a2a3 Mon Sep 17 00:00:00 2001 From: Maria Khodorchenko Date: Tue, 28 Jan 2025 13:19:55 +0300 Subject: [PATCH] Dev/synthetic gen refactor (#50) * add unsaved changes, delete unnecessary files --- examples/.DS_Store | Bin 6148 -> 0 bytes .../examples/aspect_summarisation_example.py | 4 ++-- protollm-synthetic/examples/quiz_example.py | 4 ++-- protollm-synthetic/examples/rag_example.py | 11 +++++++++-- .../synthetic_pipelines/chains.py | 8 ++++---- protollm-synthetic/pyproject.toml | 2 +- .../tests/test_summarization_chain.py | 4 ++-- 7 files changed, 20 insertions(+), 13 deletions(-) delete mode 100644 examples/.DS_Store diff --git a/examples/.DS_Store b/examples/.DS_Store deleted file mode 100644 index 525563fb81e048a221d4da0f96e49160e391f180..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK!H&}~5FK~RZc2sJ14z9fMdDg5U3CcuE}<+3t^~mWP)Ih>65)-jCf$IlN;$*F zaOF$*9e88A72Otz6GF%njo)}YGfAG4I3^-7U1TRj10wREjGZ2uJ;LLxJJPY9L!i(# z3aTl_aZIzFXy^DB8Q{6=>p3f^q+jRn(=(~gQ>ju!@Z<3qdi+BRETSPv90^@QCv<`u z3#wpWP!2n%k9ur|m|Q;(*Xw<~R_`Xw5t$%9riK)HlCte<5O@D1CvB>kz3y-6Ocqr= z9)A~|gYJ#PBk#!TdvBxns)_2PUR2X${+fF)l#=P%PSTe}-o%63PgGeaMVZfaLXqVN zdGo3$Gu2GhqRc9t8<_#G=k?;jaJdYE2P1zHoUTUx^5Hlb`KOa$wd#3y?%jX%JpLdr zOZAxv#7l#>M$0XWYxn|VFCSTTE>m19C+-cnisq$MznwVra27h!Kb|sX&t|>=r|q zbohN67g}r$nsgF&^C9e+h25bDeRkCMbvTL8plclij)83k4or8z>;Kc=@Bg=xT+cD! z82GOk5Z$xrY>Hd5YwOO<@mlLc-$7Y8t~IzxfuWaT#PU)+gld7`X9E~oYz@K#aX$i@ M2G=+S{wf1M0seY~DF6Tf diff --git a/protollm-synthetic/examples/aspect_summarisation_example.py b/protollm-synthetic/examples/aspect_summarisation_example.py index 3ba8af2..051f6ef 100644 --- a/protollm-synthetic/examples/aspect_summarisation_example.py +++ b/protollm-synthetic/examples/aspect_summarisation_example.py @@ -23,8 +23,8 @@ aspect = "politics" -qwen_large_api_key = os.environ.get("QWEN2VL_OPENAI_API_KEY") -qwen_large_api_base = os.environ.get("QWEN2VL_OPENAI_API_BASE") +qwen_large_api_key = os.environ.get("OPENAI_API_KEY") +qwen_large_api_base = os.environ.get("OPENAI_API_BASE") llm=VLLMChatOpenAI( api_key=qwen_large_api_key, diff --git a/protollm-synthetic/examples/quiz_example.py b/protollm-synthetic/examples/quiz_example.py index 48ecc0c..e951d21 100644 --- a/protollm-synthetic/examples/quiz_example.py +++ b/protollm-synthetic/examples/quiz_example.py @@ -1,6 +1,6 @@ import os -from samplefactory.synthetic_pipelines.chains import QuizChain -from samplefactory.utils import Dataset, VLLMChatOpenAI +from protollm_synthetic.synthetic_pipelines.chains import QuizChain +from protollm_synthetic.utils import Dataset, VLLMChatOpenAI import json import asyncio diff --git a/protollm-synthetic/examples/rag_example.py b/protollm-synthetic/examples/rag_example.py index bb0979e..4d12453 100644 --- a/protollm-synthetic/examples/rag_example.py +++ b/protollm-synthetic/examples/rag_example.py @@ -1,10 +1,16 @@ import os +import os import json import logging from protollm_synthetic.synthetic_pipelines.chains import RAGChain from protollm_synthetic.utils import Dataset, VLLMChatOpenAI import asyncio +import logging +from protollm_synthetic.synthetic_pipelines.chains import RAGChain +from protollm_synthetic.utils import Dataset, VLLMChatOpenAI +import asyncio + # Сохраняем набор данных texts = [ @@ -98,8 +104,8 @@ path = 'tmp_data/sample_data_rag_spb.json' dataset = Dataset(data_col='content', path=path) -qwen_large_api_key = os.environ.get("QWEN_OPENAI_API_KEY") -qwen_large_api_base = os.environ.get("QWEN_OPENAI_API_BASE") +qwen_large_api_key = os.environ.get("OPENAI_API_KEY") +qwen_large_api_base = os.environ.get("OPENAI_API_BASE") logger.info("Initializing LLM connection") @@ -130,4 +136,5 @@ logger.info(f"Writing result to {path}") df.to_json(path, orient="records") +logger.info("Generation successfully finished") logger.info("Generation successfully finished") \ No newline at end of file diff --git a/protollm-synthetic/protollm_synthetic/synthetic_pipelines/chains.py b/protollm-synthetic/protollm_synthetic/synthetic_pipelines/chains.py index e617155..467d193 100644 --- a/protollm-synthetic/protollm_synthetic/synthetic_pipelines/chains.py +++ b/protollm-synthetic/protollm_synthetic/synthetic_pipelines/chains.py @@ -3,14 +3,14 @@ import copy from datetime import datetime import logging -from samplefactory.synthetic_pipelines.prompts import (generate_summary_system_prompt, generate_summary_evaluation_system_prompt, +from protollm_synthetic.synthetic_pipelines.prompts import (generate_summary_system_prompt, generate_summary_evaluation_system_prompt, generate_rag_system_prompt, check_summary_quality_human_prompt, generate_rag_human_prompt, generate_aspect_summarisation_prompt, generate_summary_human_prompt, generate_aspect_summarisation_evaluation_system_prompt, generate_quiz_system_prompt, generate_quiz_human_prompt, generate_instruction_one_shot_system_prompt, generate_instruction_one_shot_human_prompt, merge_instructions, merge_instructions_human_prompt) -from samplefactory.utils import Dataset +from protollm_synthetic.utils import Dataset import numpy as np import asyncio from typing import List, Optional, Dict, Any, TypeVar, cast @@ -26,10 +26,10 @@ RunnableParallel, RunnableLambda) from langchain.chains.combine_documents import create_stuff_documents_chain from openai import APIConnectionError -from samplefactory.synthetic_pipelines.genetic_evolver import GeneticEvolver +from protollm_synthetic.synthetic_pipelines.genetic_evolver import GeneticEvolver import random -from samplefactory.synthetic_pipelines.schemes import (SummaryQualitySchema, +from protollm_synthetic.synthetic_pipelines.schemes import (SummaryQualitySchema, RAGScheme, AspectSummarisationQualitySchema, QuizScheme, FreeQueryScheme, FreeQueryMerger) diff --git a/protollm-synthetic/pyproject.toml b/protollm-synthetic/pyproject.toml index 993f189..4a66247 100644 --- a/protollm-synthetic/pyproject.toml +++ b/protollm-synthetic/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "samplefactory" +name = "protollm-synthetic" version = "0.1.0" description = "Sample generation with LLMs" authors = ["Your Name "] diff --git a/protollm-synthetic/tests/test_summarization_chain.py b/protollm-synthetic/tests/test_summarization_chain.py index 0571a96..aadc887 100644 --- a/protollm-synthetic/tests/test_summarization_chain.py +++ b/protollm-synthetic/tests/test_summarization_chain.py @@ -1,7 +1,7 @@ import unittest import os -from samplefactory.synthetic_pipelines.chains import SummarisationChain -from samplefactory.utils import VLLMChatOpenAI, Dataset +from protollm_synthetic.synthetic_pipelines.chains import SummarisationChain +from protollm_synthetic.utils import VLLMChatOpenAI, Dataset import pandas as pd import asyncio