From 642beae2b8f4d9626a3afc4d7b9a0bba6fbb0d5a Mon Sep 17 00:00:00 2001 From: ElieTaillard Date: Mon, 15 Apr 2024 17:39:45 +0200 Subject: [PATCH] Add possibility to use custom api address --- .env.example | 1 + ikabot/helpers/dns.py | 164 +++++++++++++++++++++++++++++------------- 2 files changed, 117 insertions(+), 48 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..6d7eb155 --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +CUSTOM_API_ADDRESS=http://127.0.0.1:5000 \ No newline at end of file diff --git a/ikabot/helpers/dns.py b/ikabot/helpers/dns.py index e059ab28..6d01fcde 100644 --- a/ikabot/helpers/dns.py +++ b/ikabot/helpers/dns.py @@ -3,10 +3,12 @@ import socket import struct + from ikabot.config import * from ikabot.helpers.process import run -def getDNSTXTRecordWithSocket(domain, DNS_server = '8.8.8.8'): + +def getDNSTXTRecordWithSocket(domain, DNS_server="8.8.8.8"): """Returns the TXT record from the DNS server for the given domain Parameters ---------- @@ -19,24 +21,25 @@ def getDNSTXTRecordWithSocket(domain, DNS_server = '8.8.8.8'): str TXT record """ + # DNS Query def build_query(domain): # Header Section - ID = struct.pack('>H', 0x1234) # Identifier: transaction ID - FLAGS = struct.pack('>H', 0x0100) # Standard query with recursion - QDCOUNT = struct.pack('>H', 0x0001) # One question - ANCOUNT = struct.pack('>H', 0x0000) # No answers - NSCOUNT = struct.pack('>H', 0x0000) # No authority records - ARCOUNT = struct.pack('>H', 0x0000) # No additional records + ID = struct.pack(">H", 0x1234) # Identifier: transaction ID + FLAGS = struct.pack(">H", 0x0100) # Standard query with recursion + QDCOUNT = struct.pack(">H", 0x0001) # One question + ANCOUNT = struct.pack(">H", 0x0000) # No answers + NSCOUNT = struct.pack(">H", 0x0000) # No authority records + ARCOUNT = struct.pack(">H", 0x0000) # No additional records header = ID + FLAGS + QDCOUNT + ANCOUNT + NSCOUNT + ARCOUNT # Question Section - question = b'' - for part in domain.split('.'): - question += struct.pack('B', len(part)) + part.encode('utf-8') - question += struct.pack('B', 0) # End of string - QTYPE = struct.pack('>H', 0x0010) # TXT record - QCLASS = struct.pack('>H', 0x0001) # IN class + question = b"" + for part in domain.split("."): + question += struct.pack("B", len(part)) + part.encode("utf-8") + question += struct.pack("B", 0) # End of string + QTYPE = struct.pack(">H", 0x0010) # TXT record + QCLASS = struct.pack(">H", 0x0001) # IN class question += QTYPE + QCLASS return header + question @@ -66,7 +69,7 @@ def parse_response(response): # Read the answer section while offset < len(response): # Read the name - if response[offset] == 0xc0: + if response[offset] == 0xC0: offset += 2 # Pointer to a name else: # Name in the form of a sequence of labels @@ -77,31 +80,32 @@ def parse_response(response): offset += length + 1 offset += 1 # End of the name - type = struct.unpack('>H', response[offset:offset+2])[0] + type = struct.unpack(">H", response[offset : offset + 2])[0] offset += 10 # Type (2 bytes) + Class (2 bytes) + TTL (4 bytes) + Data length (2 bytes) if type == 16: # TXT record - txt_length = struct.unpack('>H', response[offset-2:offset])[0] - txt_data = response[offset:offset + txt_length] + txt_length = struct.unpack(">H", response[offset - 2 : offset])[0] + txt_data = response[offset : offset + txt_length] # TXT records can be split into multiple strings txt_strings = [] while txt_data: string_length = txt_data[0] - txt_strings.append(txt_data[1:string_length+1].decode('utf-8')) - txt_data = txt_data[string_length+1:] - return ' '.join(txt_strings) + txt_strings.append(txt_data[1 : string_length + 1].decode("utf-8")) + txt_data = txt_data[string_length + 1 :] + return " ".join(txt_strings) else: # Skip this record and move to the next - data_length = struct.unpack('>H', response[offset-2:offset])[0] + data_length = struct.unpack(">H", response[offset - 2 : offset])[0] offset += data_length - raise ValueError('No TXT record found') + raise ValueError("No TXT record found") query = build_query(domain) response = send_query(query) - return 'http://' + parse_response(response) + return "http://" + parse_response(response) + -def getDNSTXTRecordWithNSlookup(domain, DNS_server = '8.8.8.8'): +def getDNSTXTRecordWithNSlookup(domain, DNS_server="8.8.8.8"): """Returns the TXT record from the DNS server for the given domain using the nslookup tool Parameters ---------- @@ -114,12 +118,15 @@ def getDNSTXTRecordWithNSlookup(domain, DNS_server = '8.8.8.8'): str TXT record """ - text = run(f'nslookup -q=txt {domain} {DNS_server}') + text = run(f"nslookup -q=txt {domain} {DNS_server}") parts = text.split('"') if len(parts) < 2: # the DNS output is not well formed - raise Exception(f"The command \"nslookup -q=txt {domain} {DNS_server}\" returned bad data: {text}") - return 'http://' + parts[1] + raise Exception( + f'The command "nslookup -q=txt {domain} {DNS_server}" returned bad data: {text}' + ) + return "http://" + parts[1] + def getAddressWithSocket(session, domain): """Makes multiple attempts to obtain the ikabot public API server address with the socket library @@ -129,19 +136,37 @@ def getAddressWithSocket(session, domain): server address """ try: - return getDNSTXTRecordWithSocket(domain, 'ns2.afraid.org') + return getDNSTXTRecordWithSocket(domain, "ns2.afraid.org") except Exception as e: - session.writeLog("Failed to obtain public API address from ns2.afraid.org, trying with 8.8.8.8: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from ns2.afraid.org, trying with 8.8.8.8: " + + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) try: - return getDNSTXTRecordWithSocket(domain, '8.8.8.8') + return getDNSTXTRecordWithSocket(domain, "8.8.8.8") except Exception as e: - session.writeLog("Failed to obtain public API address from 8.8.8.8, trying with 1.1.1.1: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from 8.8.8.8, trying with 1.1.1.1: " + + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) try: - return getDNSTXTRecordWithSocket(domain, '1.1.1.1') + return getDNSTXTRecordWithSocket(domain, "1.1.1.1") except Exception as e: - session.writeLog("Failed to obtain public API address from 1.1.1.1: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from 1.1.1.1: " + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) raise e + def getAddressWithNSlookup(session, domain): """Makes multiple attempts to obtain the ikabot public API server address with the nslookup tool if it's installed Returns @@ -150,21 +175,44 @@ def getAddressWithNSlookup(session, domain): server address """ try: - return getDNSTXTRecordWithNSlookup(domain, 'ns2.afraid.org') + return getDNSTXTRecordWithNSlookup(domain, "ns2.afraid.org") except Exception as e: - session.writeLog("Failed to obtain public API address from nslookup with ns2.afraid.org: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from nslookup with ns2.afraid.org: " + + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) try: - return getDNSTXTRecordWithNSlookup(domain, '8.8.8.8') + return getDNSTXTRecordWithNSlookup(domain, "8.8.8.8") except Exception as e: - session.writeLog("Failed to obtain public API address from nslookup with 8.8.8.8: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from nslookup with 8.8.8.8: " + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) try: - return getDNSTXTRecordWithNSlookup(domain, '1.1.1.1') + return getDNSTXTRecordWithNSlookup(domain, "1.1.1.1") except Exception as e: - session.writeLog("Failed to obtain public API address from nslookup with 1.1.1.1: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from nslookup with 1.1.1.1: " + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) raise e - -def getAddress(session = type('test', (object,), {'writeLog': lambda *args, **kwargs: print(str(args) + ' ' + str(kwargs))})(), domain = 'ikagod.twilightparadox.com'): + +def getAddress( + session=type( + "test", + (object,), + {"writeLog": lambda *args, **kwargs: print(str(args) + " " + str(kwargs))}, + )(), + domain="ikagod.twilightparadox.com", +): """Makes multiple attempts to obtain the ikabot public API server address Parameters ---------- @@ -175,19 +223,39 @@ def getAddress(session = type('test', (object,), {'writeLog': lambda *args, **kw str server address """ + custom_address = os.getenv("CUSTOM_API_ADDRESS") + if custom_address: + return custom_address try: address = getAddressWithSocket(session, domain) - assert '.' in address or ':' in address.replace('http://', ''), "Bad server address: " + address - return address.replace('/ikagod/ikabot','') + assert "." in address or ":" in address.replace("http://", ""), ( + "Bad server address: " + address + ) + return address.replace("/ikagod/ikabot", "") except Exception as e: - session.writeLog("Failed to obtain public API address from socket, falling back to nslookup: " + str(e), level=logLevels.WARN, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from socket, falling back to nslookup: " + + str(e), + level=logLevels.WARN, + module=__name__, + logTraceback=True, + ) try: address = getAddressWithNSlookup(session, domain) - assert '.' in address or ':' in address.replace('http://', ''), "Bad server address: " + address #address is either hostname, IPv4 or IPv6 - return address.replace('/ikagod/ikabot','') + assert "." in address or ":" in address.replace("http://", ""), ( + "Bad server address: " + address + ) # address is either hostname, IPv4 or IPv6 + return address.replace("/ikagod/ikabot", "") except Exception as e: - session.writeLog("Failed to obtain public API address from both socket and nslookup: " + str(e), level=logLevels.ERROR, module= __name__, logTraceback=True) + session.writeLog( + "Failed to obtain public API address from both socket and nslookup: " + + str(e), + level=logLevels.ERROR, + module=__name__, + logTraceback=True, + ) raise e + # session = type('test', (object,), {'writeLog': lambda *args, **kwargs: print(str(args) + ' ' + str(kwargs))})() # Useful mock session object for testing -# print(getAddress(session, 'ikagod.twilightparadox.com')) \ No newline at end of file +# print(getAddress(session, 'ikagod.twilightparadox.com'))