Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SQL syntax error in COPY INTO when using non-ascii characters column name #9

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion example/test.yml.example
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# The catalog_name and schema_name must be created in advance.
# The catalog_name.schema_name, catalog_name.non_ascii_schema_name, non_ascii_catalog_name.non_ascii_schema_name and non_ascii_catalog_name.schema_name must be created in advance.

server_hostname:
http_path:
personal_access_token:
catalog_name:
schema_name:
non_ascii_schema_name:
non_ascii_catalog_name:
table_prefix:
staging_volume_name_prefix:
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema
if (i != 0) {
sb.append(" , ");
}
sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), column.getName()));
String quotedColumnName = quoteIdentifierString(column.getName());
sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), quotedColumnName));
}
sb.append(" FROM ");
sb.append(quoteIdentifierString(filePath, "\""));
Expand All @@ -157,6 +158,15 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema
return sb.toString();
}

@Override
protected String quoteIdentifierString(String str, String quoteString) {
// https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html
if (quoteString.equals("`")) {
return quoteString + str.replaceAll(quoteString, quoteString + quoteString) + quoteString;
}
return super.quoteIdentifierString(str, quoteString);
}

// This is almost a copy of JdbcOutputConnection except for aggregating fromTables to first from
// table,
// because Databricks MERGE INTO source can only specify a single table.
Expand Down
130 changes: 130 additions & 0 deletions src/test/java/org/embulk/output/databricks/TestDatabaseMetadata.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package org.embulk.output.databricks;

import static java.lang.String.format;
import static org.embulk.output.databricks.util.ConnectionUtil.*;
import static org.junit.Assert.assertEquals;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import org.embulk.output.databricks.util.ConfigUtil;
import org.embulk.output.jdbc.JdbcUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

