Skip to content

Commit

Permalink
Send numbers and boolean values in binary form, get rid of asprintf().
Browse files Browse the repository at this point in the history
  • Loading branch information
mbalmer committed Feb 15, 2015
1 parent be09676 commit 017d520
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 36 deletions.
116 changes: 81 additions & 35 deletions luapgsql.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

/* PostgreSQL extension module (using Lua) */

#include <endian.h>
#include <string.h>
#include <stdlib.h>
#include <stdint.h>

#include <libpq-fe.h>
#include <libpq/libpq-fs.h>
Expand Down Expand Up @@ -340,20 +342,20 @@ conn_exec(lua_State *L)
}

static int
get_sql_params(lua_State *L, int t, int n, Oid *paramTypes, char **paramValues)
get_sql_params(lua_State *L, int t, int n, Oid *paramTypes, char **paramValues,
int *paramLengths, int *paramFormats)
{
double v;
int k;

switch (lua_type(L, t)) {
case LUA_TBOOLEAN:
if (paramTypes != NULL)
paramTypes[n] = BOOLOID;
if (paramValues != NULL) {
if (lua_toboolean(L, t))
paramValues[n] = strdup("true");
else
paramValues[n] = strdup("false");
paramValues[n] = malloc(1);
*(char *)paramValues[n] = lua_toboolean(L, t);
paramLengths[n] = 1;
paramFormats[n] = 1;
}
n = 1;
break;
Expand All @@ -362,12 +364,20 @@ get_sql_params(lua_State *L, int t, int n, Oid *paramTypes, char **paramValues)
* XXX Does not handle math.huge (Infinity), since Infinity
* and -Infinity are not defined for PostgreSQL numeric values.
*/
v = lua_tonumber(L, t);
if (paramTypes != NULL)
paramTypes[n] = NUMERICOID;
if (paramValues != NULL)
if (asprintf(&paramValues[n], "%f", v) == -1)
paramValues[n] = NULL;
paramTypes[n] = FLOAT8OID;
if (paramValues != NULL) {
union {
double v;
uint64_t i;
} swap;

swap.v = lua_tonumber(L, t);
paramValues[n] = malloc(sizeof(uint64_t));
*(uint64_t *)paramValues[n] = htobe64(swap.i);

This comment has been minimized.

Copy link
@daurnimator

daurnimator Feb 15, 2015

Contributor

Not sure if you checked on all platforms, the linux man page has this conformity note about htobe64:

These functions are nonstandard. Similar functions are present on the BSDs, where the required header file is <sys/endian.h> instead of <endian.h>. Unfortunately, NetBSD, FreeBSD, and glibc haven't followed the original OpenBSD naming convention for these functions, whereby the nn component always appears at the end of the function name (thus, for example, in NetBSD, FreeBSD, and glibc, the equivalent of OpenBSDs "betoh32" is "be32toh").

This comment has been minimized.

Copy link
@daurnimator

daurnimator Feb 15, 2015

Contributor

Sadly endian.h seems to be terribly supported across compilers and operating systems.
I found https://gist.github.com/panzi/6856583 which you might be able to use.

This comment has been minimized.

Copy link
@kiug

kiug via email Feb 16, 2015

This comment has been minimized.

Copy link
@mbalmer

mbalmer via email Feb 16, 2015

Author Collaborator

This comment has been minimized.

Copy link
@kiug

kiug via email Feb 16, 2015

This comment has been minimized.

Copy link
@mbalmer

mbalmer via email Feb 16, 2015

Author Collaborator

This comment has been minimized.

Copy link
@mbalmer

mbalmer via email Feb 16, 2015

Author Collaborator
paramLengths[n] = sizeof(uint64_t);
paramFormats[n] = 1;
}
n = 1;
break;
case LUA_TSTRING:
Expand All @@ -388,7 +398,8 @@ get_sql_params(lua_State *L, int t, int n, Oid *paramTypes, char **paramValues)
lua_gettable(L, t);
if (lua_isnil(L, -1))
break;
n += get_sql_params(L, -1, n, paramTypes, paramValues);
n += get_sql_params(L, -1, n, paramTypes, paramValues,
paramLengths, paramFormats);
lua_pop(L, 1);
}
lua_pop(L, 1);
Expand All @@ -405,37 +416,45 @@ conn_execParams(lua_State *L)
PGresult **res;
Oid *paramTypes;
char **paramValues;
int n, nParams, sqlParams;
int n, nParams, sqlParams, *paramLengths, *paramFormats;

nParams = lua_gettop(L) - 2; /* subtract connection and command */
if (nParams < 0)
nParams = 0;

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, sqlParams, NULL, NULL);
sqlParams += get_sql_params(L, 3 + n, sqlParams, NULL, NULL,
NULL, NULL);

