Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework settings #39

Merged
merged 9 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions admyral/actions/integrations/ai/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Annotated
from anthropic import Anthropic
from anthropic.types import TextBlock
from pydantic import BaseModel

from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.secret.secret import register_secret


@register_secret(secret_type="Anthropic")
class AnthropicSecret(BaseModel):
api_key: str


@action(
Expand Down Expand Up @@ -66,9 +73,9 @@ def anthropic_chat_completion(
# https://docs.anthropic.com/en/api/messages
# TODO: error handling
secret = ctx.get().secrets.get("ANTHROPIC_SECRET")
api_key = secret["api_key"]
secret = AnthropicSecret.model_validate(secret)

client = Anthropic(api_key=api_key)
client = Anthropic(api_key=secret.api_key)

model_params = {}
if top_p is not None:
Expand Down
22 changes: 15 additions & 7 deletions admyral/actions/integrations/ai/azure_openai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Annotated
from openai import AzureOpenAI
from pydantic import BaseModel

from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.secret.secret import register_secret


@register_secret(secret_type="Azure OpenAI")
class AzureOpenAISecret(BaseModel):
endpoint: str
api_key: str
deployment_name: str


# TODO: test
@action(
display_name="Chat Completion",
display_namespace="Azure OpenAI",
Expand Down Expand Up @@ -53,23 +61,23 @@ def azure_openai_chat_completion(
# TODO: add authentication via Entra ID: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
# TODO: error handling
secret = ctx.get().secrets.get("AZURE_OPENAI_SECRET")
endpoint = secret["endpoint"]
api_key = secret["api_key"]
model = secret["deployment_name"]
secret = AzureOpenAISecret.model_validate(secret)

client = AzureOpenAI(
api_version="2024-06-01", azure_endpoint=endpoint, api_key=api_key
api_version="2024-06-01", azure_endpoint=secret.endpoint, api_key=secret.api_key
)

model_params = {}
if top_p is not None:
model_params["top_p"] = top_p
if temperature is not None:
model_params["temperature"] = temperature
if stop_tokens is not None and not model.startswith("o1"):
if stop_tokens is not None and not secret.deployment_name.startswith("o1"):
model_params["stop"] = stop_tokens

chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], model=model, **model_params
messages=[{"role": "user", "content": prompt}],
model=secret.deployment_name,
**model_params,
)
return chat_completion.choices[0].message.content
11 changes: 9 additions & 2 deletions admyral/actions/integrations/ai/mistralai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Annotated
from mistralai.client import MistralClient
from pydantic import BaseModel

from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.secret.secret import register_secret


@register_secret(secret_type="Mistral AI")
class MistralAISecret(BaseModel):
api_key: str


@action(
Expand Down Expand Up @@ -58,9 +65,9 @@ def mistralai_chat_completion(
# https://docs.mistral.ai/api/#tag/chat
# TODO: error handling
secret = ctx.get().secrets.get("MISTRALAI_SECRET")
api_key = secret["api_key"]
secret = MistralAISecret.model_validate(secret)

client = MistralClient(api_key=api_key)
client = MistralClient(api_key=secret.api_key)

model_params = {}
if top_p is not None:
Expand Down
11 changes: 9 additions & 2 deletions admyral/actions/integrations/ai/openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Annotated
from openai import OpenAI
from pydantic import BaseModel

from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.secret.secret import register_secret


@register_secret(secret_type="OpenAI")
class OpenAISecret(BaseModel):
api_key: str


@action(
Expand Down Expand Up @@ -58,7 +65,7 @@ def openai_chat_completion(
# https://platform.openai.com/docs/api-reference/chat/create
# TODO: error handling
secret = ctx.get().secrets.get("OPENAI_SECRET")
api_key = secret["api_key"]
openai_secret = OpenAISecret.model_validate(secret)

model_params = {}
if top_p is not None:
Expand All @@ -68,7 +75,7 @@ def openai_chat_completion(
if stop_tokens is not None and not model.startswith("o1"):
model_params["stop"] = stop_tokens

client = OpenAI(api_key=api_key)
client = OpenAI(api_key=openai_secret.api_key)
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=model,
Expand Down
47 changes: 24 additions & 23 deletions admyral/actions/integrations/cases/jira.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from typing import Annotated, Literal
import base64
from httpx import Client
from pydantic import BaseModel

from admyral.typings import JsonValue
from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.utils.collections import is_empty
from admyral.secret.secret import register_secret


def get_jira_client(domain: str, email: str, api_key: str) -> Client:
api_key_base64 = base64.b64encode(f"{email}:{api_key}".encode()).decode()
@register_secret(secret_type="Jira")
class JiraSecret(BaseModel):
domain: str
email: str
api_key: str


def get_jira_client(secret: JiraSecret) -> Client:
api_key_base64 = base64.b64encode(
f"{secret.email}:{secret.api_key}".encode()
).decode()
return Client(
base_url=f"https://{domain}/rest/api/3",
base_url=f"https://{secret.domain}/rest/api/3",
headers={
"Authorization": f"Basic {api_key_base64}",
"Content-Type": "application/json",
Expand Down Expand Up @@ -112,9 +123,7 @@ def create_jira_issue(
# Atlassian Document Format: https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/#api-rest-api-3-issue-post
secret = ctx.get().secrets.get("JIRA_SECRET")
domain = secret["domain"]
email = secret["email"]
api_key = secret["api_key"]
secret = JiraSecret.model_validate(secret)

body = {
"fields": {
Expand Down Expand Up @@ -148,7 +157,7 @@ def create_jira_issue(
for key, value in custom_fields.items():
body["fields"][key] = value

with get_jira_client(domain, email, api_key) as client:
with get_jira_client(secret) as client:
response = client.post(
"/issue",
json=body,
Expand Down Expand Up @@ -182,11 +191,9 @@ def update_jira_issue_status(
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/#api-rest-api-3-issue-issueidorkey-transitions-post
# https://community.atlassian.com/t5/Jira-questions/How-do-i-find-a-transition-ID/qaq-p/2113213
secret = ctx.get().secrets.get("JIRA_SECRET")
domain = secret["domain"]
email = secret["email"]
api_key = secret["api_key"]
secret = JiraSecret.model_validate(secret)

with get_jira_client(domain, email, api_key) as client:
with get_jira_client(secret) as client:
response = client.post(
f"/issue/{issue_id_or_key}/transitions",
json={
Expand Down Expand Up @@ -222,11 +229,9 @@ def comment_jira_issue_status(
) -> None:
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/#api-rest-api-3-issue-issueidorkey-transitions-post
secret = ctx.get().secrets.get("JIRA_SECRET")
domain = secret["domain"]
email = secret["email"]
api_key = secret["api_key"]
secret = JiraSecret.model_validate(secret)

with get_jira_client(domain, email, api_key) as client:
with get_jira_client(secret) as client:
response = client.post(
f"/issue/{issue_id_or_key}/comment",
json={"body": comment},
Expand Down Expand Up @@ -258,11 +263,9 @@ def search_jira_issues(
) -> list[JsonValue]:
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issue-search/#api-rest-api-3-search-get
secret = ctx.get().secrets.get("JIRA_SECRET")
domain = secret["domain"]
email = secret["email"]
api_key = secret["api_key"]
secret = JiraSecret.model_validate(secret)

with get_jira_client(domain, email, api_key) as client:
with get_jira_client(secret) as client:
offset = 0
issues = []

Expand Down Expand Up @@ -320,11 +323,9 @@ def get_jira_audit_records(
):
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-audit-records/#api-rest-api-3-auditing-record-get
secret = ctx.get().secrets.get("JIRA_SECRET")
domain = secret["domain"]
email = secret["email"]
api_key = secret["api_key"]
secret = JiraSecret.model_validate(secret)

with get_jira_client(domain, email, api_key) as client:
with get_jira_client(secret) as client:
offset = 0
logs = []

Expand Down
19 changes: 13 additions & 6 deletions admyral/actions/integrations/cases/opsgenie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@
from httpx import Client
import time
import random
from pydantic import BaseModel

from admyral.typings import JsonValue
from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.secret.secret import register_secret


def get_opsgenie_client(instance: str, api_key: str) -> Client:
@register_secret(secret_type="OpsGenie")
class OpsGenieSecret(BaseModel):
api_key: str
instance: str | None


def get_opsgenie_client(secret: OpsGenieSecret) -> Client:
base_api_url = (
"https://api.eu.opsgenie.com"
if instance and instance.lower() == "eu"
if secret.instance and secret.instance.lower() == "eu"
else "https://api.opsgenie.com"
)
return Client(
base_url=base_api_url,
headers={
"Authorization": f"GenieKey {api_key}",
"Authorization": f"GenieKey {secret.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
},
Expand Down Expand Up @@ -125,8 +133,7 @@ def create_opsgenie_alert(
) -> JsonValue:
# https://docs.opsgenie.com/docs/alert-api#section-create-alert
opsgenie_secret = ctx.get().secrets.get("OPSGENIE_SECRET")
api_key = opsgenie_secret["api_key"]
instance = opsgenie_secret.get("instance")
opsgenie_secret = OpsGenieSecret.model_validate(opsgenie_secret)

body = {
"message": message,
Expand Down Expand Up @@ -156,7 +163,7 @@ def create_opsgenie_alert(
if note:
body["note"] = note

with get_opsgenie_client(instance, api_key) as client:
with get_opsgenie_client(opsgenie_secret) as client:
response = client.post(
"/v2/alerts",
json=body,
Expand Down
19 changes: 13 additions & 6 deletions admyral/actions/integrations/cases/pagerduty.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from typing import Annotated, Literal
from httpx import Client
from pydantic import BaseModel

from admyral.action import action, ArgumentMetadata
from admyral.context import ctx
from admyral.typings import JsonValue
from admyral.secret.secret import register_secret


def get_pagerduty_client(email: str, api_key: str) -> Client:
@register_secret(secret_type="PagerDuty")
class PagerDutySecret(BaseModel):
api_key: str
email: str


def get_pagerduty_client(secret: PagerDutySecret) -> Client:
return Client(
base_url="https://api.pagerduty.com",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Token token={api_key}",
"From": email,
"Authorization": f"Token token={secret.api_key}",
"From": secret.email,
},
)

Expand Down Expand Up @@ -63,8 +71,7 @@ def create_pagerduty_incident(
) -> JsonValue:
# https://developer.pagerduty.com/api-reference/a7d81b0e9200f-create-an-incident
secret = ctx.get().secrets.get("PAGERDUTY_SECRET")
api_key = secret["api_key"]
email = secret["email"]
secret = PagerDutySecret.model_validate(secret)

body = {
"incident": {
Expand All @@ -88,7 +95,7 @@ def create_pagerduty_incident(

# Note: ignoring incident key and escalation policy for now

with get_pagerduty_client(email, api_key) as client:
with get_pagerduty_client(secret) as client:
response = client.post(
"/incidents",
json=body,
Expand Down
10 changes: 6 additions & 4 deletions admyral/actions/integrations/cdr/ms_defender_for_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from admyral.actions.integrations.shared.ms_graph import (
ms_graph_list_alerts_v2,
MsSecurityGraphAlertServiceSource,
AzureSecret,
)


# TODO: OCSF schema mapping
@action(
display_name="List Alerts",
display_namespace="Microsoft Defender for Cloud",
Expand Down Expand Up @@ -41,10 +41,12 @@ def list_ms_defender_for_cloud_alerts(
) -> list[dict[str, JsonValue]]:
# https://learn.microsoft.com/en-us/rest/api/defenderforcloud/alerts/list?view=rest-defenderforcloud-2022-01-01&tabs=HTTP
secret = ctx.get().secrets.get("AZURE_SECRET")
secret = AzureSecret.model_validate(secret)

return ms_graph_list_alerts_v2(
tenant_id=secret["tenant_id"],
client_id=secret["client_id"],
client_secret=secret["client_secret"],
tenant_id=secret.tenant_id,
client_id=secret.client_id,
client_secret=secret.client_secret,
start_time=start_time,
end_time=end_time,
limit=limit,
Expand Down
Loading
Loading