diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..df3d46b --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +.envrc +/dist +/CHANGELOG.md +/script/build + +# VS Code +.vscode + +# IntelliJ +.idea + +# macOS +.DS_Store + +# vim +*.swp diff --git a/databend_driver/__init__.py b/databend_driver/__init__.py new file mode 100644 index 0000000..d775991 --- /dev/null +++ b/databend_driver/__init__.py @@ -0,0 +1,7 @@ +from .client import Client +from .connection import Connection + +VERSION = (0, 2, 4) +__version__ = '.'.join(str(x) for x in VERSION) + +__all__ = ['Client', 'Connection'] diff --git a/databend_driver/client.py b/databend_driver/client.py new file mode 100644 index 0000000..011a268 --- /dev/null +++ b/databend_driver/client.py @@ -0,0 +1,185 @@ +from urllib.parse import urlparse, parse_qs, unquote +from time import time +from databend_driver.connection import Connection +from databend_driver.util.helper import asbool +from databend_driver.result import QueryResult +import json + + +class Client(object): + """ + Client for communication with the databend http server. + Single connection is established per each connected instance of the client. + """ + + def __init__(self, *args, **kwargs): + self.settings = (kwargs.pop('settings', None) or {}).copy() + self.connection = Connection(*args, **kwargs) + self.query_result_cls = QueryResult + + def __enter__(self): + return self + + def disconnect(self): + self.disconnect_connection() + + def disconnect_connection(self): + self.connection.disconnect() + + def data_generator(self, raw_data): + + while raw_data['next_uri'] is not None: + try: + raw_data = self.receive_data(raw_data['next_uri']) + if not raw_data: + break + yield raw_data + + except (Exception, KeyboardInterrupt): + self.disconnect() + raise + + def receive_data(self, next_uri: str): + resp = self.connection.next_page(next_uri) + raw_data = json.loads(json.loads(resp.content)) + self.connection.check_error(raw_data) + return raw_data + + def receive_result(self, query, query_id=None, with_column_types=False): + raw_data = self.connection.query(query, None) + self.connection.check_error(raw_data) + columns_types = [] + fields = raw_data["schema"]["fields"] + for field in fields: + columns_types.append(field["data_type"]["type"]) + if raw_data['next_uri'] is None and with_column_types: + return raw_data['data'], columns_types + elif raw_data['next_uri'] is None: + return raw_data['data'] + + gen = self.data_generator(raw_data) + result = self.query_result_cls( + gen, with_column_types=with_column_types) + return result.get_result() + + def iter_receive_result(self, query, with_column_types=False): + raw_data = self.connection.query(query, None) + self.connection.check_error(raw_data) + if raw_data['next_uri'] is None: + return raw_data + gen = self.data_generator(raw_data) + result = self.query_result_cls( + gen, with_column_types=with_column_types) + for rows in result.get_result(): + for row in rows: + yield row + + def execute(self, query, params=None, with_column_types=False, + query_id=None, settings=None): + """ + Executes query. + + Establishes new connection if it wasn't established yet. + After query execution connection remains intact for next queries. + If connection can't be reused it will be closed and new connection will + be created. + + :param query: query that will be send to server. + :param params: substitution parameters for SELECT queries and data for + INSERT queries. Data for INSERT can be `list`, `tuple` + or :data:`~types.GeneratorType`. + Defaults to ``None`` (no parameters or data). + :param with_column_types: if specified column names and types will be + returned alongside with result. + Defaults to ``False``. + :param query_id: the query identifier. If no query id specified + ClickHouse server will generate it. + :param settings: dictionary of query settings. + Defaults to ``None`` (no additional settings). + + :return: * number of inserted rows for INSERT queries with data. + Returning rows count from INSERT FROM SELECT is not + supported. + * if `with_column_types=False`: `list` of `tuples` with + rows/columns. + * if `with_column_types=True`: `tuple` of 2 elements: + * The first element is `list` of `tuples` with + rows/columns. + * The second element information is about columns: names + and types. + """ + + rv = self.process_ordinary_query( + query, params=params, with_column_types=with_column_types, + query_id=query_id) + return rv + + def process_ordinary_query(self, query, params=None, with_column_types=False, + query_id=None): + return self.receive_result(query, query_id=query_id, with_column_types=with_column_types, ) + + @classmethod + def from_url(cls, url): + """ + Return a client configured from the given URL. + + For example:: + + http://[user:password]@localhost:9000/default + http://[user:password]@localhost:9440/default + + Any additional querystring arguments will be passed along to + the Connection class's initializer. + """ + url = urlparse(url) + + settings = {} + kwargs = {} + + host = url.hostname + port = url.port if url.port is not None else 443 + + if url.port is not None: + kwargs['port'] = url.port + port = url.port + + path = url.path.replace('/', '', 1) + if path: + kwargs['database'] = path + + if url.username is not None: + kwargs['user'] = unquote(url.username) + + if url.password is not None: + kwargs['password'] = unquote(url.password) + + if url.scheme == 'https': + kwargs['secure'] = True + + for name, value in parse_qs(url.query).items(): + if not value or not len(value): + continue + if url.scheme == 'https': + kwargs['secure'] = True + + timeouts = { + 'connect_timeout', + 'send_receive_timeout', + 'sync_request_timeout' + } + + value = value[0] + + if name == 'client_name': + kwargs[name] = value + elif name == 'secure': + kwargs[name] = asbool(value) + elif name in timeouts: + kwargs[name] = float(value) + else: + settings[name] = value + + if settings: + kwargs['settings'] = settings + + return cls(host, **kwargs) diff --git a/databend_driver/connection.py b/databend_driver/connection.py new file mode 100644 index 0000000..ced6676 --- /dev/null +++ b/databend_driver/connection.py @@ -0,0 +1,206 @@ +import json +import os +import base64 +import time + +import environs +import requests +from mysql.connector.errors import Error +from . import log +from . import defines + +headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} + + +def format_result(results): + res = "" + if results is None: + return "" + + for line in results: + buf = "" + for item in line: + if isinstance(item, bool): + item = str.lower(str(item)) + if buf == "": + buf = str(item) + else: + buf = buf + " " + str(item) # every item seperate by space + if len(buf) == 0: + # empty line in results will replace with tab + buf = "\t" + res = res + buf + "\n" + return res + + +def get_data_type(field): + if 'data_type' in field: + if 'inner' in field['data_type']: + return field['data_type']['inner']['type'] + else: + return field['data_type']['type'] + + +def get_query_options(response): + ret = "" + if get_error(response) is not None: + return ret + for field in response['schema']['fields']: + typ = str.lower(get_data_type(field)) + log.debug(f"type:{typ}") + if "int" in typ: + ret = ret + "I" + elif "float" in typ or "double" in typ: + ret = ret + "F" + elif "bool" in typ: + ret = ret + "B" + else: + ret = ret + "T" + return ret + + +def get_next_uri(response): + if "next_uri" in response: + return response['next_uri'] + return None + + +def get_result(response): + return response['data'] + + +def get_error(response): + if response['error'] is None: + return None + + # Wrap errno into msg, for result check + return Error(msg=response['error']['message'], + errno=response['error']['code']) + + +class Connection(object): + # Databend http handler doc: https://databend.rs/doc/reference/api/rest + + # Call connect(**driver) + # driver is a dict contains: + # { + # 'user': 'root', + # 'host': '127.0.0.1', + # 'port': 3307, + # 'database': 'default' + # } + def __init__(self, host, port=None, user=defines.DEFAULT_USER, password=defines.DEFAULT_PASSWORD, + database=defines.DEFAULT_DATABASE, secure=False, ): + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self.secure = secure + self.session_max_idle_time = defines.DEFAULT_SESSION_IDLE_TIME + self.session = {} + self.additional_headers = dict() + self.query_option = None + e = environs.Env() + if os.getenv("ADDITIONAL_HEADERS") is not None: + self.additional_headers = e.dict("ADDITIONAL_HEADERS") + + def make_headers(self): + if "Authorization" not in self.additional_headers: + return { + **headers, "Authorization": + "Basic " + base64.b64encode("{}:{}".format( + self.user, self.password).encode(encoding="utf-8")).decode() + } + else: + return {**headers, **self.additional_headers} + + def get_description(self): + return '{}:{}'.format(self.host, self.port) + + def disconnect(self): + self._session = {} + + def query(self, statement, session): + url = self.format_url() + log.logger.debug(f"http sql: {statement}") + query_sql = {'sql': statement, "string_fields": True} + if session is not None: + query_sql['session'] = session + log.logger.debug(f"http headers {self.make_headers()}") + response = requests.post(url, + data=json.dumps(query_sql), + headers=self.make_headers(), verify=False) + + try: + return json.loads(response.content) + except Exception as err: + log.logger.error( + f"http error, SQL: {statement}\ncontent: {response.content}\nerror msg:{str(err)}" + ) + raise + + def format_url(self): + return f"http://{self.host}:{self.port}/v1/query/" + + def reset_session(self): + self._session = {} + + def next_page(self, next_uri): + url = "http://{}:{}{}".format(self.host, self.port, next_uri) + return requests.get(url=url, headers=self.make_headers()) + + # return a list of response util empty next_uri + def query_with_session(self, statement): + current_session = self._session + response_list = list() + response = self.query(statement, current_session) + log.logger.debug(f"response content: {response}") + response_list.append(response) + start_time = time.time() + time_limit = 12 + session = response['session'] + if session: + self._session = session + while response['next_uri'] is not None: + resp = self.next_page(response['next_uri']) + response = json.loads(json.loads(resp.content)) + log.logger.debug(f"Sql in progress, fetch next_uri content: {response}") + self.check_error(response) + session = response['session'] + if session: + self._session = session + response_list.append(response) + if time.time() - start_time > time_limit: + log.logger.warning( + f"after waited for {time_limit} secs, query still not finished (next uri not none)!" + ) + return response_list + + def check_error(self, resp): + error = get_error(resp) + if error: + raise error + + def fetch_all(self, statement): + resp_list = self.query_with_session(statement) + if len(resp_list) == 0: + log.logger.warning("fetch all with empty results") + return None + self._query_option = get_query_options(resp_list[0]) # record schema + data_list = list() + for response in resp_list: + data = get_result(response) + if len(data) != 0: + data_list.extend(data) + return data_list + + def get_query_option(self): + return self._query_option + +# +# if __name__ == '__main__': +# from config import http_config +# connector = HttpConnector() +# connector.connect(**http_config) +# connector.query_without_session("show databases;") diff --git a/databend_driver/defines.py b/databend_driver/defines.py new file mode 100644 index 0000000..f8af7fe --- /dev/null +++ b/databend_driver/defines.py @@ -0,0 +1,9 @@ +DEFAULT_DATABASE = 'default' +DEFAULT_USER = 'root' +DEFAULT_PASSWORD = '' +DEFAULT_SESSION_IDLE_TIME = 30 + +DBMS_NAME = 'Databend' +CLIENT_NAME = 'python-driver' + +STRINGS_ENCODING = 'utf-8' diff --git a/databend_driver/log.py b/databend_driver/log.py new file mode 100644 index 0000000..772b57c --- /dev/null +++ b/databend_driver/log.py @@ -0,0 +1,16 @@ +import logging + +logger = logging.getLogger(__name__) + + +log_priorities = ( + 'Unknown', + 'Fatal', + 'Critical', + 'Error', + 'Warning', + 'Notice', + 'Information', + 'Debug', + 'Trace' +) diff --git a/databend_driver/result.py b/databend_driver/result.py new file mode 100644 index 0000000..9406511 --- /dev/null +++ b/databend_driver/result.py @@ -0,0 +1,38 @@ +class QueryResult(object): + """ + Stores query result from multiple blocks. + """ + + def __init__( + self, data_generator, + with_column_types=False): + self.data_generator = data_generator + self.with_column_types = with_column_types + + self.data = [] + self.columns_with_types = [] + self.columns = [] + + super(QueryResult, self).__init__() + + def store(self, rawData: dict): + self.data = rawData.get("data") + fields = rawData.get("schema")["fields"] + for field in fields: + self.columns_with_types.append(field["data_type"]["type"]) + self.columns.append(field["name"]) + + def get_result(self): + """ + :return: stored query result. + """ + + for d in self.data_generator: + self.store(d) + + data = self.data + + if self.with_column_types: + return data, self.columns_with_types + else: + return data diff --git a/databend_driver/util/helper.py b/databend_driver/util/helper.py new file mode 100644 index 0000000..7ab28fe --- /dev/null +++ b/databend_driver/util/helper.py @@ -0,0 +1,57 @@ +from itertools import islice, tee + + +def chunks(seq, n): + # islice is MUCH slower than slice for lists and tuples. + if isinstance(seq, (list, tuple)): + i = 0 + item = seq[i:i+n] + while item: + yield list(item) + i += n + item = seq[i:i+n] + + else: + it = iter(seq) + item = list(islice(it, n)) + while item: + yield item + item = list(islice(it, n)) + + +def pairwise(iterable): + a, b = tee(iterable) + next(b, None) + return zip(a, b) + + +def column_chunks(columns, n): + for column in columns: + if not isinstance(column, (list, tuple)): + raise TypeError( + 'Unsupported column type: {}. list or tuple is expected.' + .format(type(column)) + ) + + # create chunk generator for every column + g = [chunks(column, n) for column in columns] + + while True: + # get next chunk for every column + item = [next(column, []) for column in g] + if not any(item): + break + yield item + + +# from paste.deploy.converters +def asbool(obj): + if isinstance(obj, str): + obj = obj.strip().lower() + if obj in ['true', 'yes', 'on', 'y', 't', '1']: + return True + elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + return False + else: + raise ValueError('String is not true/false: %r' % obj) + return bool(obj) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..bd504fb --- /dev/null +++ b/setup.cfg @@ -0,0 +1,9 @@ +[db] +host=localhost +port=8081 +database=books +user=root +password= + +[log] +level=ERROR diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__pycache__/testcase.cpython-39.pyc b/tests/__pycache__/testcase.cpython-39.pyc new file mode 100644 index 0000000..40d70a5 Binary files /dev/null and b/tests/__pycache__/testcase.cpython-39.pyc differ diff --git a/tests/log.py b/tests/log.py new file mode 100644 index 0000000..bbb4241 --- /dev/null +++ b/tests/log.py @@ -0,0 +1,27 @@ +from logging.config import dictConfig + + +def configure(level): + dictConfig({ + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + 'format': '%(asctime)s %(levelname)-8s %(name)s: %(message)s' + }, + }, + 'handlers': { + 'default': { + 'level': level, + 'formatter': 'standard', + 'class': 'logging.StreamHandler', + }, + }, + 'loggers': { + '': { + 'handlers': ['default'], + 'level': level, + 'propagate': True + }, + } + }) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..23a417d --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,25 @@ +from databend_driver.client import Client +from tests.testcase import TestCase + + +class ClientFromUrlTestCase(TestCase): + def assertHostsEqual(self, client, another, msg=None): + self.assertEqual(client.connection.host, another, msg=msg) + + def test_simple(self): + c = Client.from_url('https://app.databend.com:443') + + self.assertHostsEqual(c, 'app.databend.com') + self.assertEqual(c.connection.database, 'default') + self.assertEqual(c.connection.user, 'root') + + c = Client.from_url('https://host:443/db') + + self.assertHostsEqual(c, 'host') + self.assertEqual(c.connection.database, 'db') + self.assertEqual(c.connection.password, '') + + def test_ordinary_query(self): + c = Client.from_url('http://localhost:8081') + r = c.execute("select 1", with_column_types=False) + self.assertEqual(r, [['1']]) diff --git a/tests/testcase.py b/tests/testcase.py new file mode 100644 index 0000000..5554c95 --- /dev/null +++ b/tests/testcase.py @@ -0,0 +1,39 @@ +import configparser +from contextlib import contextmanager +import subprocess +from unittest import TestCase + +from databend_driver.client import Client +from tests import log + +file_config = configparser.ConfigParser() +file_config.read(['../setup.cfg']) + +log.configure(file_config.get('log', 'level')) + + +class BaseTestCase(TestCase): + required_server_version = None + server_version = None + + host = file_config.get('db', 'host') + port = file_config.getint('db', 'port') + database = file_config.get('db', 'database') + user = file_config.get('db', 'user') + password = file_config.get('db', 'password') + + client = None + client_kwargs = None + + def _create_client(self, **kwargs): + client_kwargs = { + 'port': self.port, + 'database': self.database, + 'user': self.user, + 'password': self.password + } + client_kwargs.update(kwargs) + return Client(self.host, **client_kwargs) + + def created_client(self, **kwargs): + return self._create_client(**kwargs)