diff --git a/edenai_apis/apis/tenstorrent/info.json b/edenai_apis/apis/tenstorrent/info.json index f2046885..ea52543c 100644 --- a/edenai_apis/apis/tenstorrent/info.json +++ b/edenai_apis/apis/tenstorrent/info.json @@ -32,6 +32,12 @@ }, "named_entity_recognition": { "version": "v1.0.0" + }, + "chat": { + "version": "v1.0.0" + }, + "generation": { + "version": "v1.0.0" } } -} \ No newline at end of file +} diff --git a/edenai_apis/apis/tenstorrent/outputs/text/chat_output.json b/edenai_apis/apis/tenstorrent/outputs/text/chat_output.json new file mode 100644 index 00000000..7f532df5 --- /dev/null +++ b/edenai_apis/apis/tenstorrent/outputs/text/chat_output.json @@ -0,0 +1,44 @@ +{ + "original_response": { + "id": "chat-f3a36686fe6b4201a7cf3bd86cdf1bfe", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "Barack, Obama, United, lateinit States, Democratic, Illinois_validation", + "role": "assistant", + "tool_calls": [] + }, + "stop_reason": null + } + ], + "created": 1737667927, + "model": "tenstorrent/Meta-Llama-3.1-70B-Instruct", + "object": "chat.completion", + "usage": { + "completion_tokens": 15, + "prompt_tokens": 248, + "total_tokens": 263 + }, + "prompt_logprobs": null + }, + "standardized_response": { + "generated_text": "Barack, Obama, United, lateinit States, Democratic, Illinois_validation", + "message": [ + { + "role": "user", + "message": "Barack Hussein Obama is an American politician who served as the 44th president of the United States from 2009 to 2017. A member of the Democratic Party, Obama was the first African-American president of the United States. He previously served as a U.S. senator from Illinois from 2005 to 2008 and as an Illinois state senator from 1997 to 2004.", + "tools": null, + "tool_calls": null + }, + { + "role": "assistant", + "message": "Barack, Obama, United, lateinit States, Democratic, Illinois_validation", + "tools": null, + "tool_calls": null + } + ] + } +} \ No newline at end of file diff --git a/edenai_apis/apis/tenstorrent/outputs/text/generation_output.json b/edenai_apis/apis/tenstorrent/outputs/text/generation_output.json new file mode 100644 index 00000000..c619f5df --- /dev/null +++ b/edenai_apis/apis/tenstorrent/outputs/text/generation_output.json @@ -0,0 +1,26 @@ +{ + "original_response": { + "id": "cmpl-3f525da338214ac594e080413aefa2ad", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "text": " I've never seen you before.\n\nAI Assistant: Hello! *big virtual smile* I'm delighted to meet you! I'm an AI created to assist and chat with humans. Think of me as a friendly companion, always ready to lend a helping hand (or rather, a helping byte). How can I help you today?\n\nHuman: That's great! I'm looking for some advice on how to plan a surprise birthday party for my friend's 30th birthday. Can you help me come up with some ideas?\n\nAI Assistant: What an exciting task! I'd be happy to help you plan an unforgettable celebration for your friend's milestone birthday!\n\nTo get started, can you tell me a bit more about your friend? What are their interests, favorite hobbies, or things they love? This will help me get a better sense of their personality and tailor some party ideas that will make their 30th birthday truly special.\n\nAlso, what's the guest list looking like? Is it an intimate gathering or a bigger bash?\n\n(And, by the way, I have some fun party planning tools and ideas up my sleeve, so get ready for some brainstorming fun!)", + "stop_reason": null, + "prompt_logprobs": null + } + ], + "created": 1737667994, + "model": "tenstorrent/Meta-Llama-3.1-70B-Instruct", + "object": "text_completion", + "usage": { + "completion_tokens": 236, + "prompt_tokens": 31, + "total_tokens": 267 + } + }, + "standardized_response": { + "generated_text": " I've never seen you before.\n\nAI Assistant: Hello! *big virtual smile* I'm delighted to meet you! I'm an AI created to assist and chat with humans. Think of me as a friendly companion, always ready to lend a helping hand (or rather, a helping byte). How can I help you today?\n\nHuman: That's great! I'm looking for some advice on how to plan a surprise birthday party for my friend's 30th birthday. Can you help me come up with some ideas?\n\nAI Assistant: What an exciting task! I'd be happy to help you plan an unforgettable celebration for your friend's milestone birthday!\n\nTo get started, can you tell me a bit more about your friend? What are their interests, favorite hobbies, or things they love? This will help me get a better sense of their personality and tailor some party ideas that will make their 30th birthday truly special.\n\nAlso, what's the guest list looking like? Is it an intimate gathering or a bigger bash?\n\n(And, by the way, I have some fun party planning tools and ideas up my sleeve, so get ready for some brainstorming fun!)" + } +} \ No newline at end of file diff --git a/edenai_apis/apis/tenstorrent/tenstorrent_api.py b/edenai_apis/apis/tenstorrent/tenstorrent_api.py index 9e5a1593..91193186 100644 --- a/edenai_apis/apis/tenstorrent/tenstorrent_api.py +++ b/edenai_apis/apis/tenstorrent/tenstorrent_api.py @@ -4,7 +4,7 @@ from edenai_apis.features import ProviderInterface from edenai_apis.loaders.data_loader import ProviderDataEnum from edenai_apis.loaders.loaders import load_provider - +from openai import OpenAI class TenstorrentApi( ProviderInterface, @@ -23,3 +23,9 @@ def __init__(self, api_keys: Dict = {}): "content-type": "application/json", "Tenstorrent-Version": "2023-06-26", } + self.chatgen_base_url = "https://chat-and-generation--eden-ai.workload.tenstorrent.com" + self.chatgen_api_version = "v1" + self.chatgen_url = f"{self.chatgen_base_url}/{self.chatgen_api_version}" + self.client = OpenAI( + api_key=self.api_key, base_url=self.chatgen_url + ) diff --git a/edenai_apis/apis/tenstorrent/tenstorrent_text_api.py b/edenai_apis/apis/tenstorrent/tenstorrent_text_api.py index 7f89f69e..508508b1 100644 --- a/edenai_apis/apis/tenstorrent/tenstorrent_text_api.py +++ b/edenai_apis/apis/tenstorrent/tenstorrent_text_api.py @@ -1,7 +1,5 @@ -from typing import List, Optional - import requests - +from typing import Dict, List, Optional, Union from edenai_apis.features.text.keyword_extraction.keyword_extraction_dataclass import ( KeywordExtractionDataClass, ) @@ -14,13 +12,19 @@ from edenai_apis.features.text.sentiment_analysis.sentiment_analysis_dataclass import ( SentimentAnalysisDataClass, ) -from edenai_apis.features.text.text_interface import TextInterface from edenai_apis.features.text.topic_extraction.topic_extraction_dataclass import ( TopicExtractionDataClass, ) +from edenai_apis.features.text.text_interface import TextInterface from edenai_apis.utils.exception import ProviderException from edenai_apis.utils.types import ResponseType - +from edenai_apis.features.text.chat import ChatDataClass, ChatMessageDataClass +from edenai_apis.features.text.chat.chat_dataclass import ( + StreamChat, + ChatStreamResponse, +) +from edenai_apis.features.text.generation import GenerationDataClass +from openai import OpenAI class TenstorrentTextApi(TextInterface): def text__keyword_extraction( @@ -31,7 +35,6 @@ def text__keyword_extraction( payload = { "text": text, } - try: original_response = requests.post(url, json=payload, headers=self.headers) except requests.exceptions.RequestException as exc: @@ -44,7 +47,6 @@ def text__keyword_extraction( # Check for errors self.__check_for_errors(original_response, status_code) - standardized_response = KeywordExtractionDataClass( items=original_response["items"] ) @@ -114,7 +116,6 @@ def text__question_answer( # Check for errors self.__check_for_errors(original_response, status_code) - standardized_response = QuestionAnswerDataClass( answers=[original_response["answer"]] ) @@ -180,7 +181,113 @@ def text__topic_extraction( original_response=original_response, standardized_response=standardized_response, ) + + def text__chat( + self, + text: str, + chatbot_global_action: Optional[str], + previous_history: Optional[List[Dict[str, str]]], + temperature: float, + max_tokens: int, + model: str, + stream=False, + ) -> ResponseType[Union[ChatDataClass, StreamChat]]: + messages = [] + for msg in previous_history: + message = { + "role": msg.get("role"), + "content": msg.get("message"), + } + messages.append(message) + + if text: + messages.append({"role": "user", "content": text}) + + if chatbot_global_action: + messages.insert(0, {"role": "system", "content": chatbot_global_action}) + + payload = { + "model": model, + "temperature": temperature, + "messages": messages, + "max_tokens": max_tokens, + "stream": stream, + } + + + try: + response = self.client.chat.completions.create(**payload) + except Exception as exc: + raise ProviderException(str(exc)) + + # Standardize the response + if stream is False: + message = response.choices[0].message + generated_text = message.content + messages = [ + ChatMessageDataClass(role="user", message=text), + ChatMessageDataClass( + role="assistant", + message=generated_text, + ), + ] + messages_json = [m.dict() for m in messages] + + standardized_response = ChatDataClass( + generated_text=generated_text, message=messages_json + ) + + return ResponseType[ChatDataClass]( + original_response=response.to_dict(), + standardized_response=standardized_response, + ) + else: + stream = ( + ChatStreamResponse( + text=chunk.to_dict()["choices"][0]["delta"].get("content", ""), + blocked=not chunk.to_dict()["choices"][0].get("finish_reason") in (None, "stop"), + provider="tenstorrent", + ) + for chunk in response + if chunk + ) + + return ResponseType[StreamChat]( + original_response=None, standardized_response=StreamChat(stream=stream) + ) + + def text__generation( + self, + text: str, + temperature: float, + max_tokens: int, + model: str, + ) -> ResponseType[GenerationDataClass]: + payload = { + "model": model, + "prompt": text, + "temperature": temperature, + "max_tokens": max_tokens, + } + + try: + response = self.client.completions.create(**payload) + except Exception as exc: + raise ProviderException(str(exc)) + + # Standardize the response + generated_text = response.choices[0].text + + standardized_response = GenerationDataClass( + generated_text=generated_text, + ) + + return ResponseType[GenerationDataClass]( + original_response=response.to_dict(), + standardized_response=standardized_response, + ) def __check_for_errors(self, response, status_code = None): if "message" in response: raise ProviderException(response["message"], code= status_code) + \ No newline at end of file diff --git a/edenai_apis/features/text/chat/chat_args.py b/edenai_apis/features/text/chat/chat_args.py index 876f5503..8159c130 100644 --- a/edenai_apis/features/text/chat/chat_args.py +++ b/edenai_apis/features/text/chat/chat_args.py @@ -23,7 +23,8 @@ def chat_arguments(provider_name: str): "perplexityai": "llama-3.1-sonar-large-128k-chat", "replicate": "llama-2-70b-chat", "anthropic": "claude-3-sonnet-20240229-v1:0", - "xai": "grok-beta" + "xai": "grok-beta", + "tenstorrent": "tenstorrent/Meta-Llama-3.1-70B-Instruct" }, "stream": False, } diff --git a/edenai_apis/features/text/generation/generation_args.py b/edenai_apis/features/text/generation/generation_args.py index 11105737..63848591 100644 --- a/edenai_apis/features/text/generation/generation_args.py +++ b/edenai_apis/features/text/generation/generation_args.py @@ -12,5 +12,6 @@ def generation_arguments(provider_name: str): "mistral": "large-latest", "ai21labs": "j2-ultra", "meta": "llama3-1-70b-instruct-v1:0", + "tenstorrent": "tenstorrent/Meta-Llama-3.1-70B-Instruct", }, }