Skip to content

Commit

Permalink
When using DbTransactions, set the transaction on the DbCommands
Browse files Browse the repository at this point in the history
Partial fix for issue #1142

Based heavily on @ChrisKrikade's PR (mostly just moved CreateCommand() and
ExecuteWithinTransaction() down to X509CertificateDatabase where it likely
belongs instead of SqlCertificateDatabase).
  • Loading branch information
jstedfast committed Feb 22, 2025
1 parent 12a0047 commit 893a29c
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 86 deletions.
6 changes: 3 additions & 3 deletions MimeKit/Cryptography/NpgsqlCertificateDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public NpgsqlCertificateDatabase (DbConnection connection, string password) : ba
/// <returns>The list of columns.</returns>
protected override IList<DataColumn> GetTableColumns (DbConnection connection, string tableName)
{
using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = $"PRAGMA table_info({tableName})";
using (var reader = command.ExecuteReader ()) {
var columns = new List<DataColumn> ();
Expand Down Expand Up @@ -229,7 +229,7 @@ protected override void CreateTable (DbConnection connection, DataTable table)

statement.Append (')');

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = statement.ToString ();
command.CommandType = CommandType.Text;
command.ExecuteNonQuery ();
Expand All @@ -254,7 +254,7 @@ protected override void AddTableColumn (DbConnection connection, DataTable table
statement.Append (" ADD COLUMN ");
Build (statement, table, column, ref primaryKeys);

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = statement.ToString ();
command.CommandType = CommandType.Text;
command.ExecuteNonQuery ();
Expand Down
14 changes: 7 additions & 7 deletions MimeKit/Cryptography/SQLServerCertificateDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ protected override void AddTableColumn (DbConnection connection, DataTable table
statement.Append (" ADD COLUMN ");
Build (statement, table, column, ref primaryKeys);

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = statement.ToString ();
command.CommandType = CommandType.Text;
command.ExecuteNonQuery ();
Expand Down Expand Up @@ -134,7 +134,7 @@ protected override void CreateTable (DbConnection connection, DataTable table)

statement.Append (')');

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = statement.ToString ();
command.CommandType = CommandType.Text;
command.ExecuteNonQuery ();
Expand Down Expand Up @@ -188,7 +188,7 @@ static void Build (StringBuilder statement, DataTable table, DataColumn column,
/// <returns>The list of columns.</returns>
protected override IList<DataColumn> GetTableColumns (DbConnection connection, string tableName)
{
using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = $"select top 1 * from {tableName}";
using (var reader = command.ExecuteReader ()) {
var columns = new List<DataColumn> ();
Expand Down Expand Up @@ -216,7 +216,7 @@ protected override void CreateIndex (DbConnection connection, string tableName,
var indexName = GetIndexName (tableName, columnNames);
var query = string.Format ("IF NOT EXISTS (Select 8 from sys.indexes where name='{0}' and object_id=OBJECT_ID('{1}')) CREATE INDEX {0} ON {1}({2})", indexName, tableName, string.Join (", ", columnNames));

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = query;
command.ExecuteNonQuery ();
}
Expand All @@ -236,7 +236,7 @@ protected override void RemoveIndex (DbConnection connection, string tableName,
var indexName = GetIndexName (tableName, columnNames);
var query = string.Format ("IF EXISTS (Select 8 from sys.indexes where name='{0}' and object_id=OBJECT_ID('{1}')) DROP INDEX {0} ON {1}", indexName, tableName);

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = query;
command.ExecuteNonQuery ();
}
Expand All @@ -257,7 +257,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, X509Cert
var fingerprint = certificate.GetFingerprint ().ToLowerInvariant ();
var serialNumber = certificate.SerialNumber.ToString ();
var issuerName = certificate.IssuerDN.ToString ();
var command = connection.CreateCommand ();
var command = CreateCommand ();
var query = CreateSelectQuery (fields).Replace ("SELECT", "SELECT top 1");

// FIXME: Is this really the best way to query for an exact match of a certificate?
Expand Down Expand Up @@ -285,7 +285,7 @@ protected override DbCommand GetInsertCommand (DbConnection connection, X509Cert
{
var statement = new StringBuilder ("INSERT INTO CERTIFICATES(");
var variables = new StringBuilder ("VALUES(");
var command = connection.CreateCommand ();
var command = CreateCommand ();
var columns = CertificatesTable.Columns;

for (int i = 1; i < columns.Count; i++) {
Expand Down
132 changes: 59 additions & 73 deletions MimeKit/Cryptography/SqlCertificateDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ protected virtual void CreateIndex (DbConnection connection, string tableName, p
var indexName = GetIndexName (tableName, columnNames);
var query = string.Format ("CREATE INDEX IF NOT EXISTS {0} ON {1}({2})", indexName, tableName, string.Join (", ", columnNames));

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = query;
command.ExecuteNonQuery ();
}
Expand All @@ -244,7 +244,7 @@ protected virtual void RemoveIndex (DbConnection connection, string tableName, p
var indexName = GetIndexName (tableName, columnNames);
var query = string.Format ("DROP INDEX IF EXISTS {0}", indexName);

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = query;
command.ExecuteNonQuery ();
}
Expand Down Expand Up @@ -274,44 +274,37 @@ void CreateCertificatesTable (DbConnection connection, DataTable table)

if (!hasAnchorColumn) {
// Upgrade from Version 1.
using (var transaction = connection.BeginTransaction ()) {
try {
var column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.Anchor)];
AddTableColumn (connection, table, column);

column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectName)];
AddTableColumn (connection, table, column);

column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectKeyIdentifier)];
AddTableColumn (connection, table, column);

// Note: The SubjectEmail column exists, but the SubjectDnsNames column was added later, so make sure to add that.
column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectDnsNames)];
AddTableColumn (connection, table, column);

foreach (var record in Find (null, false, X509CertificateRecordFields.Id | X509CertificateRecordFields.Certificate)) {
var statement = $"UPDATE {CertificatesTableName} SET {CertificateColumnNames.Anchor} = @ANCHOR, {CertificateColumnNames.SubjectName} = @SUBJECTNAME, {CertificateColumnNames.SubjectKeyIdentifier} = @SUBJECTKEYIDENTIFIER, {CertificateColumnNames.SubjectEmail} = @SUBJECTEMAIL, {CertificateColumnNames.SubjectDnsNames} = @SUBJECTDNSNAMES WHERE {CertificateColumnNames.Id} = @ID";

using (var command = connection.CreateCommand ()) {
command.AddParameterWithValue ("@ID", record.Id);
command.AddParameterWithValue ("@ANCHOR", record.IsAnchor);
command.AddParameterWithValue ("@SUBJECTNAME", record.SubjectName);
command.AddParameterWithValue ("@SUBJECTKEYIDENTIFIER", record.SubjectKeyIdentifier?.AsHex ());
command.AddParameterWithValue ("@SUBJECTEMAIL", record.SubjectEmail);
command.AddParameterWithValue ("@SUBJECTDNSNAMES", EncodeDnsNames (record.SubjectDnsNames));
command.CommandType = CommandType.Text;
command.CommandText = statement;

command.ExecuteNonQuery ();
}
ExecuteWithinTransaction (() => {
var column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.Anchor)];
AddTableColumn (connection, table, column);

column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectName)];
AddTableColumn (connection, table, column);

column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectKeyIdentifier)];
AddTableColumn (connection, table, column);

// Note: The SubjectEmail column exists, but the SubjectDnsNames column was added later, so make sure to add that.
column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectDnsNames)];
AddTableColumn (connection, table, column);