if (sqlParams) {
paramTypes = calloc(sqlParams, sizeof(Oid));
paramValues = calloc(sqlParams, sizeof(char *));
paramLengths = calloc(sqlParams, sizeof(int));
paramFormats = calloc(sqlParams, sizeof(int));

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, sqlParams,
paramTypes, paramValues);
paramTypes, paramValues, paramLengths,
paramFormats);
} else {
paramTypes = NULL;
paramValues = NULL;
paramLengths = NULL;
paramFormats = NULL;
}
res = lua_newuserdata(L, sizeof(PGresult *));
*res = PQexecParams(*(PGconn **)luaL_checkudata(L, 1, CONN_METATABLE),
luaL_checkstring(L, 2), sqlParams, paramTypes,
(const char * const*)paramValues, NULL, NULL, 0);
(const char * const*)paramValues, paramLengths, paramFormats, 0);
luaL_getmetatable(L, RES_METATABLE);
lua_setmetatable(L, -2);
if (sqlParams) {
for (n = 0; n < sqlParams; n++)
free((void *)paramValues[n]);
free(paramTypes);
free(paramValues);
free(paramLengths);
free(paramFormats);
}
return 1;
}
Expand All @@ -452,14 +471,15 @@ conn_prepare(lua_State *L)
nParams = 0;

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 4 + n, sqlParams, NULL, NULL);
sqlParams += get_sql_params(L, 4 + n, sqlParams, NULL, NULL,
NULL, NULL);

if (sqlParams) {
paramTypes = calloc(sqlParams, sizeof(Oid));

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 4 + n, sqlParams,
paramTypes, NULL);
paramTypes, NULL, NULL, NULL);
} else
paramTypes = NULL;
res = lua_newuserdata(L, sizeof(PGresult *));
Expand All @@ -478,34 +498,42 @@ conn_execPrepared(lua_State *L)
{
PGresult **res;
char **paramValues;
int n, nParams, sqlParams;
int n, nParams, sqlParams, *paramLengths, *paramFormats;

nParams = lua_gettop(L) - 2; /* subtract connection and name */
if (nParams < 0)
nParams = 0;

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, sqlParams, NULL, NULL);
sqlParams += get_sql_params(L, 3 + n, sqlParams, NULL, NULL,
NULL, NULL);

if (sqlParams) {
paramValues = calloc(sqlParams, sizeof(char *));
paramLengths = calloc(sqlParams, sizeof(int));
paramFormats = calloc(sqlParams, sizeof(int));

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, sqlParams, NULL,
paramValues);
} else
paramValues, paramLengths, paramFormats);
} else {
paramValues = NULL;
paramLengths = NULL;
paramFormats = NULL;
}
res = lua_newuserdata(L, sizeof(PGresult *));
*res = PQexecPrepared(*(PGconn **)luaL_checkudata(L, 1, CONN_METATABLE),
luaL_checkstring(L, 2), sqlParams, (const char * const*)paramValues,
NULL, NULL, 0);
paramLengths, paramFormats, 0);
luaL_getmetatable(L, RES_METATABLE);
lua_setmetatable(L, -2);
if (sqlParams) {
for (n = 0; n < sqlParams; n++)
if (paramValues[n] != NULL)
free((void *)paramValues[n]);
free(paramValues);
free(paramLengths);
free(paramFormats);
}
return 1;
}
Expand Down Expand Up @@ -649,36 +677,44 @@ conn_sendQueryParams(lua_State *L)
{
Oid *paramTypes;
char **paramValues;
int n, nParams, sqlParams;
int n, nParams, sqlParams, *paramLengths, *paramFormats;

nParams = lua_gettop(L) - 2; /* subtract connection and command */
if (nParams < 0)
nParams = 0;

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, 0, NULL, NULL);
sqlParams += get_sql_params(L, 3 + n, 0, NULL, NULL, NULL,
NULL);

