Skip to content

Commit

Permalink
Merge branch 'master' into 2025-01-31_cli_number_of_fields
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegeral authored Feb 7, 2025
2 parents e03ade0 + 5eef913 commit 11a69b0
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 119 deletions.
56 changes: 31 additions & 25 deletions data/timesketch.conf
Original file line number Diff line number Diff line change
Expand Up @@ -353,36 +353,42 @@ CONTEXT_LINKS_CONFIG_PATH = '/etc/timesketch/context_links.yaml'

# LLM provider configs
LLM_PROVIDER_CONFIGS = {
# To use the Ollama provider you need to download and run an Ollama server.
# See instructions at: https://ollama.ai/
'ollama': {
'server_url': 'http://localhost:11434',
'model': 'gemma:7b',
},
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
# Configure a LLM provider for a specific LLM enabled feature, or the
# default provider will be used.
# Supported LLM Providers:
# - ollama: Self-hosted, open-source.
# To use the Ollama provider you need to download and run an Ollama server.
# See instructions at: https://ollama.ai/
# - vertexai: Google Cloud Vertex AI. Requires Google Cloud Project.
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
#
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
# recommended to revoke it immediately from the Google Cloud Console.
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
},
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# pip3 install google-generativeai
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
# recommended to revoke it immediately from the Google Cloud Console.
# - aistudio: Google AI Studio (API key). Get API key from Google AI Studio website.
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# $ pip3 install google-generativeai
'nl2q': {
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
},
},
'default': {
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
},
}
}


# LLM nl2q configuration
DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv'
PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q'
EXAMPLES_NL2Q = '/etc/timesketch/nl2q/examples_nl2q'
LLM_PROVIDER = ''
29 changes: 13 additions & 16 deletions timesketch/api/v1/resources/nl2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,35 +178,33 @@ def post(self, sketch_id):
Returns:
JSON representing the LLM prediction.
"""
llm_provider = current_app.config.get("LLM_PROVIDER", "")
if not llm_provider:
logger.error("No LLM provider was defined in the main configuration file")
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)
form = request.json
if not form:
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"No JSON data provided",
)
abort(HTTP_STATUS_CODE_BAD_REQUEST, "No JSON data provided")

if "question" not in form:
abort(HTTP_STATUS_CODE_BAD_REQUEST, "The 'question' parameter is required!")

llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS")
if not llm_configs:
logger.error("No LLM provider configuration defined.")
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"The 'question' parameter is required!",
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)

question = form.get("question")
prompt = self.build_prompt(question, sketch_id)

result_schema = {
"name": "AI generated search query",
"query_string": None,
"error": None,
}

feature_name = "nl2q"
try:
llm = manager.LLMManager().get_provider(llm_provider)()
llm = manager.LLMManager.create_provider(feature_name=feature_name)
except Exception as e: # pylint: disable=broad-except
logger.error("Error LLM Provider: {}".format(e))
result_schema["error"] = (
Expand All @@ -223,7 +221,6 @@ def post(self, sketch_id):
"Please try again later!"
)
return jsonify(result_schema)
# The model sometimes output triple backticks that needs to be removed.
result_schema["query_string"] = prediction.strip("```")

result_schema["query_string"] = prediction.strip("```")
return jsonify(result_schema)
32 changes: 29 additions & 3 deletions timesketch/api/v1/resources/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
"""System settings."""

from flask import current_app
from flask import jsonify
import logging
from flask import current_app, jsonify
from flask_restful import Resource
from flask_login import login_required

logger = logging.getLogger("timesketch.system_settings")


class SystemSettingsResource(Resource):
"""Resource to get system settings."""
Expand All @@ -30,10 +32,34 @@ def get(self):
JSON object with system settings.
"""
# Settings from timesketch.conf to expose to the frontend clients.
settings_to_return = ["LLM_PROVIDER", "DFIQ_ENABLED"]
settings_to_return = ["DFIQ_ENABLED"]
result = {}

for setting in settings_to_return:
result[setting] = current_app.config.get(setting)

# Derive the default LLM provider from the new configuration.
# Expecting the "default" config to be a dict with exactly one key:
# the provider name.
llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {})
default_provider = None
default_conf = llm_configs.get("default")
if default_conf and isinstance(default_conf, dict) and len(default_conf) == 1:
default_provider = next(iter(default_conf))
result["LLM_PROVIDER"] = default_provider

# TODO(mvd): Remove by 2025/06/01 once all users have updated their config.
old_llm_provider = current_app.config.get("LLM_PROVIDER")
if (
old_llm_provider and "default" not in llm_configs
): # Basic check for old config
warning_message = (
"Your LLM configuration in timesketch.conf is outdated and may cause "
"issues with LLM features. "
"Please update your LLM_PROVIDER_CONFIGS section to the new format. "
"Refer to the documentation for the updated configuration structure."
)
result["llm_config_warning"] = warning_message
logger.warning(warning_message)

return jsonify(result)
27 changes: 17 additions & 10 deletions timesketch/api/v1/resources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,10 +1191,10 @@ class TestNl2qResource(BaseTest):

resource_url = "/api/v1/sketches/1/nl2q/"

@mock.patch("timesketch.lib.llms.manager.LLMManager")
@mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider")
@mock.patch("timesketch.api.v1.utils.run_aggregator")
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager):
def test_nl2q_prompt(self, mock_aggregator, mock_create_provider):
"""Test the prompt is created correctly."""

