Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ data class DatabaseConfiguration(
val version: Int,
val create: (DatabaseConnection) -> Unit,
val upgrade: (DatabaseConnection, Int, Int) -> Unit = { _, _, _ -> },
val downgrade: (DatabaseConnection, Int, Int) -> Unit = { _, initialVersion, version ->
error("Database version $initialVersion newer than config version $version")
},
val inMemory: Boolean = false,
val journalMode: JournalMode = JournalMode.WAL,
val extendedConfig:Extended = Extended(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class NativeDatabaseConnection internal constructor(
fun migrateIfNeeded(
create: (DatabaseConnection) -> Unit,
upgrade: (DatabaseConnection, Int, Int) -> Unit,
downgrade: (DatabaseConnection, Int, Int) -> Unit,
version: Int
) {
this.withTransaction {
Expand All @@ -100,10 +101,11 @@ class NativeDatabaseConnection internal constructor(
create(this)
setVersion(version)
} else if (initialVersion != version) {
if (initialVersion > version)
throw IllegalStateException("Database version $initialVersion newer than config version $version")

upgrade(this, initialVersion, version)
if (initialVersion > version) {
downgrade(this, initialVersion, version)
} else {
upgrade(this, initialVersion, version)
}
setVersion(version)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class NativeDatabaseManager(private val path:String,
try {
val version = configuration.version
if(version != NO_VERSION_CHECK)
conn.migrateIfNeeded(configuration.create, configuration.upgrade, version)
conn.migrateIfNeeded(configuration.create, configuration.upgrade, configuration.downgrade, version)
} catch (e: Exception) {

// If this failed, we have to close the connection or we will end up leaking it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ class DatabaseManagerTest : BaseDatabaseTest(){
}

@Test
fun downgradeNotAllowed(){
val upgradeCalled = AtomicInt(0)
fun downgradeNotAllowedByDefault() {
val config1 = DatabaseConfiguration(
name = TEST_DB_NAME,
version = 1,
Expand All @@ -159,16 +158,11 @@ class DatabaseManagerTest : BaseDatabaseTest(){
execute()
}
},
upgrade = { _, _, _ ->
upgradeCalled.increment()
},
loggingConfig = DatabaseConfiguration.Logging(logger = NoneLogger),
)

createDatabaseManager(config1).withConnection { }
assertEquals(0, upgradeCalled.value)
createDatabaseManager(config1.copy(version = 2)).withConnection { }
assertEquals(1, upgradeCalled.value)
createDatabaseManager(config1).withConnection { }
createDatabaseManager(config1.copy(version = 2)).withConnection { }

var conn: DatabaseConnection? = null
assertFails {
Expand All @@ -178,6 +172,40 @@ class DatabaseManagerTest : BaseDatabaseTest(){
conn?.close()
}

@Test
fun downgradeCalled() {
val upgradeCalled = AtomicInt(0)
val downgradeCalled = AtomicInt(0)
val config1 = DatabaseConfiguration(
name = TEST_DB_NAME,
version = 1,
create = { db ->
db.withStatement(TWO_COL) {
execute()
}
},
upgrade = { _, _, _ ->
upgradeCalled.increment()
},
downgrade = { _, _, _ ->
downgradeCalled.increment()
},
loggingConfig = DatabaseConfiguration.Logging(logger = NoneLogger),
)

createDatabaseManager(config1).withConnection { }
assertEquals(0, upgradeCalled.value)
assertEquals(0, downgradeCalled.value)

createDatabaseManager(config1.copy(version = 2)).withConnection { }
assertEquals(1, upgradeCalled.value)
assertEquals(0, downgradeCalled.value)

createDatabaseManager(config1.copy(version = 1)).withConnection { }
assertEquals(1, upgradeCalled.value)
assertEquals(1, downgradeCalled.value)
}

@Test
fun failedCreateRollsBack(){
val configFail = DatabaseConfiguration(
Expand Down
Loading