diff --git a/data/timesketch.conf b/data/timesketch.conf index 7a03e64093..6cca7592e1 100644 --- a/data/timesketch.conf +++ b/data/timesketch.conf @@ -362,7 +362,9 @@ LLM_PROVIDER_CONFIGS = { # To use the Vertex AI provider you need to: # 1. Create and export a Service Account Key from the Google Cloud Console. # 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path - # to your service account private key file. + # to your service account private key file by adding it to the docker-compose.yml + # under environment: + # GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/.json # 3. Install the python libraries: $ pip3 install google-cloud-aiplatform # # IMPORTANT: Private keys must be kept secret. If you expose your private key it is @@ -372,7 +374,7 @@ LLM_PROVIDER_CONFIGS = { 'project_id': '', }, # To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/ - # pip install google-generativeai + # pip3 install google-generativeai 'aistudio': { 'api_key': '', 'model': 'gemini-2.0-flash-exp', @@ -384,4 +386,3 @@ DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv' PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q' EXAMPLES_NL2Q = '/etc/timesketch/nl2q/examples_nl2q' LLM_PROVIDER = '' - diff --git a/timesketch/lib/llms/interface.py b/timesketch/lib/llms/interface.py index 014e0887b4..a54699fac8 100644 --- a/timesketch/lib/llms/interface.py +++ b/timesketch/lib/llms/interface.py @@ -20,9 +20,10 @@ DEFAULT_TEMPERATURE = 0.1 DEFAULT_TOP_P = 0.1 -DEFAULT_TOP_K = 0 +DEFAULT_TOP_K = 1 DEFAULT_MAX_OUTPUT_TOKENS = 2048 DEFAULT_STREAM = False +DEFAULT_LOCATION = None class LLMProvider: @@ -37,6 +38,7 @@ def __init__( top_k: int = DEFAULT_TOP_K, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, stream: bool = DEFAULT_STREAM, + location: Optional[str] = DEFAULT_LOCATION, ): """Initialize the LLM provider. @@ -46,6 +48,7 @@ def __init__( top_k: The top_k to use for the response. max_output_tokens: The maximum number of output tokens to generate. stream: Whether to stream the response. + location: The cloud location/region to use for the provider. Attributes: config: The configuration for the LLM provider. @@ -59,6 +62,7 @@ def __init__( config["top_k"] = top_k config["max_output_tokens"] = max_output_tokens config["stream"] = stream + config["location"] = location # Load the LLM provider config from the Flask app config config_from_flask = current_app.config.get("LLM_PROVIDER_CONFIGS").get( diff --git a/timesketch/lib/llms/vertexai.py b/timesketch/lib/llms/vertexai.py index 2166e6baba..e4f25f7f7e 100644 --- a/timesketch/lib/llms/vertexai.py +++ b/timesketch/lib/llms/vertexai.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google Inc. All rights reserved. +# Copyright 2025 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,9 @@ # limitations under the License. """Vertex AI LLM provider.""" +import json +from typing import Optional + from timesketch.lib.llms import interface from timesketch.lib.llms import manager @@ -21,6 +24,7 @@ try: from google.cloud import aiplatform from vertexai.preview.generative_models import GenerativeModel + from vertexai.preview.generative_models import GenerationConfig except ImportError: has_required_deps = False @@ -30,31 +34,54 @@ class VertexAI(interface.LLMProvider): NAME = "vertexai" - def generate(self, prompt: str) -> str: + def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: """ Generate text using the Vertex AI service. Args: prompt: The prompt to use for the generation. - temperature: The temperature to use for the generation. - stream: Whether to stream the generation or not. + response_schema: An optional JSON schema to define the expected + response format. Returns: - The generated text as a string. + The generated text as a string (or parsed data if + response_schema is provided). """ aiplatform.init( project=self.config.get("project_id"), + location=self.config.get("location"), ) model = GenerativeModel(self.config.get("model")) + + if response_schema: + generation_config = GenerationConfig( + temperature=self.config.get("temperature"), + top_k=self.config.get("top_k"), + top_p=self.config.get("top_p"), + response_mime_type="application/json", + response_schema=response_schema, + ) + else: + generation_config = GenerationConfig( + temperature=self.config.get("temperature"), + top_k=self.config.get("top_k"), + top_p=self.config.get("top_p"), + ) + response = model.generate_content( prompt, - generation_config={ - "max_output_tokens": self.config.get("max_output_tokens"), - "temperature": self.config.get("temperature"), - }, + generation_config=generation_config, stream=self.config.get("stream"), ) + if response_schema: + try: + return json.loads(response.text) + except Exception as error: + raise ValueError( + f"Error JSON parsing text: {response.text}: {error}" + ) from error + return response.text