// The purpose of this class is to understand the behavior of DatabaseMetadata,
// so if this test fails due to a library update, please change the test result.
public class TestDatabaseMetadata {
private DatabaseMetaData dbm;
private Connection conn;

ConfigUtil.TestTask t = ConfigUtil.createTestTask();
String catalog = t.getCatalogName();
String schema = t.getSchemaName();
String table = t.getTablePrefix() + "_test";
String nonAsciiCatalog = t.getNonAsciiCatalogName();
String nonAsciiSchema = t.getNonAsciiSchemaName();
String nonAsciiTable = t.getTablePrefix() + "_テスト";

@Before
public void setup() throws SQLException, ClassNotFoundException {
conn = connectByTestTask();
dbm = conn.getMetaData();
run(conn, "USE CATALOG " + catalog);
run(conn, "USE SCHEMA " + schema);
createTables();
}

@After
public void cleanup() {
try {
conn.close();
} catch (SQLException ignored) {

}
dropAllTemporaryTables();
}

@Test
public void testGetPrimaryKeys() throws SQLException {
assertEquals(1, countPrimaryKeys(catalog, schema, table, "a0"));
assertEquals(1, countPrimaryKeys(null, schema, table, "a0"));
assertEquals(1, countPrimaryKeys(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0"));
assertEquals(1, countPrimaryKeys(null, nonAsciiSchema, nonAsciiTable, "d0"));
}

@Test
public void testGetTables() throws SQLException {
assertEquals(1, countTablesResult(catalog, schema, table));
assertEquals(2, countTablesResult(null, schema, table));
assertEquals(1, countTablesResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable));
assertEquals(0, countTablesResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2
}

@Test
public void testGetColumns() throws SQLException {
assertEquals(2, countColumnsResult(catalog, schema, table));
assertEquals(4, countColumnsResult(null, schema, table));
assertEquals(2, countColumnsResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable));
assertEquals(0, countColumnsResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2
}

private void createTables() {
String queryFormat =
"CREATE TABLE IF NOT EXISTS `%s`.`%s`.`%s` (%s String PRIMARY KEY, %s INTEGER)";
run(conn, format(queryFormat, catalog, schema, table, "a0", "a1"));
run(conn, format(queryFormat, catalog, schema, nonAsciiTable, "b0", "b1"));
run(conn, format(queryFormat, catalog, nonAsciiSchema, table, "c0", "c1"));
run(conn, format(queryFormat, catalog, nonAsciiSchema, nonAsciiTable, "d0", "d1"));
run(conn, format(queryFormat, nonAsciiCatalog, schema, table, "e0", "e1"));
run(conn, format(queryFormat, nonAsciiCatalog, schema, nonAsciiTable, "f0", "f1"));
run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, table, "g0", "g1"));
run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0", "h1"));
}

private int countPrimaryKeys(
String catalogName, String schemaName, String tableName, String primaryKey)
throws SQLException {
try (ResultSet rs = dbm.getPrimaryKeys(catalogName, schemaName, tableName)) {
int count = 0;
while (rs.next()) {
String columnName = rs.getString("COLUMN_NAME");
assertEquals(primaryKey, columnName);
count += 1;
}
return count;
}
}

private int countTablesResult(String catalogName, String schemaName, String tableName)
throws SQLException {
String e = dbm.getSearchStringEscape();
String c = JdbcUtils.escapeSearchString(catalogName, e);
String s = JdbcUtils.escapeSearchString(schemaName, e);
String t = JdbcUtils.escapeSearchString(tableName, e);
try (ResultSet rs = dbm.getTables(c, s, t, null)) {
return countResultSet(rs);
}
}

private int countColumnsResult(String catalogName, String schemaName, String tableName)
throws SQLException {
String e = dbm.getSearchStringEscape();
String c = JdbcUtils.escapeSearchString(catalogName, e);
String s = JdbcUtils.escapeSearchString(schemaName, e);
String t = JdbcUtils.escapeSearchString(tableName, e);
try (ResultSet rs = dbm.getColumns(c, s, t, null)) {
return countResultSet(rs);
}
}

private int countResultSet(ResultSet rs) throws SQLException {
int count = 0;
while (rs.next()) {
count += 1;
}
return count;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package org.embulk.output.databricks;

import static org.embulk.output.databricks.util.ConnectionUtil.*;
import static org.junit.Assert.assertTrue;

import java.sql.*;
import java.util.*;
import java.util.concurrent.Executor;
import org.embulk.output.databricks.util.ConfigUtil;
import org.embulk.output.databricks.util.ConnectionUtil;
import org.embulk.output.jdbc.JdbcColumn;
import org.embulk.output.jdbc.JdbcSchema;
import org.embulk.output.jdbc.MergeConfig;
Expand All @@ -11,21 +16,46 @@
import org.junit.Test;

public class TestDatabricksOutputConnection {
@Test
public void testTableExists() throws SQLException, ClassNotFoundException {
ConfigUtil.TestTask t = ConfigUtil.createTestTask();
String asciiTableName = t.getTablePrefix() + "_test";
String nonAsciiTableName = t.getTablePrefix() + "_テスト";
testTableExists(t.getCatalogName(), t.getSchemaName(), asciiTableName);
testTableExists(t.getNonAsciiCatalogName(), t.getSchemaName(), asciiTableName);
testTableExists(t.getCatalogName(), t.getNonAsciiSchemaName(), asciiTableName);
testTableExists(t.getCatalogName(), t.getSchemaName(), nonAsciiTableName);
testTableExists(t.getNonAsciiCatalogName(), t.getNonAsciiSchemaName(), nonAsciiTableName);
}

private void testTableExists(String catalogName, String schemaName, String tableName)
throws SQLException, ClassNotFoundException {
String fullTableName = String.format("`%s`.`%s`.`%s`", catalogName, schemaName, tableName);
try (Connection conn = ConnectionUtil.connectByTestTask()) {
run(conn, "CREATE TABLE IF NOT EXISTS " + fullTableName);
try (DatabricksOutputConnection outputConn =
buildOutputConnection(conn, catalogName, schemaName)) {
assertTrue(outputConn.tableExists(new TableIdentifier(null, null, tableName)));
}
} finally {
run("DROP TABLE IF EXISTS " + fullTableName);
}
}

@Test
public void TestBuildCopySQL() throws SQLException {
try (DatabricksOutputConnection conn = buildOutputConnection()) {
public void testBuildCopySQL() throws SQLException {
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
TableIdentifier tableIdentifier = new TableIdentifier("database", "schemaName", "tableName");
String actual = conn.buildCopySQL(tableIdentifier, "filePath", buildJdbcSchema());
String expected =
"COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string col0 , _c1::bigint col1 FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )";
"COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string `あ` , _c1::bigint ```` FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )";
Assert.assertEquals(expected, actual);
}
}

@Test
public void TestBuildAggregateSQL() throws SQLException {
try (DatabricksOutputConnection conn = buildOutputConnection()) {
public void testBuildAggregateSQL() throws SQLException {
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
List<TableIdentifier> fromTableIdentifiers = new ArrayList<>();
fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName0"));
fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName1"));
Expand All @@ -39,28 +69,28 @@ public void TestBuildAggregateSQL() throws SQLException {
}

@Test
public void TestMergeConfigSQLWithMergeRules() throws SQLException {
public void testMergeConfigSQLWithMergeRules() throws SQLException {
List<String> mergeKeys = buildMergeKeys("col0", "col1");
Optional<List<String>> mergeRules =
buildMergeRules("col0 = CONCAT(T.col0, 'test')", "col1 = T.col1 + S.col1");
String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules));
String expected =
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);";
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (``, ````) VALUES (S.``, S.````);";
Assert.assertEquals(expected, actual);
}

@Test
public void TestMergeConfigSQLWithNoMergeRules() throws SQLException {
public void testMergeConfigSQLWithNoMergeRules() throws SQLException {
List<String> mergeKeys = buildMergeKeys("col0", "col1");
Optional<List<String>> mergeRules = Optional.empty();
String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules));
String expected =
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `col0` = S.`col0`, `col1` = S.`col1` WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);";
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `` = S.``, ```` = S.```` WHEN NOT MATCHED THEN INSERT (``, ````) VALUES (S.``, S.````);";
Assert.assertEquals(expected, actual);
}

private String mergeConfigSQL(MergeConfig mergeConfig) throws SQLException {
try (DatabricksOutputConnection conn = buildOutputConnection()) {
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
TableIdentifier aggregateToTable =
new TableIdentifier("database", "schemaName", "tableName9");
TableIdentifier toTable = new TableIdentifier("database", "schemaName", "tableName100");
Expand All @@ -76,15 +106,21 @@ private Optional<List<String>> buildMergeRules(String... keys) {
return keys.length > 0 ? Optional.of(Arrays.asList(keys)) : Optional.empty();
}

private DatabricksOutputConnection buildOutputConnection() throws SQLException {
private DatabricksOutputConnection buildOutputConnection(
Connection conn, String catalogName, String schemaName)
throws SQLException, ClassNotFoundException {
return new DatabricksOutputConnection(conn, catalogName, schemaName);
}

private DatabricksOutputConnection buildDummyOutputConnection() throws SQLException {
return new DatabricksOutputConnection(
buildDummyConnection(), "defaultCatalogName", "defaultSchemaName");
}

private JdbcSchema buildJdbcSchema() {
List<JdbcColumn> jdbcColumns = new ArrayList<>();
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col0", Types.VARCHAR, "string", true, false));
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col1", Types.BIGINT, "bigint", true, false));
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("", Types.VARCHAR, "string", true, false));
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("`", Types.BIGINT, "bigint", true, false));
return new JdbcSchema(jdbcColumns);
}

Expand Down
Loading
Loading