Skip to content

Commit

Permalink
Eliminate deprecated Unicode API
Browse files Browse the repository at this point in the history
  • Loading branch information
mkleehammer committed Aug 26, 2023
1 parent 1f99534 commit 5c1f1c0
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 303 deletions.
10 changes: 10 additions & 0 deletions HACKING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,13 @@ If a segmentation fault occurs while running tests, pytest will have eaten the o

python setup.py build_ext --inplace -D PYODBC_TRACE
pytest test/test_postgresql.py -vxk test_text -vs


# Notes

## uint16_t

You'll notice we use uint16_t instead of SQLWCHAR. The unixODBC headers would define SQLWCHAR
as wchar_t even when wchar_t as defined by the C library as uint32_t. The data in the buffer
was still 16 bit however.

7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

VERSION = '5.0.0'

import sys, os, re, shlex
import sys, os, re, shlex, subprocess
from os.path import exists, abspath, dirname, join, isdir, relpath, expanduser
from inspect import cleandoc

Expand All @@ -13,6 +13,11 @@
from configparser import ConfigParser


def _run(cmd):
return subprocess.run(cmd, check=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
encoding='utf_8', shell=True).stdout


def main():
settings = get_compiler_settings()

Expand Down
87 changes: 16 additions & 71 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,16 @@ static char* StrDup(const char* text) {
}