self.login()
Expand All @@ -1207,7 +1207,7 @@ def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager):
mock_aggregator.return_value = (mock_AggregationResult, {})
mock_llm = mock.Mock()
mock_llm.generate.return_value = "LLM generated query"
mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm
mock_create_provider.return_value = mock_llm
response = self.client.post(
self.resource_url,
data=json.dumps(data),
Expand Down Expand Up @@ -1313,7 +1313,8 @@ def test_nl2q_no_question(self):
def test_nl2q_wrong_llm_provider(self, mock_aggregator):
"""Test nl2q with llm provider that does not exist."""

self.app.config["LLM_PROVIDER"] = "DoesNotExists"
self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}}
self.login()
self.login()
data = dict(question="Question for LLM?")
mock_AggregationResult = mock.MagicMock()
Expand All @@ -1333,9 +1334,10 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator):

@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_no_llm_provider(self):
"""Test nl2q with no llm provider configured."""
"""Test nl2q with no LLM provider configured."""

del self.app.config["LLM_PROVIDER"]
if "LLM_PROVIDER_CONFIGS" in self.app.config:
del self.app.config["LLM_PROVIDER_CONFIGS"]
self.login()
data = dict(question="Question for LLM?")
response = self.client.post(
Expand Down Expand Up @@ -1371,10 +1373,10 @@ def test_nl2q_no_permission(self):
)
self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN)

@mock.patch("timesketch.lib.llms.manager.LLMManager")
@mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider")
@mock.patch("timesketch.api.v1.utils.run_aggregator")
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager):
def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider):
"""Test nl2q with llm error."""

self.login()
Expand All @@ -1387,13 +1389,15 @@ def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager):
mock_aggregator.return_value = (mock_AggregationResult, {})
mock_llm = mock.Mock()
mock_llm.generate.side_effect = Exception("Test exception")
mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm
mock_create_provider.return_value = mock_llm
response = self.client.post(
self.resource_url,
data=json.dumps(data),
content_type="application/json",
)
self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK)
self.assertEqual(
response.status_code, HTTP_STATUS_CODE_OK
) # Still expect 200 OK with error in JSON
data = json.loads(response.get_data(as_text=True))
self.assertIsNotNone(data.get("error"))

Expand All @@ -1405,6 +1409,9 @@ class SystemSettingsResourceTest(BaseTest):

def test_system_settings_resource(self):
"""Authenticated request to get system settings."""
self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"test": {}}}
self.app.config["DFIQ_ENABLED"] = False

self.login()
response = self.client.get(self.resource_url)
expected_response = {"DFIQ_ENABLED": False, "LLM_PROVIDER": "test"}
Expand Down
62 changes: 27 additions & 35 deletions timesketch/lib/llms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import string
from typing import Optional

from flask import current_app

DEFAULT_TEMPERATURE = 0.1
DEFAULT_TOP_P = 0.1
DEFAULT_TOP_K = 1
Expand All @@ -27,12 +25,19 @@


class LLMProvider:
"""Base class for LLM providers."""
"""
Base class for LLM providers.
The provider is instantiated with a configuration dictionary that
was extracted (by the manager) from timesketch.conf.
Subclasses should override the NAME attribute.
"""

NAME = "name"

def __init__(
self,
config: dict,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
Expand All @@ -43,47 +48,32 @@ def __init__(
"""Initialize the LLM provider.
Args:
temperature: The temperature to use for the response.
top_p: The top_p to use for the response.
top_k: The top_k to use for the response.
max_output_tokens: The maximum number of output tokens to generate.
stream: Whether to stream the response.
location: The cloud location/region to use for the provider.
config: A dictionary of provider-specific configuration options.
temperature: Temperature setting for text generation.
top_p: Top probability (p) value used for generation.
top_k: Top-k value used for generation.
max_output_tokens: Maximum number of tokens to generate in the output.
stream: Whether to enable streaming of the generated content.
location: An optional location parameter for the provider.
Attributes:
config: The configuration for the LLM provider.
Raises:
Exception: If the LLM provider is not configured.
"""
config = {}
config["temperature"] = temperature
config["top_p"] = top_p
config["top_k"] = top_k
config["max_output_tokens"] = max_output_tokens
config["stream"] = stream
config["location"] = location

# Load the LLM provider config from the Flask app config
config_from_flask = current_app.config.get("LLM_PROVIDER_CONFIGS").get(
self.NAME
)
if not config_from_flask:
raise Exception(f"{self.NAME} config not found")

config.update(config_from_flask)
self.config = config
self.config = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"max_output_tokens": max_output_tokens,
"stream": stream,
"location": location,
}
self.config.update(config)

def prompt_from_template(self, template: str, kwargs: dict) -> str:
"""Format a prompt from a template.
Args:
template: The template to format.
kwargs: The keyword arguments to format the template with.
Returns:
The formatted prompt.
"""
"""Format a prompt from a template."""
formatter = string.Formatter()
return formatter.format(template, **kwargs)

Expand All @@ -97,5 +87,7 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str:
Returns:
The generated response.
Subclasses must override this method.
"""
raise NotImplementedError()
Loading

0 comments on commit 11a69b0

Please sign in to comment.