Skip to content

Commit

Permalink
Merge pull request #9 from trocco-io/feature/bug-fix-for-nonascii-col…
Browse files Browse the repository at this point in the history
…umn-name

Fix SQL syntax error in COPY INTO when using non-ascii characters column name
  • Loading branch information
NamedPython authored Dec 20, 2024
2 parents 54844dc + 2e976a9 commit c7d0ebb
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 28 deletions.
4 changes: 3 additions & 1 deletion example/test.yml.example
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# The catalog_name and schema_name must be created in advance.
# The catalog_name.schema_name, catalog_name.non_ascii_schema_name, non_ascii_catalog_name.non_ascii_schema_name and non_ascii_catalog_name.schema_name must be created in advance.

server_hostname:
http_path:
personal_access_token:
catalog_name:
schema_name:
non_ascii_schema_name:
non_ascii_catalog_name:
table_prefix:
staging_volume_name_prefix:
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema
if (i != 0) {
sb.append(" , ");
}
sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), column.getName()));
String quotedColumnName = quoteIdentifierString(column.getName());
sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), quotedColumnName));
}
sb.append(" FROM ");
sb.append(quoteIdentifierString(filePath, "\""));
Expand All @@ -120,6 +121,15 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema
return sb.toString();
}

@Override
protected String quoteIdentifierString(String str, String quoteString) {
// https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html
if (quoteString.equals("`")) {
return quoteString + str.replaceAll(quoteString, quoteString + quoteString) + quoteString;
}
return super.quoteIdentifierString(str, quoteString);
}

// This is almost a copy of JdbcOutputConnection except for aggregating fromTables to first from
// table,
// because Databricks MERGE INTO source can only specify a single table.
Expand Down
130 changes: 130 additions & 0 deletions src/test/java/org/embulk/output/databricks/TestDatabaseMetadata.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package org.embulk.output.databricks;

import static java.lang.String.format;
import static org.embulk.output.databricks.util.ConnectionUtil.*;
import static org.junit.Assert.assertEquals;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import org.embulk.output.databricks.util.ConfigUtil;
import org.embulk.output.jdbc.JdbcUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