static bool Connect(PyObject* pConnectString, HDBC hdbc, bool fAnsi, long timeout,
Object& encoding)
static bool Connect(PyObject* pConnectString, HDBC hdbc, long timeout, Object& encoding)
{
// This should have been checked by the global connect function.
assert(PyUnicode_Check(pConnectString) || PyUnicode_Check(pConnectString));

// The driver manager determines if the app is a Unicode app based on whether we call SQLDriverConnectA or
// SQLDriverConnectW. Some drivers, notably Microsoft Access/Jet, change their behavior based on this, so we try
// the Unicode version first. (The Access driver only supports Unicode text, but SQLDescribeCol returns SQL_CHAR
// instead of SQL_WCHAR if we connect with the ANSI version. Obviously this causes lots of errors since we believe
// what it tells us (SQL_CHAR).)

// Python supports only UCS-2 and UCS-4, so we shouldn't need to worry about receiving surrogate pairs. However,
// Windows does use UCS-16, so it is possible something would be misinterpreted as one. We may need to examine
// this more.
assert(PyUnicode_Check(pConnectString));

SQLRETURN ret;

if (timeout > 0)
{
Py_BEGIN_ALLOW_THREADS
ret = SQLSetConnectAttr(hdbc, SQL_ATTR_LOGIN_TIMEOUT, (SQLPOINTER)(uintptr_t)timeout, SQL_IS_UINTEGER);
ret = SQLSetConnectAttrW(hdbc, SQL_ATTR_LOGIN_TIMEOUT, (SQLPOINTER)(uintptr_t)timeout, SQL_IS_UINTEGER);
Py_END_ALLOW_THREADS
if (!SQL_SUCCEEDED(ret))
RaiseErrorFromHandle(0, "SQLSetConnectAttr(SQL_ATTR_LOGIN_TIMEOUT)", hdbc, SQL_NULL_HANDLE);
Expand All @@ -96,28 +84,12 @@ static bool Connect(PyObject* pConnectString, HDBC hdbc, bool fAnsi, long timeou
}
}

if (!fAnsi)
{
// I want to call the W version when possible since the driver can use it as an
// indication that we can handle Unicode.

SQLWChar wchar(pConnectString, szEncoding ? szEncoding : ENCSTR_UTF16NE);
if (!wchar.isValid())
return false;

Py_BEGIN_ALLOW_THREADS
ret = SQLDriverConnectW(hdbc, 0, wchar.psz, SQL_NTS, 0, 0, 0, SQL_DRIVER_NOPROMPT);
Py_END_ALLOW_THREADS
if (SQL_SUCCEEDED(ret))
return true;
}

SQLWChar wchar(pConnectString, szEncoding ? szEncoding : "utf-8");
if (!wchar.isValid())
SQLWChar cstring(pConnectString, szEncoding ? szEncoding : ENCSTR_UTF16NE);
if (!cstring.isValid())
return false;

Py_BEGIN_ALLOW_THREADS
ret = SQLDriverConnect(hdbc, 0, (SQLCHAR*)wchar.psz, SQL_NTS, 0, 0, 0, SQL_DRIVER_NOPROMPT);
ret = SQLDriverConnectW(hdbc, 0, cstring, SQL_NTS, 0, 0, 0, SQL_DRIVER_NOPROMPT);
Py_END_ALLOW_THREADS
if (SQL_SUCCEEDED(ret))
return true;
Expand All @@ -133,6 +105,8 @@ static bool ApplyPreconnAttrs(HDBC hdbc, SQLINTEGER ikey, PyObject *value, char
SQLPOINTER ivalue = 0;
SQLINTEGER vallen = 0;

SQLWChar sqlchar;

if (PyLong_Check(value))
{
if (_PyLong_Sign(value) >= 0)
Expand All @@ -150,31 +124,11 @@ static bool ApplyPreconnAttrs(HDBC hdbc, SQLINTEGER ikey, PyObject *value, char
ivalue = (SQLPOINTER)PyByteArray_AsString(value);
vallen = SQL_IS_POINTER;
}
else if (PyBytes_Check(value))
{
ivalue = PyBytes_AS_STRING(value);
vallen = SQL_IS_POINTER;
}
else if (PyUnicode_Check(value))
{
Object stringholder;
if (sizeof(Py_UNICODE) == 2 // This part should be compile-time.
&& (!strencoding || !strcmp(strencoding, "utf-16le")))
{
// default or utf-16le is set, pass through directly
ivalue = PyUnicode_AS_UNICODE(value);
}
else
{
// use strencoding to convert, default to utf-16le if not set.
stringholder = PyCodec_Encode(value, strencoding ? strencoding : "utf-16le", "strict");
ivalue = PyBytes_AS_STRING(stringholder.Get());
}
sqlchar.set(value, strencoding ? strencoding : "utf-16le");
ivalue = sqlchar.get();
vallen = SQL_NTS;
Py_BEGIN_ALLOW_THREADS
ret = SQLSetConnectAttrW(hdbc, ikey, ivalue, vallen);
Py_END_ALLOW_THREADS
goto checkSuccess;
}
else if (PySequence_Check(value))
{
Expand All @@ -190,10 +144,9 @@ static bool ApplyPreconnAttrs(HDBC hdbc, SQLINTEGER ikey, PyObject *value, char
}

Py_BEGIN_ALLOW_THREADS
ret = SQLSetConnectAttr(hdbc, ikey, ivalue, vallen);
ret = SQLSetConnectAttrW(hdbc, ikey, ivalue, vallen);
Py_END_ALLOW_THREADS

checkSuccess:
if (!SQL_SUCCEEDED(ret))
{
RaiseErrorFromHandle(0, "SQLSetConnectAttr", hdbc, SQL_NULL_HANDLE);
Expand All @@ -205,15 +158,9 @@ static bool ApplyPreconnAttrs(HDBC hdbc, SQLINTEGER ikey, PyObject *value, char
return true;
}

PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, bool fAnsi, long timeout, bool fReadOnly,
PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, long timeout, bool fReadOnly,
PyObject* attrs_before, Object& encoding)
{
// pConnectString
// A string or unicode object. (This must be checked by the caller.)
//
// fAnsi
// If true, do not attempt a Unicode connection.

//
// Allocate HDBC and connect
//
Expand Down Expand Up @@ -255,7 +202,7 @@ PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, bool fAnsi,
}
}

if (!Connect(pConnectString, hdbc, fAnsi, timeout, encoding))
if (!Connect(pConnectString, hdbc, timeout, encoding))
{
// Connect has already set an exception.
Py_BEGIN_ALLOW_THREADS
Expand Down Expand Up @@ -1298,7 +1245,7 @@ static void NormalizeCodecName(const char* src, char* dest, size_t cbDest)
*pch = '\0';
}

static bool SetTextEncCommon(TextEnc& enc, const char* encoding, int ctype, bool allow_raw)
static bool SetTextEncCommon(TextEnc& enc, const char* encoding, int ctype)
{
// Code common to setencoding and setdecoding.

Expand Down Expand Up @@ -1396,9 +1343,8 @@ static PyObject* Connection_setencoding(PyObject* self, PyObject* args, PyObject
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|si", kwlist, &encoding, &ctype))
return 0;
TextEnc& enc = cnxn->unicode_enc;
bool allow_raw = false;

if (!SetTextEncCommon(enc, encoding, ctype, allow_raw))
if (!SetTextEncCommon(enc, encoding, ctype))
return 0;

Py_RETURN_NONE;
Expand Down Expand Up @@ -1426,7 +1372,6 @@ static PyObject* Connection_setdecoding(PyObject* self, PyObject* args, PyObject
int sqltype;
char* encoding = 0;
int ctype = 0;
bool allow_raw = false;

static char *kwlist[] = {"sqltype", "encoding", "ctype", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "i|si", kwlist, &sqltype, &encoding, &ctype))
Expand All @@ -1438,7 +1383,7 @@ static PyObject* Connection_setdecoding(PyObject* self, PyObject* args, PyObject
TextEnc& enc = (sqltype == SQL_CHAR) ? cnxn->sqlchar_enc :
((sqltype == SQL_WMETADATA) ? cnxn->metadata_enc : cnxn->sqlwchar_enc);

if (!SetTextEncCommon(enc, encoding, ctype, allow_raw))
if (!SetTextEncCommon(enc, encoding, ctype))
return 0;

Py_RETURN_NONE;
Expand Down
2 changes: 1 addition & 1 deletion src/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct Connection
* Used by the module's connect function to create new connection objects. If unable to connect to the database, an
* exception is set and zero is returned.
*/
PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, bool fAnsi, long timeout, bool fReadOnly,
PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, long timeout, bool fReadOnly,
PyObject* attrs_before, Object& encoding);

/*
Expand Down
26 changes: 13 additions & 13 deletions src/cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ static bool create_name_map(Cursor* cur, SQLSMALLINT field_count, bool lower)
bool success = false;
PyObject *desc = 0, *colmap = 0, *colinfo = 0, *type = 0, *index = 0, *nullable_obj=0;
SQLSMALLINT nameLen = 300;
ODBCCHAR *szName = NULL;
uint16_t *szName = NULL;
SQLRETURN ret;

assert(cur->hstmt != SQL_NULL_HANDLE && cur->colinfos != 0);
Expand All @@ -149,7 +149,7 @@ static bool create_name_map(Cursor* cur, SQLSMALLINT field_count, bool lower)

desc = PyTuple_New((Py_ssize_t)field_count);
colmap = PyDict_New();
szName = (ODBCCHAR*) PyMem_Malloc((nameLen + 1) * sizeof(ODBCCHAR));
szName = (uint16_t*) PyMem_Malloc((nameLen + 1) * sizeof(uint16_t));
if (!desc || !colmap || !szName)
goto done;

Expand Down Expand Up @@ -182,7 +182,7 @@ static bool create_name_map(Cursor* cur, SQLSMALLINT field_count, bool lower)
// If needed, allocate a bigger column name message buffer and retry.
if (cchName > nameLen - 1) {
nameLen = cchName + 1;
if (!PyMem_Realloc((BYTE**) &szName, (nameLen + 1) * sizeof(ODBCCHAR))) {
if (!PyMem_Realloc((BYTE**) &szName, (nameLen + 1) * sizeof(uint16_t))) {
PyErr_NoMemory();
goto done;
}
Expand Down Expand Up @@ -584,10 +584,10 @@ static int GetDiagRecs(Cursor* cur)
PyObject* msg_list; // the "messages" as a Python list of diagnostic records

SQLSMALLINT iRecNumber = 1; // the index of the diagnostic records (1-based)
ODBCCHAR cSQLState[6]; // five-character SQLSTATE code (plus terminating NULL)
uint16_t cSQLState[6]; // five-character SQLSTATE code (plus terminating NULL)
SQLINTEGER iNativeError;
SQLSMALLINT iMessageLen = 1023;
ODBCCHAR *cMessageText = (ODBCCHAR*) PyMem_Malloc((iMessageLen + 1) * sizeof(ODBCCHAR));
uint16_t *cMessageText = (uint16_t*) PyMem_Malloc((iMessageLen + 1) * sizeof(uint16_t));
SQLSMALLINT iTextLength;

SQLRETURN ret;
Expand Down Expand Up @@ -621,7 +621,7 @@ static int GetDiagRecs(Cursor* cur)
// If needed, allocate a bigger error message buffer and retry.
if (iTextLength > iMessageLen - 1) {
iMessageLen = iTextLength + 1;
if (!PyMem_Realloc((BYTE**) &cMessageText, (iMessageLen + 1) * sizeof(ODBCCHAR))) {
if (!PyMem_Realloc((BYTE**) &cMessageText, (iMessageLen + 1) * sizeof(uint16_t))) {
PyMem_Free(cMessageText);
PyErr_NoMemory();
return 0;
Expand All @@ -643,13 +643,13 @@ static int GetDiagRecs(Cursor* cur)
// Default to UTF-16, which may not work if the driver/manager is using some other encoding
const char *unicode_enc = cur->cnxn ? cur->cnxn->metadata_enc.name : ENCSTR_UTF16NE;
PyObject* msg_value = PyUnicode_Decode(
(char*)cMessageText, iTextLength * sizeof(ODBCCHAR), unicode_enc, "strict"
(char*)cMessageText, iTextLength * sizeof(uint16_t), unicode_enc, "strict"
);
if (!msg_value)
{
// If the char cannot be decoded, return something rather than nothing.
Py_XDECREF(msg_value);
msg_value = PyBytes_FromStringAndSize((char*)cMessageText, iTextLength * sizeof(ODBCCHAR));
msg_value = PyBytes_FromStringAndSize((char*)cMessageText, iTextLength * sizeof(uint16_t));
}

PyObject* msg_tuple = PyTuple_New(2); // the message as a Python tuple of class and value
Expand Down Expand Up @@ -748,7 +748,7 @@ static PyObject* execute(Cursor* cur, PyObject* pSql, PyObject* params, bool ski
bool isWide = (penc->ctype == SQL_C_WCHAR);

const char* pch = PyBytes_AS_STRING(query.Get());
SQLINTEGER cch = (SQLINTEGER)(PyBytes_GET_SIZE(query.Get()) / (isWide ? sizeof(ODBCCHAR) : 1));
SQLINTEGER cch = (SQLINTEGER)(PyBytes_GET_SIZE(query.Get()) / (isWide ? sizeof(uint16_t) : 1));

Py_BEGIN_ALLOW_THREADS
if (isWide)
Expand Down Expand Up @@ -1466,10 +1466,10 @@ static PyObject* Cursor_columns(PyObject* self, PyObject* args, PyObject* kwargs

Py_BEGIN_ALLOW_THREADS
ret = SQLColumnsW(cur->hstmt,
catalog.psz, SQL_NTS,
schema.psz, SQL_NTS,
table.psz, SQL_NTS,
column.psz, SQL_NTS);
catalog, SQL_NTS,
schema, SQL_NTS,
table, SQL_NTS,
column, SQL_NTS);
Py_END_ALLOW_THREADS

if (!SQL_SUCCEEDED(ret))
Expand Down
13 changes: 5 additions & 8 deletions src/errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ PyObject* RaiseErrorV(const char* sqlstate, PyObject* exc_class, const char* for
}


#define PyUnicode_CompareWithASCIIString(lhs, rhs) _strcmpi(PyUnicode_AS_STRING(lhs), rhs)


bool HasSqlState(PyObject* ex, const char* szSqlState)
{
// Returns true if `ex` is an exception and has the given SQLSTATE. It is safe to pass 0 for
Expand Down Expand Up @@ -205,9 +202,9 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
SQLINTEGER nNativeError;
SQLSMALLINT cchMsg;

ODBCCHAR sqlstateT[6];
uint16_t sqlstateT[6];
SQLSMALLINT msgLen = 1023;
ODBCCHAR *szMsg = (ODBCCHAR*) PyMem_Malloc((msgLen + 1) * sizeof(ODBCCHAR));
uint16_t *szMsg = (uint16_t*) PyMem_Malloc((msgLen + 1) * sizeof(uint16_t));

if (!szMsg) {
PyErr_NoMemory();
Expand Down Expand Up @@ -254,7 +251,7 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
// If needed, allocate a bigger error message buffer and retry.
if (cchMsg > msgLen - 1) {
msgLen = cchMsg + 1;
if (!PyMem_Realloc((BYTE**) &szMsg, (msgLen + 1) * sizeof(ODBCCHAR))) {
if (!PyMem_Realloc((BYTE**) &szMsg, (msgLen + 1) * sizeof(uint16_t))) {
PyErr_NoMemory();
PyMem_Free(szMsg);
return 0;
Expand All @@ -272,7 +269,7 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
// For now, default to UTF-16 if this is not in the context of a connection.
// Note that this will not work if the DM is using a different wide encoding (e.g. UTF-32).
const char *unicode_enc = conn ? conn->metadata_enc.name : ENCSTR_UTF16NE;
Object msgStr(PyUnicode_Decode((char*)szMsg, cchMsg * sizeof(ODBCCHAR), unicode_enc, "strict"));
Object msgStr(PyUnicode_Decode((char*)szMsg, cchMsg * sizeof(uint16_t), unicode_enc, "strict"));

if (cchMsg != 0 && msgStr.Get())
{
Expand Down Expand Up @@ -314,7 +311,7 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
// Raw message buffer not needed anymore
PyMem_Free(szMsg);

if (!msg || PyUnicode_GetSize(msg.Get()) == 0)
if (!msg || PyUnicode_GET_LENGTH(msg.Get()) == 0)
{
// This only happens using unixODBC. (Haven't tried iODBC yet.) Either the driver or the driver manager is
// buggy and has signaled a fault without recording error information.
Expand Down
6 changes: 3 additions & 3 deletions src/errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline PyObject* RaiseErrorFromException(PyObject* pError)
return 0;
}

inline void CopySqlState(const ODBCCHAR* src, char* dest)
inline void CopySqlState(const uint16_t* src, char* dest)
{
// Copies a SQLSTATE read as SQLWCHAR into a character buffer. We know that SQLSTATEs are
// composed of ASCII characters and we need one standard to compare when choosing
Expand All @@ -71,14 +71,14 @@ inline void CopySqlState(const ODBCCHAR* src, char* dest)
// Strangely, even when the error messages are UTF-8, PostgreSQL and MySQL encode the
// sqlstate as UTF-16LE. We'll simply copy all non-zero bytes, with some checks for
// running off the end of the buffers which will work for ASCII, UTF8, and UTF16 LE & BE.
// It would work for UTF32 if I increase the size of the ODBCCHAR buffer to handle it.
// It would work for UTF32 if I increase the size of the uint16_t buffer to handle it.
//
// (In the worst case, if a driver does something totally weird, we'll have an incomplete
// SQLSTATE.)
//

const char* pchSrc = (const char*)src;
const char* pchSrcMax = pchSrc + sizeof(ODBCCHAR) * 5;
const char* pchSrcMax = pchSrc + sizeof(uint16_t) * 5;
char* pchDest = dest; // Where we are copying into dest
char* pchDestMax = dest + 5; // We know a SQLSTATE is 5 characters long

Expand Down
Loading

0 comments on commit 5c1f1c0

Please sign in to comment.