foreach (var record in Find (null, false, X509CertificateRecordFields.Id | X509CertificateRecordFields.Certificate)) {
var statement = $"UPDATE {CertificatesTableName} SET {CertificateColumnNames.Anchor} = @ANCHOR, {CertificateColumnNames.SubjectName} = @SUBJECTNAME, {CertificateColumnNames.SubjectKeyIdentifier} = @SUBJECTKEYIDENTIFIER, {CertificateColumnNames.SubjectEmail} = @SUBJECTEMAIL, {CertificateColumnNames.SubjectDnsNames} = @SUBJECTDNSNAMES WHERE {CertificateColumnNames.Id} = @ID";

using (var command = CreateCommand ()) {
command.AddParameterWithValue ("@ID", record.Id);
command.AddParameterWithValue ("@ANCHOR", record.IsAnchor);
command.AddParameterWithValue ("@SUBJECTNAME", record.SubjectName);
command.AddParameterWithValue ("@SUBJECTKEYIDENTIFIER", record.SubjectKeyIdentifier?.AsHex ());
command.AddParameterWithValue ("@SUBJECTEMAIL", record.SubjectEmail);
command.AddParameterWithValue ("@SUBJECTDNSNAMES", EncodeDnsNames (record.SubjectDnsNames));
command.CommandType = CommandType.Text;
command.CommandText = statement;

command.ExecuteNonQuery ();
}

transaction.Commit ();
} catch {
transaction.Rollback ();
throw;
}
}
});