// The purpose of this class is to understand the behavior of DatabaseMetadata,
// so if this test fails due to a library update, please change the test result.
public class TestDatabaseMetadata {
private DatabaseMetaData dbm;
private Connection conn;

ConfigUtil.TestTask t = ConfigUtil.createTestTask();
String catalog = t.getCatalogName();
String schema = t.getSchemaName();
String table = t.getTablePrefix() + "_test";
String nonAsciiCatalog = t.getNonAsciiCatalogName();
String nonAsciiSchema = t.getNonAsciiSchemaName();
String nonAsciiTable = t.getTablePrefix() + "_テスト";

@Before
public void setup() throws SQLException, ClassNotFoundException {
conn = connectByTestTask();
dbm = conn.getMetaData();
run(conn, "USE CATALOG " + catalog);
run(conn, "USE SCHEMA " + schema);
createTables();
}

@After
public void cleanup() {
try {
conn.close();
} catch (SQLException ignored) {

}
dropAllTemporaryTables();
}

@Test
public void testGetPrimaryKeys() throws SQLException {
assertEquals(1, countPrimaryKeys(catalog, schema, table, "a0"));
assertEquals(1, countPrimaryKeys(null, schema, table, "a0"));
assertEquals(1, countPrimaryKeys(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0"));
assertEquals(1, countPrimaryKeys(null, nonAsciiSchema, nonAsciiTable, "d0"));
}

@Test
public void testGetTables() throws SQLException {
assertEquals(1, countTablesResult(catalog, schema, table));
assertEquals(2, countTablesResult(null, schema, table));
assertEquals(1, countTablesResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable));
assertEquals(0, countTablesResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2
}

@Test
public void testGetColumns() throws SQLException {
assertEquals(2, countColumnsResult(catalog, schema, table));
assertEquals(4, countColumnsResult(null, schema, table));
assertEquals(2, countColumnsResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable));
assertEquals(0, countColumnsResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2
}

private void createTables() {
String queryFormat =
"CREATE TABLE IF NOT EXISTS `%s`.`%s`.`%s` (%s String PRIMARY KEY, %s INTEGER)";
run(conn, format(queryFormat, catalog, schema, table, "a0", "a1"));
run(conn, format(queryFormat, catalog, schema, nonAsciiTable, "b0", "b1"));
run(conn, format(queryFormat, catalog, nonAsciiSchema, table, "c0", "c1"));
run(conn, format(queryFormat, catalog, nonAsciiSchema, nonAsciiTable, "d0", "d1"));
run(conn, format(queryFormat, nonAsciiCatalog, schema, table, "e0", "e1"));
run(conn, format(queryFormat, nonAsciiCatalog, schema, nonAsciiTable, "f0", "f1"));
run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, table, "g0", "g1"));
run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0", "h1"));
}

private int countPrimaryKeys(
String catalogName, String schemaName, String tableName, String primaryKey)
throws SQLException {
try (ResultSet rs = dbm.getPrimaryKeys(catalogName, schemaName, tableName)) {
int count = 0;
while (rs.next()) {
String columnName = rs.getString("COLUMN_NAME");
assertEquals(primaryKey, columnName);
count += 1;
}
return count;
}
}

private int countTablesResult(String catalogName, String schemaName, String tableName)
throws SQLException {
String e = dbm.getSearchStringEscape();
String c = JdbcUtils.escapeSearchString(catalogName, e);
String s = JdbcUtils.escapeSearchString(schemaName, e);
String t = JdbcUtils.escapeSearchString(tableName, e);
try (ResultSet rs = dbm.getTables(c, s, t, null)) {
return countResultSet(rs);
}
}

private int countColumnsResult(String catalogName, String schemaName, String tableName)
throws SQLException {
String e = dbm.getSearchStringEscape();
String c = JdbcUtils.escapeSearchString(catalogName, e);
String s = JdbcUtils.escapeSearchString(schemaName, e);
String t = JdbcUtils.escapeSearchString(tableName, e);
try (ResultSet rs = dbm.getColumns(c, s, t, null)) {
return countResultSet(rs);
}
}

private int countResultSet(ResultSet rs) throws SQLException {
int count = 0;
while (rs.next()) {
count += 1;
}
return count;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package org.embulk.output.databricks;

import static org.embulk.output.databricks.util.ConnectionUtil.*;
import static org.junit.Assert.assertTrue;

import java.sql.*;
import java.util.*;
import java.util.concurrent.Executor;
import org.embulk.output.databricks.util.ConfigUtil;
import org.embulk.output.databricks.util.ConnectionUtil;
import org.embulk.output.jdbc.JdbcColumn;
import org.embulk.output.jdbc.JdbcSchema;
import org.embulk.output.jdbc.MergeConfig;
Expand All @@ -11,21 +16,46 @@
import org.junit.Test;

public class TestDatabricksOutputConnection {
@Test
public void testTableExists() throws SQLException, ClassNotFoundException {
ConfigUtil.TestTask t = ConfigUtil.createTestTask();
String asciiTableName = t.getTablePrefix() + "_test";
String nonAsciiTableName = t.getTablePrefix() + "_テスト";
testTableExists(t.getCatalogName(), t.getSchemaName(), asciiTableName);
testTableExists(t.getNonAsciiCatalogName(), t.getSchemaName(), asciiTableName);
testTableExists(t.getCatalogName(), t.getNonAsciiSchemaName(), asciiTableName);
testTableExists(t.getCatalogName(), t.getSchemaName(), nonAsciiTableName);
testTableExists(t.getNonAsciiCatalogName(), t.getNonAsciiSchemaName(), nonAsciiTableName);
}

private void testTableExists(String catalogName, String schemaName, String tableName)
throws SQLException, ClassNotFoundException {
String fullTableName = String.format("`%s`.`%s`.`%s`", catalogName, schemaName, tableName);
try (Connection conn = ConnectionUtil.connectByTestTask()) {
run(conn, "CREATE TABLE IF NOT EXISTS " + fullTableName);
try (DatabricksOutputConnection outputConn =
buildOutputConnection(conn, catalogName, schemaName)) {
assertTrue(outputConn.tableExists(new TableIdentifier(null, null, tableName)));
}
} finally {
run("DROP TABLE IF EXISTS " + fullTableName);
}
}

@Test
public void TestBuildCopySQL() throws SQLException {
try (DatabricksOutputConnection conn = buildOutputConnection()) {
public void testBuildCopySQL() throws SQLException {
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
TableIdentifier tableIdentifier = new TableIdentifier("database", "schemaName", "tableName");
String actual = conn.buildCopySQL(tableIdentifier, "filePath", buildJdbcSchema());
String expected =
"COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string col0 , _c1::bigint col1 FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )";
"COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string `あ` , _c1::bigint ```` FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )";
Assert.assertEquals(expected, actual);
}
}

@Test
public void TestBuildAggregateSQL() throws SQLException {
try (DatabricksOutputConnection conn = buildOutputConnection()) {
public void testBuildAggregateSQL() throws SQLException {
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
List<TableIdentifier> fromTableIdentifiers = new ArrayList<>();
fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName0"));
fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName1"));
Expand All @@ -39,28 +69,28 @@ public void TestBuildAggregateSQL() throws SQLException {
}

@Test
public void TestMergeConfigSQLWithMergeRules() throws SQLException {
public void testMergeConfigSQLWithMergeRules() throws SQLException {
List<String> mergeKeys = buildMergeKeys("col0", "col1");
Optional<List<String>> mergeRules =
buildMergeRules("col0 = CONCAT(T.col0, 'test')", "col1 = T.col1 + S.col1");
String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules));
String expected =
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);";
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (``, ````) VALUES (S.``, S.````);";
Assert.assertEquals(expected, actual);
}

@Test
public void TestMergeConfigSQLWithNoMergeRules() throws SQLException {
public void testMergeConfigSQLWithNoMergeRules() throws SQLException {
List<String> mergeKeys = buildMergeKeys("col0", "col1");
Optional<List<String>> mergeRules = Optional.empty();
String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules));
String expected =
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `col0` = S.`col0`, `col1` = S.`col1` WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);";
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `` = S.``, ```` = S.```` WHEN NOT MATCHED THEN INSERT (``, ````) VALUES (S.``, S.````);";
Assert.assertEquals(expected, actual);
}

private String mergeConfigSQL(MergeConfig mergeConfig) throws SQLException {
try (DatabricksOutputConnection conn = buildOutputConnection()) {
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
TableIdentifier aggregateToTable =
new TableIdentifier("database", "schemaName", "tableName9");
TableIdentifier toTable = new TableIdentifier("database", "schemaName", "tableName100");
Expand All @@ -76,15 +106,21 @@ private Optional<List<String>> buildMergeRules(String... keys) {
return keys.length > 0 ? Optional.of(Arrays.asList(keys)) : Optional.empty();
}

private DatabricksOutputConnection buildOutputConnection() throws SQLException {
private DatabricksOutputConnection buildOutputConnection(
Connection conn, String catalogName, String schemaName)
throws SQLException, ClassNotFoundException {
return new DatabricksOutputConnection(conn, catalogName, schemaName);
}

private DatabricksOutputConnection buildDummyOutputConnection() throws SQLException {
return new DatabricksOutputConnection(
buildDummyConnection(), "defaultCatalogName", "defaultSchemaName");
}

private JdbcSchema buildJdbcSchema() {
List<JdbcColumn> jdbcColumns = new ArrayList<>();
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col0", Types.VARCHAR, "string", true, false));
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col1", Types.BIGINT, "bigint", true, false));
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("", Types.VARCHAR, "string", true, false));
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("`", Types.BIGINT, "bigint", true, false));
return new JdbcSchema(jdbcColumns);
}

Expand Down
Loading

0 comments on commit c7d0ebb

Please sign in to comment.