Skip to content

Commit

Permalink
Added create_functions to be called the sync_pgfunctions command zach…
Browse files Browse the repository at this point in the history
  • Loading branch information
Scott Walton committed May 10, 2013
1 parent 986c655 commit b2c911d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 22 deletions.
63 changes: 44 additions & 19 deletions django_postgres/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ def _generate_function(name, args, fields, definition):
return sql


def _function_exists(cursor, name):
"""Returns True or False depending whether function with name exists.
"""
function_query = (
u"SELECT COUNT(*) "
u"FROM pg_catalog.pg_namespace n "
u"JOIN pg_catalog.pg_proc p "
u"ON pronamespace = n.oid "
u"WHERE nspname = 'public' and proname = %s;")
cursor.execute(function_query, [name])
return cursor.fetchone()[0] > 0


def _force_required(cursor, name, args):
"""Returns whether the function signature is compatible with the new
definition.
"""
function_detail_query = (
u"SELECT pronargs "
u"FROM pg_catalog.pg_namespace n "
u"JOIN pg_catalog.pg_proc p "
u"ON pronamespace = n.oid "
u"WHERE nspname = 'public' and proname = %s;")
cursor.execute(function_detail_query, [name])
return cursor.fetchone()[0] != len(args)


def create_function(connection, function_name, function_fields,
function_definition, update=True):
"""
Expand All @@ -41,37 +68,25 @@ def create_function(connection, function_name, function_fields,
If ``update`` is True (default), attempt to update an existing function.
"""
cursor_wrapper = connection.cursor()
cursor = cursor_wrapper.cursor.cursor
cursor = cursor_wrapper.cursor

name, args = _split_function_args(function_name)

try:
force_required = False
# Determine if function already exists.
function_query = (
u"SELECT COUNT(*) "
u"FROM pg_catalog.pg_namespace n "
u"JOIN pg_catalog.pg_proc p "
u"ON pronamespace = n.oid "
u"WHERE nspname = 'public' and proname = %s;")
cursor.execute(function_query, [name])
function_exists = cursor.fetchone()[0] > 0

function_exists = _function_exists(cursor, name)

if function_exists and not update:
return 'EXISTS'
elif function_exists:
function_detail_query = (
u"SELECT pronargs "
u"FROM pg_catalog.pg_namespace n "
u"JOIN pg_catalog.pg_proc p "
u"ON pronamespace = n.oid "
u"WHERE nspname = 'public' and proname = %s;")
cursor.execute(function_detail_query, [name])
force_required = cursor.fetchone()[0] != len(args)
force_required = _force_required(cursor, name, args)

if not force_required:
function_sql = _generate_function(
name, args, function_fields, function_definition)

print function_sql
cursor.execute(function_sql)
ret = 'UPDATED' if function_exists else 'CREATED'
else:
Expand All @@ -83,6 +98,15 @@ def create_function(connection, function_name, function_fields,
cursor_wrapper.close()


def _get_field_type(field):
"""Returns the field type as a string for SQL.
"""
return field.db_type(
connection).replace(
'serial', 'bigint').replace(
'integer', 'bigint')


def create_functions(models_module, update=True):
"""Create the database functions for a given models module.
"""
Expand All @@ -97,8 +121,9 @@ def create_functions(models_module, update=True):

function_name = function_cls._meta.db_table
fields = tuple(
' '.join(n, b.db_type(connection)) for n, b in
' '.join((n, _get_field_type(f))) for n, f in
get_fields_by_name(function_cls, '*').iteritems())

definition = function_cls.sql

create_function(
Expand Down
6 changes: 4 additions & 2 deletions tests/test_project/functiontest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ class UserTypeCounter(django_postgres.Function):
"""A simple class that tests the function. Can be called with
either True or False as arguments
"""
sql = """SELECT COUNT(*) AS my_count FROM auth_user WHERE
is_superuser = $1;"""
sql = """SELECT COUNT(*) AS my_count, CAST(1 AS BIGINT)
FROM auth_user WHERE
is_superuser = $1"""

my_count = models.IntegerField()
id = models.IntegerField(primary_key=True)

class Meta:
db_table = 'user_type (BOOLEAN)'
14 changes: 13 additions & 1 deletion tests/test_project/functiontest/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import models

from django_postgres.function import create_function
from django_postgres.function import (create_function, create_functions,
_function_exists)


class FunctionTestCase(TestCase):
Expand Down Expand Up @@ -77,3 +78,14 @@ def test_error_function(self):
updated = create_function(connection, name, field, definition)

self.assertEqual(updated, 'ERROR: Manually Drop This Function')

def test_create_functions_from_models(self):
"""Create functions using the create_functions and passing the models
module.
"""
create_functions(models)

# Now check it was created
cursor_wrapper = connection.cursor()
cursor = cursor_wrapper.cursor
self.assertEqual(_function_exists(cursor, 'user_type'), True)

0 comments on commit b2c911d

Please sign in to comment.