diff --git a/example/test.yml.example b/example/test.yml.example index 5c449eb..99d3e15 100644 --- a/example/test.yml.example +++ b/example/test.yml.example @@ -1,9 +1,11 @@ -# The catalog_name and schema_name must be created in advance. +# The catalog_name.schema_name, catalog_name.non_ascii_schema_name, non_ascii_catalog_name.non_ascii_schema_name and non_ascii_catalog_name.schema_name must be created in advance. server_hostname: http_path: personal_access_token: catalog_name: schema_name: +non_ascii_schema_name: +non_ascii_catalog_name: table_prefix: staging_volume_name_prefix: diff --git a/src/main/java/org/embulk/output/databricks/DatabricksOutputConnection.java b/src/main/java/org/embulk/output/databricks/DatabricksOutputConnection.java index 7355437..6b8717b 100644 --- a/src/main/java/org/embulk/output/databricks/DatabricksOutputConnection.java +++ b/src/main/java/org/embulk/output/databricks/DatabricksOutputConnection.java @@ -144,7 +144,8 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema if (i != 0) { sb.append(" , "); } - sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), column.getName())); + String quotedColumnName = quoteIdentifierString(column.getName()); + sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), quotedColumnName)); } sb.append(" FROM "); sb.append(quoteIdentifierString(filePath, "\"")); @@ -157,6 +158,15 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema return sb.toString(); } + @Override + protected String quoteIdentifierString(String str, String quoteString) { + // https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html + if (quoteString.equals("`")) { + return quoteString + str.replaceAll(quoteString, quoteString + quoteString) + quoteString; + } + return super.quoteIdentifierString(str, quoteString); + } + // This is almost a copy of JdbcOutputConnection except for aggregating fromTables to first from // table, // because Databricks MERGE INTO source can only specify a single table. diff --git a/src/test/java/org/embulk/output/databricks/TestDatabaseMetadata.java b/src/test/java/org/embulk/output/databricks/TestDatabaseMetadata.java new file mode 100644 index 0000000..5624053 --- /dev/null +++ b/src/test/java/org/embulk/output/databricks/TestDatabaseMetadata.java @@ -0,0 +1,130 @@ +package org.embulk.output.databricks; + +import static java.lang.String.format; +import static org.embulk.output.databricks.util.ConnectionUtil.*; +import static org.junit.Assert.assertEquals; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import org.embulk.output.databricks.util.ConfigUtil; +import org.embulk.output.jdbc.JdbcUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +// The purpose of this class is to understand the behavior of DatabaseMetadata, +// so if this test fails due to a library update, please change the test result. +public class TestDatabaseMetadata { + private DatabaseMetaData dbm; + private Connection conn; + + ConfigUtil.TestTask t = ConfigUtil.createTestTask(); + String catalog = t.getCatalogName(); + String schema = t.getSchemaName(); + String table = t.getTablePrefix() + "_test"; + String nonAsciiCatalog = t.getNonAsciiCatalogName(); + String nonAsciiSchema = t.getNonAsciiSchemaName(); + String nonAsciiTable = t.getTablePrefix() + "_テスト"; + + @Before + public void setup() throws SQLException, ClassNotFoundException { + conn = connectByTestTask(); + dbm = conn.getMetaData(); + run(conn, "USE CATALOG " + catalog); + run(conn, "USE SCHEMA " + schema); + createTables(); + } + + @After + public void cleanup() { + try { + conn.close(); + } catch (SQLException ignored) { + + } + dropAllTemporaryTables(); + } + + @Test + public void testGetPrimaryKeys() throws SQLException { + assertEquals(1, countPrimaryKeys(catalog, schema, table, "a0")); + assertEquals(1, countPrimaryKeys(null, schema, table, "a0")); + assertEquals(1, countPrimaryKeys(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0")); + assertEquals(1, countPrimaryKeys(null, nonAsciiSchema, nonAsciiTable, "d0")); + } + + @Test + public void testGetTables() throws SQLException { + assertEquals(1, countTablesResult(catalog, schema, table)); + assertEquals(2, countTablesResult(null, schema, table)); + assertEquals(1, countTablesResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable)); + assertEquals(0, countTablesResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2 + } + + @Test + public void testGetColumns() throws SQLException { + assertEquals(2, countColumnsResult(catalog, schema, table)); + assertEquals(4, countColumnsResult(null, schema, table)); + assertEquals(2, countColumnsResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable)); + assertEquals(0, countColumnsResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2 + } + + private void createTables() { + String queryFormat = + "CREATE TABLE IF NOT EXISTS `%s`.`%s`.`%s` (%s String PRIMARY KEY, %s INTEGER)"; + run(conn, format(queryFormat, catalog, schema, table, "a0", "a1")); + run(conn, format(queryFormat, catalog, schema, nonAsciiTable, "b0", "b1")); + run(conn, format(queryFormat, catalog, nonAsciiSchema, table, "c0", "c1")); + run(conn, format(queryFormat, catalog, nonAsciiSchema, nonAsciiTable, "d0", "d1")); + run(conn, format(queryFormat, nonAsciiCatalog, schema, table, "e0", "e1")); + run(conn, format(queryFormat, nonAsciiCatalog, schema, nonAsciiTable, "f0", "f1")); + run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, table, "g0", "g1")); + run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0", "h1")); + } + + private int countPrimaryKeys( + String catalogName, String schemaName, String tableName, String primaryKey) + throws SQLException { + try (ResultSet rs = dbm.getPrimaryKeys(catalogName, schemaName, tableName)) { + int count = 0; + while (rs.next()) { + String columnName = rs.getString("COLUMN_NAME"); + assertEquals(primaryKey, columnName); + count += 1; + } + return count; + } + } + + private int countTablesResult(String catalogName, String schemaName, String tableName) + throws SQLException { + String e = dbm.getSearchStringEscape(); + String c = JdbcUtils.escapeSearchString(catalogName, e); + String s = JdbcUtils.escapeSearchString(schemaName, e); + String t = JdbcUtils.escapeSearchString(tableName, e); + try (ResultSet rs = dbm.getTables(c, s, t, null)) { + return countResultSet(rs); + } + } + + private int countColumnsResult(String catalogName, String schemaName, String tableName) + throws SQLException { + String e = dbm.getSearchStringEscape(); + String c = JdbcUtils.escapeSearchString(catalogName, e); + String s = JdbcUtils.escapeSearchString(schemaName, e); + String t = JdbcUtils.escapeSearchString(tableName, e); + try (ResultSet rs = dbm.getColumns(c, s, t, null)) { + return countResultSet(rs); + } + } + + private int countResultSet(ResultSet rs) throws SQLException { + int count = 0; + while (rs.next()) { + count += 1; + } + return count; + } +} diff --git a/src/test/java/org/embulk/output/databricks/TestDatabricksOutputConnection.java b/src/test/java/org/embulk/output/databricks/TestDatabricksOutputConnection.java index f3b087d..f3025e9 100644 --- a/src/test/java/org/embulk/output/databricks/TestDatabricksOutputConnection.java +++ b/src/test/java/org/embulk/output/databricks/TestDatabricksOutputConnection.java @@ -1,8 +1,13 @@ package org.embulk.output.databricks; +import static org.embulk.output.databricks.util.ConnectionUtil.*; +import static org.junit.Assert.assertTrue; + import java.sql.*; import java.util.*; import java.util.concurrent.Executor; +import org.embulk.output.databricks.util.ConfigUtil; +import org.embulk.output.databricks.util.ConnectionUtil; import org.embulk.output.jdbc.JdbcColumn; import org.embulk.output.jdbc.JdbcSchema; import org.embulk.output.jdbc.MergeConfig; @@ -11,21 +16,46 @@ import org.junit.Test; public class TestDatabricksOutputConnection { + @Test + public void testTableExists() throws SQLException, ClassNotFoundException { + ConfigUtil.TestTask t = ConfigUtil.createTestTask(); + String asciiTableName = t.getTablePrefix() + "_test"; + String nonAsciiTableName = t.getTablePrefix() + "_テスト"; + testTableExists(t.getCatalogName(), t.getSchemaName(), asciiTableName); + testTableExists(t.getNonAsciiCatalogName(), t.getSchemaName(), asciiTableName); + testTableExists(t.getCatalogName(), t.getNonAsciiSchemaName(), asciiTableName); + testTableExists(t.getCatalogName(), t.getSchemaName(), nonAsciiTableName); + testTableExists(t.getNonAsciiCatalogName(), t.getNonAsciiSchemaName(), nonAsciiTableName); + } + + private void testTableExists(String catalogName, String schemaName, String tableName) + throws SQLException, ClassNotFoundException { + String fullTableName = String.format("`%s`.`%s`.`%s`", catalogName, schemaName, tableName); + try (Connection conn = ConnectionUtil.connectByTestTask()) { + run(conn, "CREATE TABLE IF NOT EXISTS " + fullTableName); + try (DatabricksOutputConnection outputConn = + buildOutputConnection(conn, catalogName, schemaName)) { + assertTrue(outputConn.tableExists(new TableIdentifier(null, null, tableName))); + } + } finally { + run("DROP TABLE IF EXISTS " + fullTableName); + } + } @Test - public void TestBuildCopySQL() throws SQLException { - try (DatabricksOutputConnection conn = buildOutputConnection()) { + public void testBuildCopySQL() throws SQLException { + try (DatabricksOutputConnection conn = buildDummyOutputConnection()) { TableIdentifier tableIdentifier = new TableIdentifier("database", "schemaName", "tableName"); String actual = conn.buildCopySQL(tableIdentifier, "filePath", buildJdbcSchema()); String expected = - "COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string col0 , _c1::bigint col1 FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )"; + "COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string `あ` , _c1::bigint ```` FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )"; Assert.assertEquals(expected, actual); } } @Test - public void TestBuildAggregateSQL() throws SQLException { - try (DatabricksOutputConnection conn = buildOutputConnection()) { + public void testBuildAggregateSQL() throws SQLException { + try (DatabricksOutputConnection conn = buildDummyOutputConnection()) { List fromTableIdentifiers = new ArrayList<>(); fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName0")); fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName1")); @@ -39,28 +69,28 @@ public void TestBuildAggregateSQL() throws SQLException { } @Test - public void TestMergeConfigSQLWithMergeRules() throws SQLException { + public void testMergeConfigSQLWithMergeRules() throws SQLException { List mergeKeys = buildMergeKeys("col0", "col1"); Optional> mergeRules = buildMergeRules("col0 = CONCAT(T.col0, 'test')", "col1 = T.col1 + S.col1"); String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules)); String expected = - "MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);"; + "MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (`あ`, ````) VALUES (S.`あ`, S.````);"; Assert.assertEquals(expected, actual); } @Test - public void TestMergeConfigSQLWithNoMergeRules() throws SQLException { + public void testMergeConfigSQLWithNoMergeRules() throws SQLException { List mergeKeys = buildMergeKeys("col0", "col1"); Optional> mergeRules = Optional.empty(); String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules)); String expected = - "MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `col0` = S.`col0`, `col1` = S.`col1` WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);"; + "MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `あ` = S.`あ`, ```` = S.```` WHEN NOT MATCHED THEN INSERT (`あ`, ````) VALUES (S.`あ`, S.````);"; Assert.assertEquals(expected, actual); } private String mergeConfigSQL(MergeConfig mergeConfig) throws SQLException { - try (DatabricksOutputConnection conn = buildOutputConnection()) { + try (DatabricksOutputConnection conn = buildDummyOutputConnection()) { TableIdentifier aggregateToTable = new TableIdentifier("database", "schemaName", "tableName9"); TableIdentifier toTable = new TableIdentifier("database", "schemaName", "tableName100"); @@ -76,15 +106,21 @@ private Optional> buildMergeRules(String... keys) { return keys.length > 0 ? Optional.of(Arrays.asList(keys)) : Optional.empty(); } - private DatabricksOutputConnection buildOutputConnection() throws SQLException { + private DatabricksOutputConnection buildOutputConnection( + Connection conn, String catalogName, String schemaName) + throws SQLException, ClassNotFoundException { + return new DatabricksOutputConnection(conn, catalogName, schemaName); + } + + private DatabricksOutputConnection buildDummyOutputConnection() throws SQLException { return new DatabricksOutputConnection( buildDummyConnection(), "defaultCatalogName", "defaultSchemaName"); } private JdbcSchema buildJdbcSchema() { List jdbcColumns = new ArrayList<>(); - jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col0", Types.VARCHAR, "string", true, false)); - jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col1", Types.BIGINT, "bigint", true, false)); + jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("あ", Types.VARCHAR, "string", true, false)); + jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("`", Types.BIGINT, "bigint", true, false)); return new JdbcSchema(jdbcColumns); } diff --git a/src/test/java/org/embulk/output/databricks/TestDatabricksOutputPluginByNonAscii.java b/src/test/java/org/embulk/output/databricks/TestDatabricksOutputPluginByNonAscii.java new file mode 100644 index 0000000..6d2a23c --- /dev/null +++ b/src/test/java/org/embulk/output/databricks/TestDatabricksOutputPluginByNonAscii.java @@ -0,0 +1,162 @@ +package org.embulk.output.databricks; + +import static org.embulk.output.databricks.util.ConfigUtil.createPluginConfigSource; +import static org.embulk.output.databricks.util.ConfigUtil.setMergeRule; +import static org.embulk.output.databricks.util.ConfigUtil.setNonAsciiCatalogName; +import static org.embulk.output.databricks.util.ConfigUtil.setNonAsciiSchemaName; +import static org.embulk.output.databricks.util.ConfigUtil.setNonAsciiStagingVolumeNamePrefix; +import static org.embulk.output.databricks.util.ConfigUtil.setNonAsciiTable; +import static org.embulk.output.databricks.util.ConnectionUtil.quotedDstTableName; +import static org.embulk.output.databricks.util.ConnectionUtil.run; +import static org.embulk.output.databricks.util.ConnectionUtil.runQuery; +import static org.embulk.output.databricks.util.IOUtil.createInputFile; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.embulk.config.ConfigSource; +import org.embulk.output.databricks.util.ConfigUtil; +import org.embulk.output.jdbc.AbstractJdbcOutputPlugin.Mode; +import org.junit.Assert; +import org.junit.Test; + +public class TestDatabricksOutputPluginByNonAscii extends AbstractTestDatabricksOutputPlugin { + + @Test + public void testColumnNameInsert() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.INSERT); + runOutputAndAssertResult(configSource); + } + + @Test + public void testColumnNameInsertDirect() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.INSERT_DIRECT); + runOutputAndAssertResult(configSource); + } + + @Test + public void testColumnNameTruncateInsert() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.TRUNCATE_INSERT); + runOutputAndAssertResult(configSource); + } + + @Test + public void testColumnNameReplace() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.REPLACE); + runOutputAndAssertResult(configSource); + } + + @Test + public void testColumnNameMergeWithMergeKeys() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.MERGE); + ConfigUtil.setMergeKeys(configSource, "あ"); + setupForMerge(configSource, false); + runOutputAndAssertResult(configSource, "あ", "い"); + } + + @Test + public void testColumnNameWithoutMergeKeys() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.MERGE); + setupForMerge(configSource, true); + runOutputAndAssertResult(configSource, "あ", "い"); + } + + @Test + public void testColumnNameMergeWithMergeRule() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.MERGE); + setMergeRule(configSource, "`い` = CONCAT(T.`い`, 'あ', S.`い`)"); + setupForMerge(configSource, true); + runOutputAndAssertMergeWithMergeRule(configSource); + } + + @Test + public void testCatalogName() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.INSERT); + setNonAsciiCatalogName(configSource); + runOutputAndAssertResult(configSource); + } + + @Test + public void testSchemaName() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.INSERT); + setNonAsciiSchemaName(configSource); + runOutputAndAssertResult(configSource); + } + + @Test + public void testTableName() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.INSERT); + setNonAsciiTable(configSource); + runOutputAndAssertResult(configSource); + } + + @Test + public void testStagingVolumeNamePrefix() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.INSERT); + setNonAsciiStagingVolumeNamePrefix(configSource); + runOutputAndAssertResult(configSource); + } + + @Test + public void testAllAttributes() throws Exception { + ConfigSource configSource = createPluginConfigSource(Mode.MERGE); + setNonAsciiCatalogName(configSource); + setNonAsciiSchemaName(configSource); + setNonAsciiTable(configSource); + setNonAsciiStagingVolumeNamePrefix(configSource); + setMergeRule(configSource, "`い` = CONCAT(T.`い`, 'あ', S.`い`)"); + setupForMerge(configSource, true); + runOutputAndAssertMergeWithMergeRule(configSource); + } + + private void setupForMerge(ConfigSource configSource, boolean hasPrimaryKey) { + String quotedDstTableName = quotedDstTableName(configSource); + String primaryKey = hasPrimaryKey ? "PRIMARY KEY" : ""; + run("CREATE TABLE " + quotedDstTableName + " (`あ` STRING " + primaryKey + ", `い` STRING)"); + run("INSERT INTO " + quotedDstTableName + "(`あ`, `い`) VALUES ('test0', 'hoge')"); + } + + private void runOutputAndAssertMergeWithMergeRule(ConfigSource configSource) throws IOException { + runOutput(configSource, "あ", "い"); + + List> results = + runQuery("SELECT * FROM " + quotedDstTableName(configSource)); + Assert.assertEquals(1, results.size()); + Assert.assertEquals("test0", results.get(0).get("あ")); + Assert.assertEquals("hogeあtest1", results.get(0).get("い")); + } + + private void runOutputAndAssertResult(ConfigSource configSource) throws IOException { + runOutputAndAssertResult(configSource, "あ", "`", "\"", "'"); + } + + private void runOutputAndAssertResult(ConfigSource configSource, String... columnNames) + throws IOException { + runOutput(configSource, columnNames); + assertResult(configSource, columnNames); + } + + private void runOutput(ConfigSource configSource, String... columnNames) throws IOException { + String header = + Arrays.stream(columnNames).map(x -> x + ":string").collect(Collectors.joining(",")); + String data = + IntStream.range(0, columnNames.length) + .mapToObj(i -> "test" + i) + .collect(Collectors.joining(",")); + File inputFile = createInputFile(testFolder, header, data); + embulk.runOutput(configSource, inputFile.toPath()); + } + + private void assertResult(ConfigSource configSource, String... columnNames) { + List> results = + runQuery("SELECT * FROM " + quotedDstTableName(configSource)); + Assert.assertEquals(1, results.size()); + for (int i = 0; i < columnNames.length; i++) { + Assert.assertEquals("test" + i, results.get(0).get(columnNames[i])); + } + } +} diff --git a/src/test/java/org/embulk/output/databricks/util/ConfigUtil.java b/src/test/java/org/embulk/output/databricks/util/ConfigUtil.java index dd75772..6c0ab62 100644 --- a/src/test/java/org/embulk/output/databricks/util/ConfigUtil.java +++ b/src/test/java/org/embulk/output/databricks/util/ConfigUtil.java @@ -54,6 +54,12 @@ public interface TestTask extends Task { @Config("staging_volume_name_prefix") public String getStagingVolumeNamePrefix(); + + @Config("non_ascii_catalog_name") + public String getNonAsciiCatalogName(); + + @Config("non_ascii_schema_name") + public String getNonAsciiSchemaName(); } public static TestTask createTestTask() { @@ -89,6 +95,31 @@ public static DatabricksOutputPlugin.DatabricksPluginTask createPluginTask( return CONFIG_MAPPER.map(configSource, DatabricksOutputPlugin.DatabricksPluginTask.class); } + public static ConfigSource setNonAsciiCatalogName(ConfigSource configSource) { + return configSource.set("catalog_name", createTestTask().getNonAsciiCatalogName()); + } + + public static ConfigSource setNonAsciiSchemaName(ConfigSource configSource) { + return configSource.set("schema_name", createTestTask().getNonAsciiSchemaName()); + } + + public static ConfigSource setNonAsciiTable(ConfigSource configSource) { + return configSource.set("table", createTestTask().getTablePrefix() + "_マルチバイトテスト"); + } + + public static ConfigSource setNonAsciiStagingVolumeNamePrefix(ConfigSource configSource) { + String s = createTestTask().getStagingVolumeNamePrefix() + "_マルチバイトテスト"; + return configSource.set("staging_volume_name_prefix", s); + } + + public static ConfigSource setMergeKeys(ConfigSource configSource, String... mergeKeys) { + return configSource.set("merge_keys", mergeKeys); + } + + public static ConfigSource setMergeRule(ConfigSource configSource, String... mergeRule) { + return configSource.set("merge_rule", mergeRule); + } + public static ConfigSource setColumnOption( ConfigSource configSource, String columnName, String type) { return setColumnOption(configSource, columnName, type, null, null, null); diff --git a/src/test/java/org/embulk/output/databricks/util/ConnectionUtil.java b/src/test/java/org/embulk/output/databricks/util/ConnectionUtil.java index 0280cd5..f7fa73a 100644 --- a/src/test/java/org/embulk/output/databricks/util/ConnectionUtil.java +++ b/src/test/java/org/embulk/output/databricks/util/ConnectionUtil.java @@ -33,25 +33,36 @@ public static String quotedDstTableName(ConfigSource configSource) { public static void dropAllTemporaryTables() { ConfigUtil.TestTask t = ConfigUtil.createTestTask(); + try (Connection conn = connectByTestTask()) { + dropAllTemporaryTables(conn, t.getCatalogName(), t.getSchemaName()); + dropAllTemporaryTables(conn, t.getCatalogName(), t.getNonAsciiSchemaName()); + dropAllTemporaryTables(conn, t.getNonAsciiCatalogName(), t.getSchemaName()); + dropAllTemporaryTables(conn, t.getNonAsciiCatalogName(), t.getNonAsciiSchemaName()); + } catch (SQLException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private static void dropAllTemporaryTables( + Connection conn, String catalogName, String schemaName) { + ConfigUtil.TestTask t = ConfigUtil.createTestTask(); String tableNamesSQL = String.format( "select table_name from system.information_schema.tables where table_catalog = '%s' AND table_schema = '%s' AND table_name LIKE '%s%%'", - t.getCatalogName(), t.getSchemaName(), t.getTablePrefix()); - runQuery(tableNamesSQL) + catalogName, schemaName, t.getTablePrefix()); + runQuery(conn, tableNamesSQL) .forEach( x -> { String tableName = (String) x.get("table_name"); String dropSql = String.format( - "drop table if exists `%s`.`%s`.`%s`", - t.getCatalogName(), t.getSchemaName(), tableName); - run(dropSql); + "drop table if exists `%s`.`%s`.`%s`", catalogName, schemaName, tableName); + run(conn, dropSql); }); } - public static List> runQuery(String query) { - try (Connection conn = connectByTestTask(); - Statement stmt = conn.createStatement(); + public static List> runQuery(Connection conn, String query) { + try (Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(query)) { List> result = new ArrayList<>(); while (rs.next()) { @@ -62,15 +73,30 @@ public static List> runQuery(String query) { result.add(resMap); } return result; + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public static List> runQuery(String query) { + try (Connection conn = connectByTestTask()) { + return runQuery(conn, query); } catch (SQLException | ClassNotFoundException e) { throw new RuntimeException(e); } } - public static Boolean run(String query) { - try (Connection conn = connectByTestTask(); - Statement stmt = conn.createStatement()) { + public static Boolean run(Connection conn, String query) { + try (Statement stmt = conn.createStatement()) { return stmt.execute(query); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public static Boolean run(String query) { + try (Connection conn = connectByTestTask()) { + return run(conn, query); } catch (SQLException | ClassNotFoundException e) { throw new RuntimeException(e); } diff --git a/src/test/java/org/embulk/output/databricks/util/DatabricksApiClientUtil.java b/src/test/java/org/embulk/output/databricks/util/DatabricksApiClientUtil.java index b74b80c..078ec34 100644 --- a/src/test/java/org/embulk/output/databricks/util/DatabricksApiClientUtil.java +++ b/src/test/java/org/embulk/output/databricks/util/DatabricksApiClientUtil.java @@ -8,8 +8,16 @@ public class DatabricksApiClientUtil { public static void deleteAllTemporaryStagingVolumes() { + ConfigUtil.TestTask t = ConfigUtil.createTestTask(); + deleteAllTemporaryStagingVolumes(t.getCatalogName(), t.getSchemaName()); + deleteAllTemporaryStagingVolumes(t.getCatalogName(), t.getNonAsciiSchemaName()); + deleteAllTemporaryStagingVolumes(t.getNonAsciiCatalogName(), t.getSchemaName()); + deleteAllTemporaryStagingVolumes(t.getNonAsciiCatalogName(), t.getNonAsciiSchemaName()); + } + + public static void deleteAllTemporaryStagingVolumes(String catalogName, String schemaName) { WorkspaceClient client = createWorkspaceClient(); - fetchAllTemporaryStagingVolumes() + fetchAllTemporaryStagingVolumes(catalogName, schemaName) .forEach( x -> { String name = @@ -20,11 +28,17 @@ public static void deleteAllTemporaryStagingVolumes() { public static List fetchAllTemporaryStagingVolumes() { ConfigUtil.TestTask t = ConfigUtil.createTestTask(); + return fetchAllTemporaryStagingVolumes(t.getCatalogName(), t.getSchemaName()); + } + + private static List fetchAllTemporaryStagingVolumes( + String catalogName, String schemaName) { + ConfigUtil.TestTask t = ConfigUtil.createTestTask(); WorkspaceClient client = createWorkspaceClient(); List results = new ArrayList<>(); client .volumes() - .list(t.getCatalogName(), t.getSchemaName()) + .list(catalogName, schemaName) .forEach( x -> { if (x.getName().startsWith(t.getStagingVolumeNamePrefix())) {