From 12bebc4cdaab4d67bd4518061b046bbaae10da79 Mon Sep 17 00:00:00 2001 From: yuanoOo Date: Thu, 28 Nov 2024 16:24:49 +0800 Subject: [PATCH] BugFix: Fix unable to connect to oracle tenant. --- docs/spark-connector-oceanbase.md | 8 +++++++- docs/spark-connector-oceanbase_cn.md | 6 ++++++ .../oceanbase/spark/cfg/ConnectionOptions.java | 2 ++ .../com/oceanbase/spark/jdbc/OBJdbcUtils.scala | 17 +++++++++++++++-- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/spark-connector-oceanbase.md b/docs/spark-connector-oceanbase.md index e02fad1..5e4b175 100644 --- a/docs/spark-connector-oceanbase.md +++ b/docs/spark-connector-oceanbase.md @@ -279,7 +279,7 @@ df.write sql-port - + 2881 Integer The SQL port. @@ -313,6 +313,12 @@ df.write String The table name. + + driver + com.mysql.cj.jdbc.Driver + String + The class name of the JDBC driver. By default, it connects to the MySQL tenant. If you need to connect to Oracle tenant, the name needs to be com.oceanbase.jdbc.Driver + diff --git a/docs/spark-connector-oceanbase_cn.md b/docs/spark-connector-oceanbase_cn.md index 7267f8a..0414e44 100644 --- a/docs/spark-connector-oceanbase_cn.md +++ b/docs/spark-connector-oceanbase_cn.md @@ -310,6 +310,12 @@ df.write String 表名。 + + driver + com.mysql.cj.jdbc.Driver + String + JDBC 驱动程序的类名。默认支持连接MySQL租户。如果需要连接到Oracle租户,请修改为 com.oceanbase.jdbc.Driver + diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/java/com/oceanbase/spark/cfg/ConnectionOptions.java b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/java/com/oceanbase/spark/cfg/ConnectionOptions.java index f3597b5..4e986a6 100644 --- a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/java/com/oceanbase/spark/cfg/ConnectionOptions.java +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/java/com/oceanbase/spark/cfg/ConnectionOptions.java @@ -24,6 +24,8 @@ public interface ConnectionOptions { String PASSWORD = "password"; String SCHEMA_NAME = "schema-name"; String TABLE_NAME = "table-name"; + String DRIVER = "driver"; + String DRIVER_DEFAULT = "com.mysql.cj.jdbc.Driver"; /* Direct-load config */ String ENABLE_DIRECT_LOAD_WRITE = "direct-load.enabled"; diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/jdbc/OBJdbcUtils.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/jdbc/OBJdbcUtils.scala index 5c0ce3e..1f6cc23 100644 --- a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/jdbc/OBJdbcUtils.scala +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/jdbc/OBJdbcUtils.scala @@ -25,6 +25,10 @@ import java.sql.{Connection, DriverManager} object OBJdbcUtils { val OB_MYSQL_URL = s"jdbc:mysql://%s:%d/%s" private val OB_ORACLE_URL = s"jdbc:oceanbase://%s:%d/%s" + private val MYSQL_JDBC_DRIVER = "com.mysql.cj.jdbc.Driver" + private val MYSQL_LEGACY_JDBC_DRIVER = "com.mysql.jdbc.Driver" + private val OB_JDBC_DRIVER = "com.oceanbase.jdbc.Driver" + private val OB_LEGACY_JDBC_DRIVER = "com.alipay.oceanbase.jdbc.Driver" def getConnection(sparkSettings: SparkSettings): Connection = { val connection = DriverManager.getConnection( @@ -41,19 +45,28 @@ object OBJdbcUtils { def getJdbcUrl(sparkSettings: SparkSettings): String = { var url: String = null - if ("MYSQL".equalsIgnoreCase(getCompatibleMode(sparkSettings))) { + val driver = + sparkSettings.getProperty(ConnectionOptions.DRIVER, ConnectionOptions.DRIVER_DEFAULT) + if ( + driver.equalsIgnoreCase(MYSQL_JDBC_DRIVER) || driver.equalsIgnoreCase( + MYSQL_LEGACY_JDBC_DRIVER) + ) { url = OBJdbcUtils.OB_MYSQL_URL.format( sparkSettings.getProperty(ConnectionOptions.HOST), sparkSettings.getIntegerProperty(ConnectionOptions.SQL_PORT), sparkSettings.getProperty(ConnectionOptions.SCHEMA_NAME) ) - } else { + } else if ( + driver.equalsIgnoreCase(OB_JDBC_DRIVER) || driver.equalsIgnoreCase(OB_LEGACY_JDBC_DRIVER) + ) { JdbcDialects.registerDialect(OceanBaseOracleDialect) url = OBJdbcUtils.OB_ORACLE_URL.format( sparkSettings.getProperty(ConnectionOptions.HOST), sparkSettings.getIntegerProperty(ConnectionOptions.SQL_PORT), sparkSettings.getProperty(ConnectionOptions.SCHEMA_NAME) ) + } else { + throw new RuntimeException(String.format("Unsupported driver name: %s", driver)) } url }