Skip to content

Commit

Permalink
Adds support to Postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
pgeadas committed Jun 19, 2024
1 parent 31e1c6f commit cb76924
Show file tree
Hide file tree
Showing 18 changed files with 650 additions and 40 deletions.
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies {
runtimeOnly 'com.h2database:h2:1.4.200'
runtimeOnly 'mysql:mysql-connector-java:8.0.28'
runtimeOnly 'org.xerial:sqlite-jdbc:3.36.0.3'
runtimeOnly("org.postgresql:postgresql:42.7.3")
implementation 'info.picocli:picocli:4.6.3'
annotationProcessor 'info.picocli:picocli-codegen:4.6.3'

Expand All @@ -62,6 +63,8 @@ dependencies {
testImplementation "org.testcontainers:testcontainers:1.16.3"
testImplementation "org.testcontainers:mysql:1.16.3"
testImplementation "org.testcontainers:spock:1.16.3"
testImplementation 'org.testcontainers:postgresql:1.16.3'
testImplementation("org.postgresql:postgresql:42.7.3")

// Dummy library containing some migration files among their resources for testing purposes
testImplementation files("libs/jar-with-resources.jar")
Expand Down
3 changes: 3 additions & 0 deletions conf/resource-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
},
{
"pattern":"\\Qschema/sqlite.sql\\E"
},
{
"pattern":"\\Qschema/postgres.sql\\E"
}
]},
"bundles":[{
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/io/seqera/migtool/App.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public Integer call() {
.withUser(username)
.withPassword(password)
.withUrl(url)
.withDialect(dialect!=null ? dialect : dialectFromUrl(url))
.withDriver(driver!=null ? driver : driverFromUrl(url))
.withDialect(dialect!=null ? dialect : dialectFromUrl(url).toString())
.withDriver(driver!=null ? driver : driverFromUrl(url).toString())
.withLocations(location)
.withPattern(pattern);

Expand Down
60 changes: 60 additions & 0 deletions src/main/java/io/seqera/migtool/Dialect.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.seqera.migtool;

import java.util.List;

public enum Dialect {
MYSQL(List.of("mysql")),
H2(List.of("h2")),
MARIADB(List.of("mariadb")),
SQLITE(List.of("sqlite")),
POSTGRES(List.of("postgres", "postgresql")),
TCPOSTGRES(List.of("tc"));

private final List<String> names;

Dialect(List<String> names) {
this.names = names;
}

public static Dialect from(String dialect) {
for (Dialect d : Dialect.values()) {
if (d.names.contains(dialect.toLowerCase())) {
return d;
}
}
throw new IllegalStateException("Unknown dialect: " + dialect);
}

boolean isPostgres() {
return this == POSTGRES;
}

boolean isMySQL() {
return this == MYSQL;
}

boolean isH2() {
return this == H2;
}

boolean isSQLite() {
return this == SQLITE;
}

boolean isMariaDB() {
return this == MARIADB;
}

boolean isTestContainersPostgres() {
return this == TCPOSTGRES;
}

@Override
public String toString() {
if (this == TCPOSTGRES) {
// the dialect is actually POSTGRES
return "postgres";
}
return names.get(0);
}
}
29 changes: 29 additions & 0 deletions src/main/java/io/seqera/migtool/Driver.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.seqera.migtool;

public enum Driver {
MYSQL("com.mysql.cj.jdbc.Driver"),
H2("org.h2.Driver"),
SQLITE("org.sqlite.JDBC"),
POSTGRES("org.postgresql.Driver"),
TCPOSTGRES("org.testcontainers.jdbc.ContainerDatabaseDriver");

private final String driver;

Driver(String driver) {
this.driver = driver;
}

static Driver from(String driver) {
for (Driver d : Driver.values()) {
if (d.driver.equals(driver)) {
return d;
}
}
throw new IllegalStateException("Unknown driver: " + driver);
}

@Override
public String toString() {
return driver;
}
}
28 changes: 16 additions & 12 deletions src/main/java/io/seqera/migtool/Helper.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,25 @@ private static void closeWithWarning(Closeable closeable) {
tryClose(closeable);
}

static public String dialectFromUrl(String url) {
static public Dialect dialectFromUrl(String url) {
if( url==null )
return null;
String[] parts = url.split(":");
return parts.length > 1 ? parts[1] : null;
}

static public String driverFromUrl(String url) {
final String dialect = dialectFromUrl(url);
if( "mysql".equals(dialect) )
return "com.mysql.cj.jdbc.Driver";
if( "h2".equals(dialect))
return "org.h2.Driver";
if( "sqlite".equals(dialect))
return "org.sqlite.JDBC";
return parts.length > 1 ? Dialect.from(parts[1]) : null;
}

static public Driver driverFromUrl(String url) {
final Dialect dialect = dialectFromUrl(url);
if( dialect.isMySQL() )
return Driver.MYSQL;
if( dialect.isH2() )
return Driver.H2;
if( dialect.isSQLite() )
return Driver.SQLITE;
if( dialect.isPostgres() )
return Driver.POSTGRES;
if( dialect.isTestContainersPostgres() )
return Driver.TCPOSTGRES;
return null;
}
}
2 changes: 1 addition & 1 deletion src/main/java/io/seqera/migtool/MigRecord.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*
* @author Paolo Di Tommaso <[email protected]>
*/
class MigRecord implements Comparable<MigRecord> {
public class MigRecord implements Comparable<MigRecord> {

enum Language {
SQL,
Expand Down
62 changes: 41 additions & 21 deletions src/main/java/io/seqera/migtool/MigTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*/
package io.seqera.migtool;

import static io.seqera.migtool.Dialect.*;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -42,13 +44,13 @@ public class MigTool {

static final String MIGTOOL_TABLE = "MIGTOOL_HISTORY";

static final String[] DIALECTS = {"h2", "mysql", "mariadb","sqlite"};
static final Dialect[] DIALECTS = {H2, MYSQL, MARIADB, SQLITE, POSTGRES, TCPOSTGRES};

String driver;
Driver driver;
String url;
String user;
String password;
String dialect;
Dialect dialect;
String locations;
ClassLoader classLoader;
Pattern pattern;
Expand All @@ -66,7 +68,7 @@ public MigTool() {
}

public MigTool withDriver(String driver) {
this.driver = driver;
this.driver = Driver.from(driver);
return this;
}

Expand All @@ -86,7 +88,7 @@ public MigTool withPassword(String password) {
}

public MigTool withDialect(String dialect) {
this.dialect = dialect;
this.dialect = Dialect.from(dialect);
return this;
}

Expand Down Expand Up @@ -121,8 +123,7 @@ public MigTool run() {
createIfNotExists();
scanMigrations();
apply();
}
finally {
} finally {
if( previous!=null ) {
Thread.currentThread().setContextClassLoader(previous);
}
Expand Down Expand Up @@ -167,11 +168,17 @@ List<MigRecord> getOverrideEntries() {
* Validate the expected input params and open the connection with the DB
*/
protected void init() {
if( dialect==null || dialect.isEmpty() )
validateAttributes();
loadDriver();
loadSchemaAndCatalog();
}

private void validateAttributes() {
if( dialect==null )
throw new IllegalStateException("Missing 'dialect' attribute");
if( url==null || url.isEmpty() )
throw new IllegalStateException("Missing 'url' attribute");
if( driver==null || driver.isEmpty() )
if( driver==null )
throw new IllegalStateException("Missing 'driver' attribute");
if( user==null || user.isEmpty() )
throw new IllegalStateException("Missing 'user' attribute");
Expand All @@ -181,15 +188,18 @@ protected void init() {
throw new IllegalStateException("Unsupported dialect: " + dialect);
if( locations==null )
throw new IllegalStateException("Missing 'locations' attribute");
}

private void loadDriver() {
try {
// load driver
Class.forName(driver);
Class.forName(driver.toString());
}
catch (ClassNotFoundException e) {
throw new IllegalStateException("Unable to find driver class: " + driver, e);
}
}

private void loadSchemaAndCatalog() {
try( Connection conn = getConnection() ) {
if( conn == null )
throw new IllegalStateException("Unable to acquire DB connection");
Expand Down Expand Up @@ -230,14 +240,14 @@ protected String dumpResultSet(boolean hasData, ResultSet rs) {
final int cols=rs.getMetaData().getColumnCount();
for( int i=1; i<=cols; i++ ) {
if( i>1 ) result.append(",");
result.append( String.valueOf(rs.getMetaData().getColumnName(i)) );
result.append(rs.getMetaData().getColumnName(i));
}
result.append( "}; ");

int row=0;
do {
row++;
result.append( "{row"+row+":");
result.append("{row").append(row).append(":");
for( int i=1; i<=cols; i++ ) {
if( i>1 ) result.append(",");
result.append( String.valueOf(rs.getObject(i)).replaceAll("\n"," ") );
Expand All @@ -259,14 +269,14 @@ protected String dumpResultSet(boolean hasData, ResultSet rs) {
protected void createIfNotExists() {
try (Connection conn = getConnection()) {
if( !existTable(conn, MIGTOOL_TABLE) ) {
log.info("Creating MigTool schema using dialect: " + dialect);
String schema = Helper.getResourceAsString("/schema/" + dialect + ".sql");
log.info("Creating MigTool schema using dialect: {}", dialect.toString());
String schema = Helper.getResourceAsString("/schema/" + dialect.toString() + ".sql");
try ( Statement stm = conn.createStatement() ) {
stm.execute(schema);
}
log.info("Created");
}
}
catch (SQLException e) {
} catch (SQLException e) {
throw new IllegalStateException("Unable to create MigTool schema -- cause: " + e.getMessage(), e);
}
}
Expand Down Expand Up @@ -334,7 +344,7 @@ else if(entry.isPatch())
* Apply the migration files
*/
protected void apply() {
if( migrationEntries.size()==0 ) {
if(migrationEntries.isEmpty()) {
log.info("No DB migrations found");
}

Expand All @@ -350,7 +360,9 @@ protected void apply() {

protected void checkRank(MigRecord entry) {
try(Connection conn=getConnection(); Statement stm = conn.createStatement()) {
ResultSet rs = stm.executeQuery("select max(`rank`) from "+MIGTOOL_TABLE);
String rank = dialect.isPostgres() || dialect.isTestContainersPostgres() ? "rank" : "`rank`";
String sql = String.format("select max(%s) from %s", rank, MIGTOOL_TABLE);
ResultSet rs = stm.executeQuery(sql);
int last = rs.next() ? rs.getInt(1) : 0;
int expected = last+1;
if( entry.rank != expected) {
Expand Down Expand Up @@ -433,8 +445,11 @@ private int migrate(MigRecord entry) throws SQLException {
// compute the delta
int delta = (int)(System.currentTimeMillis()-now);

String columns = dialect.isPostgres() || dialect.isTestContainersPostgres()
? "rank, script, checksum, created_on, execution_time"
: "`rank`,`script`,`checksum`,`created_on`,`execution_time`";
// save the current migration
final String insertSql = "insert into "+MIGTOOL_TABLE+" (`rank`,`script`,`checksum`,`created_on`,`execution_time`) values (?,?,?,?,?)";
final String insertSql = String.format("insert into %s (%s) values (?,?,?,?,?)", MIGTOOL_TABLE, columns);
try (Connection conn=getConnection(); PreparedStatement insert = conn.prepareStatement(insertSql)) {
insert.setInt(1, entry.rank);
insert.setString(2, entry.script);
Expand All @@ -447,7 +462,12 @@ private int migrate(MigRecord entry) throws SQLException {
}

protected boolean checkMigrated(MigRecord entry) {
String sql = "select `id`, `checksum`, `script` from " + MIGTOOL_TABLE + " where `rank` = ? and `script` = ?";
String sql;
if(dialect.isPostgres() || dialect.isTestContainersPostgres()) {
sql = "select id, checksum, script from " + MIGTOOL_TABLE + " where rank = ? and script = ?";
} else {
sql = "select `id`, `checksum`, `script` from " + MIGTOOL_TABLE + " where `rank` = ? and `script` = ?";
}

try (Connection conn=getConnection(); PreparedStatement stm = conn.prepareStatement(sql)) {
stm.setInt(1, entry.rank);
Expand Down
9 changes: 9 additions & 0 deletions src/main/resources/schema/postgres.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
create table if not exists MIGTOOL_HISTORY
(
id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
rank INTEGER NOT NULL,
script VARCHAR(250) NOT NULL,
checksum VARCHAR(64) NOT NULL,
created_on timestamp NOT NULL,
execution_time INTEGER
);
20 changes: 20 additions & 0 deletions src/nativeCliTest/groovy/io/seqera/migtool/HelperTest.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.seqera.migtool

import spock.lang.Specification

class HelperTest extends Specification {

def 'should load resources' () {
expect:
Helper.getResourceFiles('/db/mysql') == ['/db/mysql/file1.sql'].toSet()
Helper.getResourceFiles('/db/mariadb') == ['/db/mariadb/V01__maria1.sql', '/db/mariadb/V02__maria2.sql', '/db/mariadb/v01-foo.txt'].toSet()
Helper.getResourceFiles('/db/postgres') == ['/db/postgres/V01__postgres.sql'].toSet()
}

def 'should read resource file' () {
expect:
Helper.getResourceAsString('db/mysql/file1.sql').trim() == 'select * from my-table;'
Helper.getResourceAsString('db/mariadb/V01__maria1.sql') == 'create table XXX ( col1 varchar(1) );\n'
Helper.getResourceAsString('db/postgres/V01__postgres.sql').trim() == 'select * from my-table;'
}
}
Loading

0 comments on commit cb76924

Please sign in to comment.