if (sqlParams) {
paramTypes = calloc(sqlParams, sizeof(Oid));
paramValues = calloc(sqlParams, sizeof(char *));
paramLengths = calloc(sqlParams, sizeof(int));
paramFormats = calloc(sqlParams, sizeof(int));

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, sqlParams,
paramTypes, paramValues);
paramTypes, paramValues, paramLengths,
paramFormats);
} else {
paramTypes = NULL;
paramValues = NULL;
paramLengths = NULL;
paramFormats = NULL;
}
lua_pushinteger(L,
PQsendQueryParams(*(PGconn **)luaL_checkudata(L, 1, CONN_METATABLE),
luaL_checkstring(L, 2), sqlParams, paramTypes,
(const char * const*)paramValues, NULL, NULL, 0));
(const char * const*)paramValues, paramLengths, paramFormats, 0));
if (sqlParams) {
for (n = 0; n < sqlParams; n++)
if (paramValues[n] != NULL)
free((void *)paramValues[n]);
free(paramTypes);
free(paramValues);
free(paramLengths);
free(paramFormats);
}
return 1;
}
Expand All @@ -694,14 +730,15 @@ conn_sendPrepare(lua_State *L)
nParams = 0;

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 4 + n, 0, NULL, NULL);
sqlParams += get_sql_params(L, 4 + n, 0, NULL, NULL, NULL,
NULL);

if (sqlParams) {
paramTypes = calloc(sqlParams, sizeof(Oid));

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 4 + n, sqlParams,
paramTypes, NULL);
paramTypes, NULL, NULL, NULL);
} else
paramTypes = NULL;
lua_pushinteger(L,
Expand All @@ -717,33 +754,42 @@ static int
conn_sendQueryPrepared(lua_State *L)
{
char **paramValues;
int n, nParams, sqlParams;
int n, nParams, sqlParams, *paramLengths, *paramFormats;

nParams = lua_gettop(L) - 2; /* subtract connection and name */
if (nParams < 0)
nParams = 0;

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, 0, NULL, NULL);
sqlParams += get_sql_params(L, 3 + n, 0, NULL, NULL, NULL,
NULL);

if (sqlParams) {
paramValues = calloc(sqlParams, sizeof(char *));
paramLengths = calloc(sqlParams, sizeof(int));
paramFormats = calloc(sqlParams, sizeof(int));

for (n = 0, sqlParams = 0; n < nParams; n++)
sqlParams += get_sql_params(L, 3 + n, sqlParams, NULL,
paramValues);
} else
paramValues, paramLengths, paramFormats);
} else {
paramValues = NULL;
paramLengths = NULL;
paramFormats = NULL;
}
lua_pushinteger(L,
PQsendQueryPrepared(*(PGconn **)luaL_checkudata(L, 1,
CONN_METATABLE),
luaL_checkstring(L, 2), nParams, (const char * const*)paramValues,
NULL, NULL, 0));
paramLengths, paramFormats, 0));
if (nParams) {
for (n = 0; n < nParams; n++)
if (paramValues[n] != NULL)
free((void *)paramValues[n]);
free(paramValues);
free(paramLengths);
free(paramFormats);

}
return 1;
}
Expand Down Expand Up @@ -1525,7 +1571,7 @@ pgsql_set_info(lua_State *L)
lua_pushliteral(L, "PostgreSQL binding for Lua");
lua_settable(L, -3);
lua_pushliteral(L, "_VERSION");
lua_pushliteral(L, "pgsql 1.4.1");
lua_pushliteral(L, "pgsql 1.4.2");
lua_settable(L, -3);
}

Expand Down
2 changes: 1 addition & 1 deletion luapgsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
/* OIDs from server/pg_type.h */
#define BOOLOID 16
#define TEXTOID 25
#define NUMERICOID 1700
#define FLOAT8OID 701

typedef struct largeObject {
PGconn *conn;
Expand Down

0 comments on commit 017d520

Please sign in to comment.