// Remove some old indexes
RemoveIndex (connection, table.TableName, CertificateColumnNames.Trusted);
Expand All @@ -321,31 +314,24 @@ void CreateCertificatesTable (DbConnection connection, DataTable table)
RemoveIndex (connection, table.TableName, CertificateColumnNames.BasicConstraints, CertificateColumnNames.SubjectEmail);
} else if (!hasSubjectDnsNamesColumn) {
// Upgrade from Version 2.
using (var transaction = connection.BeginTransaction ()) {
try {
var column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectDnsNames)];
AddTableColumn (connection, table, column);

foreach (var record in Find (null, false, X509CertificateRecordFields.Id | X509CertificateRecordFields.Certificate)) {
var statement = $"UPDATE {CertificatesTableName} SET {CertificateColumnNames.SubjectEmail} = @SUBJECTEMAIL, {CertificateColumnNames.SubjectDnsNames} = @SUBJECTDNSNAMES WHERE {CertificateColumnNames.Id} = @ID";

using (var command = connection.CreateCommand ()) {
command.AddParameterWithValue ("@ID", record.Id);
command.AddParameterWithValue ("@SUBJECTEMAIL", record.SubjectEmail);
command.AddParameterWithValue ("@SUBJECTDNSNAMES", EncodeDnsNames (record.SubjectDnsNames));
command.CommandType = CommandType.Text;
command.CommandText = statement;

command.ExecuteNonQuery ();
}
}
ExecuteWithinTransaction (() => {
var column = table.Columns[table.Columns.IndexOf (CertificateColumnNames.SubjectDnsNames)];
AddTableColumn (connection, table, column);

foreach (var record in Find (null, false, X509CertificateRecordFields.Id | X509CertificateRecordFields.Certificate)) {
var statement = $"UPDATE {CertificatesTableName} SET {CertificateColumnNames.SubjectEmail} = @SUBJECTEMAIL, {CertificateColumnNames.SubjectDnsNames} = @SUBJECTDNSNAMES WHERE {CertificateColumnNames.Id} = @ID";

transaction.Commit ();
} catch {
transaction.Rollback ();
throw;
using (var command = CreateCommand ()) {
command.AddParameterWithValue ("@ID", record.Id);
command.AddParameterWithValue ("@SUBJECTEMAIL", record.SubjectEmail);
command.AddParameterWithValue ("@SUBJECTDNSNAMES", EncodeDnsNames (record.SubjectDnsNames));
command.CommandType = CommandType.Text;
command.CommandText = statement;

command.ExecuteNonQuery ();
}
}
}
});

// Remove some old indexes
RemoveIndex (connection, table.TableName, CertificateColumnNames.BasicConstraints, CertificateColumnNames.SubjectEmail, CertificateColumnNames.NotBefore, CertificateColumnNames.NotAfter);
Expand Down Expand Up @@ -435,7 +421,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, X509Cert
var fingerprint = certificate.GetFingerprint ().ToLowerInvariant ();
var serialNumber = certificate.SerialNumber.ToString ();
var issuerName = certificate.IssuerDN.ToString ();
var command = connection.CreateCommand ();
var command = CreateCommand ();
var query = CreateSelectQuery (fields);

