diff --git a/dataset/table.py b/dataset/table.py index 0b817ef..907333e 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -121,10 +121,12 @@ def insert(self, row, ensure=None, types=None): return res.inserted_primary_key[0] return True - def insert_ignore(self, row, keys, ensure=None, types=None): + def insert_ignore(self, row, keys=None, ensure=None, types=None): """Add a ``row`` dict into the table if the row does not exist. If rows with matching ``keys`` exist no change is made. + If ``keys`` are not passed, keys will be + replaced by columns with unique constraints of table. Setting ``ensure`` results in automatically creating missing columns, i.e., keys of the row are not table columns. @@ -139,6 +141,11 @@ def insert_ignore(self, row, keys, ensure=None, types=None): data = dict(id=10, title='I am a banana!') table.insert_ignore(data, ['id']) """ + if not keys and not self.unique_columns: + log.warning("Insert ignore can not be executed. Table does not have unique columns") + return + + keys = keys or self.unique_columns row = self._sync_columns(row, ensure, types=types) if self._check_ensure(ensure): self.create_index(keys) @@ -184,7 +191,7 @@ def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): self.table.insert().execute(chunk) chunk = [] - def update(self, row, keys, ensure=None, types=None, return_count=False): + def update(self, row, keys=None, ensure=None, types=None, return_count=False): """Update a row in the table. The update is managed via the set of column names stated in ``keys``: @@ -200,7 +207,15 @@ def update(self, row, keys, ensure=None, types=None, return_count=False): If keys in ``row`` update columns not present in the table, they will be created based on the settings of ``ensure`` and ``types``, matching the behavior of :py:meth:`insert() `. + + If ``keys`` are not passed, keys will be + replaced by columns with unique constraints of table. """ + if not keys and not self.unique_columns: + log.warning("Update can not be executed. Table does not have unique columns") + return + + keys = keys or self.unique_columns row = self._sync_columns(row, ensure, types=types) args, row = self._keys_to_args(row, keys) clause = self._args_to_clause(args) @@ -213,7 +228,7 @@ def update(self, row, keys, ensure=None, types=None, return_count=False): if return_count: return self.count(clause) - def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + def update_many(self, rows, keys=None, chunk_size=1000, ensure=None, types=None): """Update many rows in the table at a time. This is significantly faster than updating them one by one. Per default @@ -222,7 +237,14 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): See :py:meth:`update() ` for details on the other parameters. + If ``keys`` are not passed, keys will be + replaced by columns with unique constraints of table. """ + if not keys and not self.unique_columns: + log.warning("Update can not be executed. Table does not have unique columns") + return + + keys = keys or self.unique_columns keys = ensure_list(keys) chunk = [] @@ -247,16 +269,23 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): self.db.executable.execute(stmt, chunk) chunk = [] - def upsert(self, row, keys, ensure=None, types=None): + def upsert(self, row, keys=None, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist they will be updated, otherwise a new row is inserted in the table. + If ``keys`` are not passed, keys will be + replaced by columns with unique constraints of table. :: data = dict(id=10, title='I am a banana!') table.upsert(data, ['id']) """ + if not keys and not self.unique_columns: + log.warning("Upsert can not be executed. Table does not have unique columns") + return + + keys = keys or self.unique_columns row = self._sync_columns(row, ensure, types=types) if self._check_ensure(ensure): self.create_index(keys) @@ -265,7 +294,7 @@ def upsert(self, row, keys, ensure=None, types=None): return self.insert(row, ensure=False) return True - def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + def upsert_many(self, rows, keys=None, chunk_size=1000, ensure=None, types=None): """ Sorts multiple input rows into upserts and inserts. Inserts are passed to insert and upserts are updated. @@ -296,6 +325,15 @@ def delete(self, *clauses, **filters): rp = self.db.executable.execute(stmt) return rp.rowcount > 0 + @property + def unique_columns(self): + """Get table unique columns""" + u_constraints = self.db.inspect.get_unique_constraints(self.name, schema=self.db.schema) + return list({ + column for constraint in u_constraints + for column in constraint['column_names'] + }) + def _reflect_table(self): """Load the tables definition from the database.""" with self.db.lock: diff --git a/test/test_dataset.py b/test/test_dataset.py index 9c2bb8e..7d03492 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -156,8 +156,31 @@ def setUp(self): for row in TEST_DATA: self.tbl.insert(row) + self.tbl_uniq_columns = self._init_table_with_unique_columns() + + def _init_table_with_unique_columns(self): + table = self.db.create_table('uniq_table') + table.create_column('key1', self.db.types.string(255), unique=True, nullable=False) + + if "sqlite" not in self.db.engine.dialect.dbapi.__name__: + # we can't execute create_column for sqlite multiple times + table.create_column('key2', self.db.types.float, unique=False) + table.create_column('key3', self.db.types.integer, unique=True) + + return table + + def _get_data_for_table_with_unique_columns(self): + data = dict(key1='I am banana') + unique_columns = ['key1'] + if "sqlite" not in self.db.engine.dialect.dbapi.__name__: + data.update(key2=3.1456, key3=42) + unique_columns.append('key3') + + return data, unique_columns + def tearDown(self): self.tbl.drop() + self.tbl_uniq_columns.drop() def test_insert(self): assert len(self.tbl) == len(TEST_DATA), len(self.tbl) @@ -237,6 +260,14 @@ def test_upsert_id(self): table.upsert(data, ["id"]) assert len(table) == 1, len(table) + def test_upsert_with_default_unique_columns(self): + data, unique_columns = self._get_data_for_table_with_unique_columns() + table = self.tbl_uniq_columns + table.upsert(data) + assert len(table) == 1, len(table) + table.upsert(data, unique_columns) + assert len(table) == 1, len(table) + def test_update_while_iter(self): for row in self.tbl: row["foo"] = "bar" @@ -540,6 +571,10 @@ def test_empty_query(self): empty = list(self.tbl.find(place="not in data")) assert len(empty) == 0, empty + def test_unique_columns(self): + data, unique_columns = self._get_data_for_table_with_unique_columns() + assert len(set(self.tbl_uniq_columns.unique_columns) & set(unique_columns)) == len(unique_columns) + class Constructor(dict): """Very simple low-functionality extension to ``dict`` to