Skip to content

Commit

Permalink
Add possibility to use custom api address (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElieTaillard authored Apr 15, 2024
1 parent 79ed9f1 commit 06fdee4
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 48 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CUSTOM_API_ADDRESS=http://127.0.0.1:5000
164 changes: 116 additions & 48 deletions ikabot/helpers/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import socket
import struct

from ikabot.config import *
from ikabot.helpers.process import run

def getDNSTXTRecordWithSocket(domain, DNS_server = '8.8.8.8'):

def getDNSTXTRecordWithSocket(domain, DNS_server="8.8.8.8"):
"""Returns the TXT record from the DNS server for the given domain
Parameters
----------
Expand All @@ -19,24 +21,25 @@ def getDNSTXTRecordWithSocket(domain, DNS_server = '8.8.8.8'):
str
TXT record
"""

# DNS Query
def build_query(domain):
# Header Section
ID = struct.pack('>H', 0x1234) # Identifier: transaction ID
FLAGS = struct.pack('>H', 0x0100) # Standard query with recursion
QDCOUNT = struct.pack('>H', 0x0001) # One question
ANCOUNT = struct.pack('>H', 0x0000) # No answers
NSCOUNT = struct.pack('>H', 0x0000) # No authority records
ARCOUNT = struct.pack('>H', 0x0000) # No additional records
ID = struct.pack(">H", 0x1234) # Identifier: transaction ID
FLAGS = struct.pack(">H", 0x0100) # Standard query with recursion
QDCOUNT = struct.pack(">H", 0x0001) # One question
ANCOUNT = struct.pack(">H", 0x0000) # No answers
NSCOUNT = struct.pack(">H", 0x0000) # No authority records
ARCOUNT = struct.pack(">H", 0x0000) # No additional records
header = ID + FLAGS + QDCOUNT + ANCOUNT + NSCOUNT + ARCOUNT

# Question Section
question = b''
for part in domain.split('.'):
question += struct.pack('B', len(part)) + part.encode('utf-8')
question += struct.pack('B', 0) # End of string
QTYPE = struct.pack('>H', 0x0010) # TXT record
QCLASS = struct.pack('>H', 0x0001) # IN class
question = b""
for part in domain.split("."):
question += struct.pack("B", len(part)) + part.encode("utf-8")
question += struct.pack("B", 0) # End of string
QTYPE = struct.pack(">H", 0x0010) # TXT record
QCLASS = struct.pack(">H", 0x0001) # IN class
question += QTYPE + QCLASS

return header + question
Expand Down Expand Up @@ -66,7 +69,7 @@ def parse_response(response):
# Read the answer section
while offset < len(response):
# Read the name
if response[offset] == 0xc0:
if response[offset] == 0xC0:
offset += 2 # Pointer to a name
else:
# Name in the form of a sequence of labels
Expand All @@ -77,31 +80,32 @@ def parse_response(response):
offset += length + 1
offset += 1 # End of the name

type = struct.unpack('>H', response[offset:offset+2])[0]
type = struct.unpack(">H", response[offset : offset + 2])[0]
offset += 10 # Type (2 bytes) + Class (2 bytes) + TTL (4 bytes) + Data length (2 bytes)

if type == 16: # TXT record
txt_length = struct.unpack('>H', response[offset-2:offset])[0]
txt_data = response[offset:offset + txt_length]
txt_length = struct.unpack(">H", response[offset - 2 : offset])[0]
txt_data = response[offset : offset + txt_length]
# TXT records can be split into multiple strings
txt_strings = []
while txt_data:
string_length = txt_data[0]
txt_strings.append(txt_data[1:string_length+1].decode('utf-8'))
txt_data = txt_data[string_length+1:]
return ' '.join(txt_strings)
txt_strings.append(txt_data[1 : string_length + 1].decode("utf-8"))
txt_data = txt_data[string_length + 1 :]
return " ".join(txt_strings)
else:
# Skip this record and move to the next
data_length = struct.unpack('>H', response[offset-2:offset])[0]
data_length = struct.unpack(">H", response[offset - 2 : offset])[0]
offset += data_length

raise ValueError('No TXT record found')
raise ValueError("No TXT record found")

query = build_query(domain)
response = send_query(query)
return 'http://' + parse_response(response)
return "http://" + parse_response(response)


def getDNSTXTRecordWithNSlookup(domain, DNS_server = '8.8.8.8'):
def getDNSTXTRecordWithNSlookup(domain, DNS_server="8.8.8.8"):
"""Returns the TXT record from the DNS server for the given domain using the nslookup tool
Parameters
----------
Expand All @@ -114,12 +118,15 @@ def getDNSTXTRecordWithNSlookup(domain, DNS_server = '8.8.8.8'):
str
TXT record
"""
text = run(f'nslookup -q=txt {domain} {DNS_server}')
text = run(f"nslookup -q=txt {domain} {DNS_server}")
parts = text.split('"')
if len(parts) < 2:
# the DNS output is not well formed
raise Exception(f"The command \"nslookup -q=txt {domain} {DNS_server}\" returned bad data: {text}")
return 'http://' + parts[1]
raise Exception(
f'The command "nslookup -q=txt {domain} {DNS_server}" returned bad data: {text}'
)
return "http://" + parts[1]


def getAddressWithSocket(session, domain):
"""Makes multiple attempts to obtain the ikabot public API server address with the socket library
Expand All @@ -129,19 +136,37 @@ def getAddressWithSocket(session, domain):
server address
"""
try:
return getDNSTXTRecordWithSocket(domain, 'ns2.afraid.org')
return getDNSTXTRecordWithSocket(domain, "ns2.afraid.org")
except Exception as e:
session.writeLog("Failed to obtain public API address from ns2.afraid.org, trying with 8.8.8.8: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from ns2.afraid.org, trying with 8.8.8.8: "
+ str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
try:
return getDNSTXTRecordWithSocket(domain, '8.8.8.8')
return getDNSTXTRecordWithSocket(domain, "8.8.8.8")
except Exception as e:
session.writeLog("Failed to obtain public API address from 8.8.8.8, trying with 1.1.1.1: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from 8.8.8.8, trying with 1.1.1.1: "
+ str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
try:
return getDNSTXTRecordWithSocket(domain, '1.1.1.1')
return getDNSTXTRecordWithSocket(domain, "1.1.1.1")
except Exception as e:
session.writeLog("Failed to obtain public API address from 1.1.1.1: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from 1.1.1.1: " + str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
raise e


def getAddressWithNSlookup(session, domain):
"""Makes multiple attempts to obtain the ikabot public API server address with the nslookup tool if it's installed
Returns
Expand All @@ -150,21 +175,44 @@ def getAddressWithNSlookup(session, domain):
server address
"""
try:
return getDNSTXTRecordWithNSlookup(domain, 'ns2.afraid.org')
return getDNSTXTRecordWithNSlookup(domain, "ns2.afraid.org")
except Exception as e:
session.writeLog("Failed to obtain public API address from nslookup with ns2.afraid.org: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from nslookup with ns2.afraid.org: "
+ str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
try:
return getDNSTXTRecordWithNSlookup(domain, '8.8.8.8')
return getDNSTXTRecordWithNSlookup(domain, "8.8.8.8")
except Exception as e:
session.writeLog("Failed to obtain public API address from nslookup with 8.8.8.8: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from nslookup with 8.8.8.8: " + str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
try:
return getDNSTXTRecordWithNSlookup(domain, '1.1.1.1')
return getDNSTXTRecordWithNSlookup(domain, "1.1.1.1")
except Exception as e:
session.writeLog("Failed to obtain public API address from nslookup with 1.1.1.1: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from nslookup with 1.1.1.1: " + str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
raise e


def getAddress(session = type('test', (object,), {'writeLog': lambda *args, **kwargs: print(str(args) + ' ' + str(kwargs))})(), domain = 'ikagod.twilightparadox.com'):

def getAddress(
session=type(
"test",
(object,),
{"writeLog": lambda *args, **kwargs: print(str(args) + " " + str(kwargs))},
)(),
domain="ikagod.twilightparadox.com",
):
"""Makes multiple attempts to obtain the ikabot public API server address
Parameters
----------
Expand All @@ -175,19 +223,39 @@ def getAddress(session = type('test', (object,), {'writeLog': lambda *args, **kw
str
server address
"""
custom_address = os.getenv("CUSTOM_API_ADDRESS")
if custom_address:
return custom_address
try:
address = getAddressWithSocket(session, domain)
assert '.' in address or ':' in address.replace('http://', ''), "Bad server address: " + address
return address.replace('/ikagod/ikabot','')
assert "." in address or ":" in address.replace("http://", ""), (
"Bad server address: " + address
)
return address.replace("/ikagod/ikabot", "")
except Exception as e:
session.writeLog("Failed to obtain public API address from socket, falling back to nslookup: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from socket, falling back to nslookup: "
+ str(e),
level=logLevels.WARN,
module=__name__,
logTraceback=True,
)
try:
address = getAddressWithNSlookup(session, domain)
assert '.' in address or ':' in address.replace('http://', ''), "Bad server address: " + address #address is either hostname, IPv4 or IPv6
return address.replace('/ikagod/ikabot','')
assert "." in address or ":" in address.replace("http://", ""), (
"Bad server address: " + address
) # address is either hostname, IPv4 or IPv6
return address.replace("/ikagod/ikabot", "")
except Exception as e:
session.writeLog("Failed to obtain public API address from both socket and nslookup: " + str(e), level=logLevels.ERROR, module= __name__, logTraceback=True)
session.writeLog(
"Failed to obtain public API address from both socket and nslookup: "
+ str(e),
level=logLevels.ERROR,
module=__name__,
logTraceback=True,
)
raise e


# session = type('test', (object,), {'writeLog': lambda *args, **kwargs: print(str(args) + ' ' + str(kwargs))})() # Useful mock session object for testing
# print(getAddress(session, 'ikagod.twilightparadox.com'))
# print(getAddress(session, 'ikagod.twilightparadox.com'))

0 comments on commit 06fdee4

Please sign in to comment.