// FIXME: Is this really the best way to query for an exact match of a certificate?
Expand Down Expand Up @@ -467,7 +453,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, X509Cert
/// <param name="fields">The fields to return.</param>
protected override DbCommand GetSelectCommand (DbConnection connection, MailboxAddress mailbox, DateTime now, bool requirePrivateKey, X509CertificateRecordFields fields)
{
var command = connection.CreateCommand ();
var command = CreateCommand ();
var query = CreateSelectQuery (fields);

query = query.Append (" WHERE ").Append (CertificateColumnNames.BasicConstraints).Append (" = @BASICCONSTRAINTS ");
Expand Down Expand Up @@ -521,7 +507,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, MailboxA
/// <param name="fields">The fields to return.</param>
protected override DbCommand GetSelectCommand (DbConnection connection, ISelector<X509Certificate> selector, bool trustedAnchorsOnly, bool requirePrivateKey, X509CertificateRecordFields fields)
{
var command = connection.CreateCommand ();
var command = CreateCommand ();
var query = CreateSelectQuery (fields);
int baseQueryLength = query.Length;

Expand Down Expand Up @@ -655,7 +641,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, ISelecto
protected override DbCommand GetSelectCommand (DbConnection connection, X509Name issuer, X509CrlRecordFields fields)
{
var query = CreateSelectQuery (fields).Append (" WHERE ").Append (CrlColumnNames.IssuerName).Append (" = @ISSUERNAME");
var command = connection.CreateCommand ();
var command = CreateCommand ();

command.CommandText = query.ToString ();
command.AddParameterWithValue ("@ISSUERNAME", issuer.ToString ());
Expand All @@ -681,7 +667,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, X509Crl
.Append (CrlColumnNames.IssuerName).Append ("= @ISSUERNAME AND ")
.Append (CrlColumnNames.ThisUpdate).Append (" = @THISUPDATE LIMIT 1");
var issuerName = crl.IssuerDN.ToString ();
var command = connection.CreateCommand ();
var command = CreateCommand ();

command.CommandText = query.ToString ();
command.AddParameterWithValue ("@DELTA", crl.IsDelta ());
Expand All @@ -702,7 +688,7 @@ protected override DbCommand GetSelectCommand (DbConnection connection, X509Crl
/// <param name="connection">The database connection.</param>
protected override DbCommand GetSelectAllCrlsCommand (DbConnection connection)
{
var command = connection.CreateCommand ();
var command = CreateCommand ();

command.CommandText = $"SELECT {CrlColumnNames.Id}, {CrlColumnNames.Crl} FROM {CrlsTableName}";
command.CommandType = CommandType.Text;
Expand All @@ -721,7 +707,7 @@ protected override DbCommand GetSelectAllCrlsCommand (DbConnection connection)
/// <param name="record">The certificate record.</param>
protected override DbCommand GetDeleteCommand (DbConnection connection, X509CertificateRecord record)
{
var command = connection.CreateCommand ();
var command = CreateCommand ();

command.CommandText = $"DELETE FROM {CertificatesTableName} WHERE {CertificateColumnNames.Id} = @ID";
command.AddParameterWithValue ("@ID", record.Id);
Expand All @@ -741,7 +727,7 @@ protected override DbCommand GetDeleteCommand (DbConnection connection, X509Cert
/// <param name="record">The record.</param>
protected override DbCommand GetDeleteCommand (DbConnection connection, X509CrlRecord record)
{
var command = connection.CreateCommand ();
var command = CreateCommand ();

command.CommandText = $"DELETE FROM {CrlsTableName} WHERE {CrlColumnNames.Id} = @ID";
command.AddParameterWithValue ("@ID", record.Id);
Expand All @@ -763,7 +749,7 @@ protected override DbCommand GetInsertCommand (DbConnection connection, X509Cert
{
var statement = new StringBuilder ("INSERT INTO ").Append (CertificatesTableName).Append ('(');
var variables = new StringBuilder ("VALUES(");
var command = connection.CreateCommand ();
var command = CreateCommand ();
var columns = CertificatesTable.Columns;

for (int i = 1; i < columns.Count; i++) {
Expand Down Expand Up @@ -802,7 +788,7 @@ protected override DbCommand GetInsertCommand (DbConnection connection, X509CrlR
{
var statement = new StringBuilder ("INSERT INTO ").Append (CrlsTableName).Append ('(');
var variables = new StringBuilder ("VALUES(");
var command = connection.CreateCommand ();
var command = CreateCommand ();
var columns = CrlsTable.Columns;

for (int i = 1; i < columns.Count; i++) {
Expand Down Expand Up @@ -842,7 +828,7 @@ protected override DbCommand GetUpdateCommand (DbConnection connection, X509Cert
{
var statement = new StringBuilder ("UPDATE ").Append (CertificatesTableName).Append (" SET ");
var columns = GetColumnNames (fields & ~X509CertificateRecordFields.Id);
var command = connection.CreateCommand ();
var command = CreateCommand ();

for (int i = 0; i < columns.Length; i++) {
var value = GetValue (record, columns[i]);
Expand Down Expand Up @@ -880,7 +866,7 @@ protected override DbCommand GetUpdateCommand (DbConnection connection, X509Cert
protected override DbCommand GetUpdateCommand (DbConnection connection, X509CrlRecord record)
{
var statement = new StringBuilder ("UPDATE ").Append (CrlsTableName).Append (" SET ");
var command = connection.CreateCommand ();
var command = CreateCommand ();
var columns = CrlsTable.Columns;

for (int i = 1; i < columns.Count; i++) {
Expand Down
6 changes: 3 additions & 3 deletions MimeKit/Cryptography/SqliteCertificateDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ public SqliteCertificateDatabase (DbConnection connection, string password, Secu
/// <returns>The list of columns.</returns>
protected override IList<DataColumn> GetTableColumns (DbConnection connection, string tableName)
{
using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = $"PRAGMA table_info({tableName})";
using (var reader = command.ExecuteReader ()) {
var columns = new List<DataColumn> ();
Expand Down Expand Up @@ -422,7 +422,7 @@ protected override void CreateTable (DbConnection connection, DataTable table)

statement.Append (')');

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = statement.ToString ();
command.CommandType = CommandType.Text;
command.ExecuteNonQuery ();
Expand All @@ -447,7 +447,7 @@ protected override void AddTableColumn (DbConnection connection, DataTable table
statement.Append (" ADD COLUMN ");
Build (statement, table, column, ref primaryKeys, false);

using (var command = connection.CreateCommand ()) {
using (var command = CreateCommand ()) {
command.CommandText = statement.ToString ();
command.CommandType = CommandType.Text;
command.ExecuteNonQuery ();
Expand Down
Loading

0 comments on commit 893a29c

Please sign in to comment.