diff --git a/README.md b/README.md index 9d02237..a62e751 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,46 @@ dbx is a database schema migration library for Go that lets you manage database ## Features +- **Rails-like migrations**: Define and organize schema changes with timestamp-based migrations - **Database inspection**: Introspect existing database schemas - **Schema comparison**: Compare schemas and generate migration statements - **Built on `database/sql`**: Works with standard Go database drivers - **Automatic SQL generation**: Automatically generate SQL statements for schema changes +- **CLI tool**: Command-line interface for creating and running migrations ## Usage Examples +### Rails-like Migrations + +Create and run migrations using the CLI: + +```bash +# Generate a new migration +dbx generate create_users + +# Run migrations +dbx --database "postgres://postgres:postgres@localhost:5432/dbx_test?sslmode=disable" migrate +``` + +Or programmatically in your code: + +```go +import ( + "database/sql" + _ "github.com/lib/pq" + "github.com/swiftcarrot/dbx/migration" +) + +// Set the migrations directory +migration.SetMigrationsDir("./migrations") + +// Run migrations +db, _ := sql.Open("postgres", "postgres://postgres:postgres@localhost:5432/dbname") +migration.RunMigrations(db, "") +``` + +See the [migration documentation](./migration/README.md) for more details. + ### Database Inspection Introspect an existing database schema: diff --git a/cmd/dbx/example.go b/cmd/dbx/example.go new file mode 100644 index 0000000..72b8257 --- /dev/null +++ b/cmd/dbx/example.go @@ -0,0 +1,57 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +// This example demonstrates how to use the migration system +func Example() { + // Set the migrations directory + migrationsDir := filepath.Join(os.TempDir(), "dbx_migrations") + migration.SetMigrationsDir(migrationsDir) + + // Create a new migration + migrationPath, err := migration.CreateMigration("create_users") + if err != nil { + fmt.Printf("Error creating migration: %s\n", err) + return + } + fmt.Printf("Created migration: %s\n", migrationPath) + + // Register a migration programmatically + migration.Register("20250520000000", "create_posts", upCreatePosts, downCreatePosts) +} + +// Migration functions for create_posts +func upCreatePosts() *schema.Schema { + s := schema.NewSchema() + + s.CreateTable("posts", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("body", &schema.TextType{}) + t.Column("user_id", &schema.IntegerType{}) + t.Column("created_at", &schema.TimestampType{}) + t.Column("updated_at", &schema.TimestampType{}) + + t.SetPrimaryKey("pk_posts", []string{"id"}) + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + + t.Index("idx_posts_created_at", []string{"created_at"}) + }) + + return s +} + +func downCreatePosts() *schema.Schema { + s := schema.NewSchema() + + s.DropTable("posts") + + return s +} diff --git a/cmd/dbx/main.go b/cmd/dbx/main.go new file mode 100644 index 0000000..8700898 --- /dev/null +++ b/cmd/dbx/main.go @@ -0,0 +1,226 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "os" + "strconv" + "strings" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/swiftcarrot/dbx/migration" +) + +const defaultMigrationsDir = "./migrations" + +func main() { + // Define flags + migrationsDir := flag.String("migrations-dir", defaultMigrationsDir, "Directory containing migrations") + dbUrl := flag.String("database", "", "Database connection URL") + + // Parse command line arguments + flag.Parse() + + // Set up the migrations directory + migration.SetMigrationsDir(*migrationsDir) + + // Get the subcommand + args := flag.Args() + if len(args) == 0 { + printUsage() + os.Exit(1) + } + + command := args[0] + switch command { + case "generate", "g": + if len(args) < 2 { + fmt.Println("Error: Missing migration name") + fmt.Println("Usage: dbx generate ") + os.Exit(1) + } + generateMigration(args[1]) + case "migrate", "m": + if *dbUrl == "" { + fmt.Println("Error: Database URL is required") + fmt.Println("Usage: dbx migrate --database ") + os.Exit(1) + } + var version string + if len(args) > 1 { + version = args[1] + } + runMigrations(*dbUrl, version) + case "rollback", "r": + if *dbUrl == "" { + fmt.Println("Error: Database URL is required") + fmt.Println("Usage: dbx rollback --database [steps]") + os.Exit(1) + } + steps := 1 + if len(args) > 1 { + s, err := strconv.Atoi(args[1]) + if err == nil { + steps = s + } + } + rollbackMigrations(*dbUrl, steps) + case "status", "s": + if *dbUrl == "" { + fmt.Println("Error: Database URL is required") + fmt.Println("Usage: dbx status --database ") + os.Exit(1) + } + showStatus(*dbUrl) + case "help", "h": + printUsage() + default: + fmt.Printf("Unknown command: %s\n", command) + printUsage() + os.Exit(1) + } +} + +func printUsage() { + fmt.Println("DBX Migration Tool") + fmt.Println("Usage:") + fmt.Println(" dbx [options] [arguments]") + fmt.Println("") + fmt.Println("Options:") + fmt.Println(" --migrations-dir Directory containing migrations (default: ./migrations)") + fmt.Println(" --database Database connection URL") + fmt.Println("") + fmt.Println("Commands:") + fmt.Println(" generate, g Generate a new migration") + fmt.Println(" migrate, m [version] Run migrations (up to optional version)") + fmt.Println(" rollback, r [steps] Rollback migrations (default: 1 step)") + fmt.Println(" status, s Show migration status") + fmt.Println(" help, h Show this help") + fmt.Println("") + fmt.Println("Examples:") + fmt.Println(" dbx generate create_users") + fmt.Println(" dbx --database \"postgres://user:pass@localhost/dbname\" migrate") + fmt.Println(" dbx --database \"mysql://user:pass@localhost/dbname\" rollback 2") +} + +func generateMigration(name string) { + filePath, err := migration.CreateMigration(name) + if err != nil { + fmt.Printf("Error generating migration: %s\n", err) + os.Exit(1) + } + fmt.Printf("Created migration: %s\n", filePath) +} + +func connectToDatabase(dbUrl string) (*sql.DB, error) { + // Parse the URL to determine the driver + var driver string + if strings.HasPrefix(dbUrl, "postgres://") { + driver = "postgres" + } else if strings.HasPrefix(dbUrl, "mysql://") { + driver = "mysql" + // Convert mysql URL to DSN format if needed + dbUrl = strings.TrimPrefix(dbUrl, "mysql://") + dbUrl = strings.Replace(dbUrl, "/", "?parseTime=true&loc=Local&charset=utf8mb4&collation=utf8mb4_unicode_ci&database=", 1) + } else if strings.HasPrefix(dbUrl, "sqlite://") { + driver = "sqlite3" + dbUrl = strings.TrimPrefix(dbUrl, "sqlite://") + } else { + return nil, fmt.Errorf("unsupported database URL: must start with postgres://, mysql:// or sqlite://") + } + + // Connect to the database + db, err := sql.Open(driver, dbUrl) + if err != nil { + return nil, err + } + + // Test the connection + if err := db.Ping(); err != nil { + db.Close() + return nil, err + } + + return db, nil +} + +func runMigrations(dbUrl, version string) { + db, err := connectToDatabase(dbUrl) + if err != nil { + fmt.Printf("Error connecting to database: %s\n", err) + os.Exit(1) + } + defer db.Close() + + err = migration.RunMigrations(db, version) + if err != nil { + fmt.Printf("Error running migrations: %s\n", err) + os.Exit(1) + } + + fmt.Println("Migrations applied successfully") +} + +func rollbackMigrations(dbUrl string, steps int) { + db, err := connectToDatabase(dbUrl) + if err != nil { + fmt.Printf("Error connecting to database: %s\n", err) + os.Exit(1) + } + defer db.Close() + + err = migration.RollbackMigration(db, steps) + if err != nil { + fmt.Printf("Error rolling back migrations: %s\n", err) + os.Exit(1) + } + + fmt.Println("Migrations rolled back successfully") +} + +func showStatus(dbUrl string) { + db, err := connectToDatabase(dbUrl) + if err != nil { + fmt.Printf("Error connecting to database: %s\n", err) + os.Exit(1) + } + defer db.Close() + + status, err := migration.GetMigrationStatus(db) + if err != nil { + fmt.Printf("Error getting migration status: %s\n", err) + os.Exit(1) + } + + if len(status) == 0 { + fmt.Println("No migrations found") + return + } + + // Print status table + fmt.Println("Migration Status:") + fmt.Println("--------------------------------------------------------------------------------------------------------") + fmt.Printf("%-14s | %-50s | %-10s | %s\n", "Version", "Name", "Status", "Applied At") + fmt.Println("--------------------------------------------------------------------------------------------------------") + + for _, s := range status { + fmt.Printf("%-14s | %-50s | %-10s | %s\n", s.Version, s.Name, s.Status, s.AppliedAt) + } + fmt.Println("--------------------------------------------------------------------------------------------------------") + + // Print current version + currentVersion, err := migration.GetCurrentVersion(db) + if err != nil { + fmt.Printf("Error getting current version: %s\n", err) + os.Exit(1) + } + + if currentVersion == "" { + fmt.Println("Current version: none") + } else { + fmt.Printf("Current version: %s\n", currentVersion) + } +} diff --git a/migration/README.md b/migration/README.md new file mode 100644 index 0000000..2d938b8 --- /dev/null +++ b/migration/README.md @@ -0,0 +1,179 @@ +# DBX Migration System + +DBX Migrations is a Rails-like database migration system for Go applications using the `github.com/swiftcarrot/dbx` library. It provides a simple way to manage database schema changes across different database systems. + +## Features + +- Rails-like migration system with timestamped versions +- Support for MySQL, PostgreSQL, and SQLite databases +- Full schema comparison and diff calculation +- Automatic SQL generation for schema changes +- Migration rollback support +- Migration status reporting + +## Installation + +```bash +go get github.com/swiftcarrot/dbx +``` + +## Usage + +### Command Line Interface + +The DBX library includes a command-line tool for managing migrations: + +```bash +# Generate a new migration +dbx generate create_users + +# Run all pending migrations +dbx --database "postgres://user:pass@localhost/dbname" migrate + +# Run migrations up to a specific version +dbx --database "postgres://user:pass@localhost/dbname" migrate 20250520000000 + +# Rollback the last migration +dbx --database "postgres://user:pass@localhost/dbname" rollback + +# Rollback multiple migrations +dbx --database "postgres://user:pass@localhost/dbname" rollback 3 + +# Show migration status +dbx --database "postgres://user:pass@localhost/dbname" status + +# Get help +dbx help +``` + +### In Your Go Code + +You can also use the migration system programmatically in your Go code: + +```go +package main + +import ( + "database/sql" + "log" + + _ "github.com/lib/pq" + "github.com/swiftcarrot/dbx/migration" +) + +func main() { + // Set the migrations directory + migration.SetMigrationsDir("./migrations") + + // Connect to database + db, err := sql.Open("postgres", "postgres://user:pass@localhost/dbname") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + // Run migrations + err = migration.RunMigrations(db, "") + if err != nil { + log.Fatal(err) + } + + // Get migration status + status, err := migration.GetMigrationStatus(db) + if err != nil { + log.Fatal(err) + } + + for _, s := range status { + log.Printf("Migration %s (%s): %s", s.Version, s.Name, s.Status) + } +} +``` + +## Writing Migrations + +Migrations are Go files with `Up` and `Down` functions that define schema changes. When you generate a migration, a new file is created in your migrations directory with this structure: + +```go +package migrations + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + migration.Register("20250520000000", "create_users", up20250520000000, down20250520000000) +} + +func up20250520000000() *schema.Schema { + s := schema.NewSchema() + + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}, schema.PrimaryKey) + t.Column("username", &schema.VarcharType{Length: 255}, schema.NotNull) + t.Column("email", &schema.VarcharType{Length: 255}, schema.NotNull) + t.Column("created_at", &schema.TimestampType{}) + + t.Index("idx_users_email", []string{"email"}, schema.Unique) + }) + + return s +} + +func down20250520000000() *schema.Schema { + s := schema.NewSchema() + + s.DropTable("users") + + return s +} +``` + +### Schema Definition API + +The schema definition API is powerful and allows you to define complex schema changes: + +```go +// Create a table +s.CreateTable("posts", func(t *schema.Table) { + // Add columns + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("body", &schema.TextType{}) + t.Column("user_id", &schema.IntegerType{}) + + // Add a primary key + t.SetPrimaryKey("pk_posts", []string{"id"}) + + // Add a foreign key + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + + // Add indexes + t.Index("idx_posts_title", []string{"title"}) +}) + +// Create a view +s.CreateView("active_users", "SELECT * FROM users WHERE active = TRUE") + +// PostgreSQL-specific features +s.EnableExtension("uuid-ossp") +s.CreateSequence("user_id_seq") +``` + +## Migration Workflow + +1. Generate a new migration: `dbx generate create_users` +2. Edit the migration file to define schema changes +3. Run the migration: `dbx --database "postgres://..." migrate` +4. If needed, roll back: `dbx --database "postgres://..." rollback` + +## Database Support + +- PostgreSQL: Full support for all PostgreSQL features +- MySQL: Support for tables, columns, indexes, foreign keys, views, triggers +- SQLite: Support for basic schema operations + +## License + +See the main DBX project license. diff --git a/migration/generator.go b/migration/generator.go new file mode 100644 index 0000000..afd20d9 --- /dev/null +++ b/migration/generator.go @@ -0,0 +1,100 @@ +package migration + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "text/template" +) + +// Generator is responsible for creating new migration files +type Generator struct { + registry *Registry +} + +// NewGenerator creates a new migration generator +func NewGenerator(registry *Registry) *Generator { + return &Generator{ + registry: registry, + } +} + +// Generate creates a new migration file +func (g *Generator) Generate(name string) (string, error) { + // Clean the name (no spaces, lowercase with underscores) + name = strings.ToLower(strings.ReplaceAll(name, " ", "_")) + + // Generate version timestamp + version := GenerateVersionTimestamp() + + // Create file path + filename := fmt.Sprintf("%s_%s.go", version, name) + filePath := filepath.Join(g.registry.GetMigrationsDir(), filename) + + // Create the file + file, err := os.Create(filePath) + if err != nil { + return "", fmt.Errorf("failed to create migration file: %w", err) + } + defer file.Close() + + // Create migration from template + migrationTemplate := template.Must(template.New("migration").Parse(migrationTmpl)) + + packageName := filepath.Base(g.registry.GetMigrationsDir()) + + data := struct { + Version string + Name string + Package string + }{ + Version: version, + Name: name, + Package: packageName, + } + + if err := migrationTemplate.Execute(file, data); err != nil { + return "", fmt.Errorf("failed to write migration file: %w", err) + } + + return filePath, nil +} + +// Template for new migration files +const migrationTmpl = `package {{ .Package }} + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + migration.Register("{{ .Version }}", "{{ .Name }}", up{{ .Version }}, down{{ .Version }}) +} + +func up{{ .Version }}() *schema.Schema { + s := schema.NewSchema() + + // Define your schema changes here + // Example: + // s.CreateTable("users", func(t *schema.Table) { + // t.Column("id", &schema.IntegerType{}, schema.PrimaryKey) + // t.Column("name", &schema.VarcharType{Length: 255}) + // t.Column("email", &schema.VarcharType{Length: 255}, schema.NotNull) + // t.Column("created_at", &schema.TimestampType{}) + // }) + + return s +} + +func down{{ .Version }}() *schema.Schema { + s := schema.NewSchema() + + // Define how to revert the changes here + // Example: + // s.DropTable("users") + + return s +} +` diff --git a/migration/integration_test.go b/migration/integration_test.go new file mode 100644 index 0000000..0665618 --- /dev/null +++ b/migration/integration_test.go @@ -0,0 +1,466 @@ +package migration + +import ( + "database/sql" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/swiftcarrot/dbx/internal/testutil" + "github.com/swiftcarrot/dbx/schema" +) + +func TestIntegration(t *testing.T) { + // Create a temporary directory for migrations + migrationsDir, err := os.MkdirTemp("", "dbx_integration_test_*") + require.NoError(t, err) + defer os.RemoveAll(migrationsDir) + + // Set up registry + registry := NewRegistry(migrationsDir) + + // Generate migrations + generator := NewGenerator(registry) + + createUsersPath, err := generator.Generate("create_users") + require.NoError(t, err) + + createPostsPath, err := generator.Generate("create_posts") + require.NoError(t, err) + + // Verify migration files were created + _, err = os.Stat(createUsersPath) + require.NoError(t, err) + + _, err = os.Stat(createPostsPath) + require.NoError(t, err) + + // Extract the version from the filename + usersVersion := filepath.Base(createUsersPath) + usersVersion = usersVersion[:14] // Get timestamp part + + // Update the users migration content with the correct version + usersMigrationContent := `package migrations + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + migration.Register("` + usersVersion + `", "create_users", upCreateUsers, downCreateUsers) +} + +func upCreateUsers() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("name", &schema.VarcharType{Length: 255}) + t.Column("email", &schema.VarcharType{Length: 255}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_users", []string{"id"}) + }) + return s +} + +func downCreateUsers() *schema.Schema { + s := schema.NewSchema() + s.DropTable("users") + return s +} +` + + err = os.WriteFile(createUsersPath, []byte(usersMigrationContent), 0644) + require.NoError(t, err) + + postsVersion := filepath.Base(createPostsPath) + postsVersion = postsVersion[:14] // Get timestamp part + + postsMigrationContent := `package migrations + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + migration.Register("` + postsVersion + `", "create_posts", upCreatePosts, downCreatePosts) +} + +func upCreatePosts() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("posts", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("content", &schema.TextType{}) + t.Column("user_id", &schema.IntegerType{}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_posts", []string{"id"}) + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + }) + return s +} + +func downCreatePosts() *schema.Schema { + s := schema.NewSchema() + s.DropTable("posts") + return s +} +` + + err = os.WriteFile(createPostsPath, []byte(postsMigrationContent), 0644) + require.NoError(t, err) + + // Create a test database + db, err := testutil.GetSQLiteTestConn() + require.NoError(t, err) + t.Cleanup(func() { + _, err := db.Exec(` + DROP TABLE IF EXISTS posts; + DROP TABLE IF EXISTS users; + DROP TABLE IF EXISTS schema_migrations; + `) + require.NoError(t, err) + }) + + // Register migrations programmatically since we can't rely on init() during testing + versionTracker := NewVersionTracker(db) + err = versionTracker.EnsureMigrationsTable() + require.NoError(t, err) + + // We need to manually create migration objects since loading Go files dynamically is complex + createUsersFn := func() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("name", &schema.VarcharType{Length: 255}) + t.Column("email", &schema.VarcharType{Length: 255}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_users", []string{"id"}) + }) + return s + } + + dropUsersFn := func() *schema.Schema { + s := schema.NewSchema() + s.DropTable("users") + return s + } + + createPostsFn := func() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("posts", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("content", &schema.TextType{}) + t.Column("user_id", &schema.IntegerType{}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_posts", []string{"id"}) + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + }) + return s + } + + dropPostsFn := func() *schema.Schema { + s := schema.NewSchema() + s.DropTable("posts") + return s + } + + usersMigration := NewMigration(usersVersion, "create_users", createUsersFn, dropUsersFn) + postsMigration := NewMigration(postsVersion, "create_posts", createPostsFn, dropPostsFn) + + registry.AddMigration(usersMigration) + registry.AddMigration(postsMigration) + + // Create migrator and run migrations + migrator := NewMigrator(db, registry) + + // Run first migration only + err = migrator.Migrate(usersVersion) + require.NoError(t, err) + + // Verify first migration was applied + var tableCount int + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 1, tableCount) + + // Run second migration + err = migrator.Migrate("") + require.NoError(t, err) + + // Verify second migration was applied + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 1, tableCount) + + // Check that the foreign key exists + var fkCount int + err = db.QueryRow(` + SELECT COUNT(*) FROM sqlite_master + WHERE type='table' AND name='posts' + AND sql LIKE '%FOREIGN KEY%REFERENCES%users%' + `).Scan(&fkCount) + require.NoError(t, err) + require.Equal(t, 1, fkCount) + + // Check migration records + migrations, err := versionTracker.GetAppliedMigrations() + require.NoError(t, err) + require.Equal(t, 2, len(migrations)) + + // Test rollback + err = migrator.Rollback(1) + require.NoError(t, err) + + // Verify posts table was dropped + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 0, tableCount) + + // Verify users table still exists + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 1, tableCount) +} + + // Generate migrations + generator := NewGenerator(registry) + + createUsersPath, err := generator.Generate("create_users") + require.NoError(t, err) + + createPostsPath, err := generator.Generate("create_posts") + require.NoError(t, err) + + // Verify migration files were created + _, err = os.Stat(createUsersPath) + require.NoError(t, err) + + _, err = os.Stat(createPostsPath) + require.NoError(t, err) + + // Update migration file content with actual schema changes + usersMigrationContent := `package migrations + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + version := filepath.Base(createUsersPath) + version = version[:14] // Extract timestamp part + + migration.Register(version, "create_users", upCreateUsers, downCreateUsers) +} + +func upCreateUsers() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("name", &schema.VarcharType{Length: 255}) + t.Column("email", &schema.VarcharType{Length: 255}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_users", []string{"id"}) + }) + return s +} + +func downCreateUsers() *schema.Schema { + s := schema.NewSchema() + s.DropTable("users") + return s +} +` + // Extract the version from the filename + usersVersion := filepath.Base(createUsersPath) + usersVersion = usersVersion[:14] // Get timestamp part + + // Update the users migration content with the correct version + usersMigrationContent = `package migrations + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + migration.Register("` + usersVersion + `", "create_users", upCreateUsers, downCreateUsers) +} + +func upCreateUsers() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("name", &schema.VarcharType{Length: 255}) + t.Column("email", &schema.VarcharType{Length: 255}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_users", []string{"id"}) + }) + return s +} + +func downCreateUsers() *schema.Schema { + s := schema.NewSchema() + s.DropTable("users") + return s +} +` + + err = os.WriteFile(createUsersPath, []byte(usersMigrationContent), 0644) + require.NoError(t, err) + + postsVersion := filepath.Base(createPostsPath) + postsVersion = postsVersion[:14] // Get timestamp part + + postsMigrationContent := `package migrations + +import ( + "github.com/swiftcarrot/dbx/migration" + "github.com/swiftcarrot/dbx/schema" +) + +func init() { + migration.Register("` + postsVersion + `", "create_posts", upCreatePosts, downCreatePosts) +} + +func upCreatePosts() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("posts", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("content", &schema.TextType{}) + t.Column("user_id", &schema.IntegerType{}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_posts", []string{"id"}) + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + }) + return s +} + +func downCreatePosts() *schema.Schema { + s := schema.NewSchema() + s.DropTable("posts") + return s +} +` + + err = os.WriteFile(createPostsPath, []byte(postsMigrationContent), 0644) + require.NoError(t, err) + + // Create a test database + db, err := testutil.GetSQLiteTestConn() + require.NoError(t, err) + t.Cleanup(func() { + _, err := db.Exec(` + DROP TABLE IF EXISTS posts; + DROP TABLE IF EXISTS users; + DROP TABLE IF EXISTS schema_migrations; + `) + require.NoError(t, err) + }) + + // Register migrations programmatically since we can't rely on init() during testing + versionTracker := NewVersionTracker(db) + err = versionTracker.EnsureMigrationsTable() + require.NoError(t, err) + + // We need to manually create migration objects since loading Go files dynamically is complex + createUsersFn := func() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("name", &schema.VarcharType{Length: 255}) + t.Column("email", &schema.VarcharType{Length: 255}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_users", []string{"id"}) + }) + return s + } + + dropUsersFn := func() *schema.Schema { + s := schema.NewSchema() + s.DropTable("users") + return s + } + + createPostsFn := func() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("posts", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("content", &schema.TextType{}) + t.Column("user_id", &schema.IntegerType{}) + t.Column("created_at", &schema.TimestampType{}) + t.SetPrimaryKey("pk_posts", []string{"id"}) + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + }) + return s + } + + dropPostsFn := func() *schema.Schema { + s := schema.NewSchema() + s.DropTable("posts") + return s + } + + usersMigration := NewMigration(usersVersion, "create_users", createUsersFn, dropUsersFn) + postsMigration := NewMigration(postsVersion, "create_posts", createPostsFn, dropPostsFn) + + registry.AddMigration(usersMigration) + registry.AddMigration(postsMigration) + + // Create migrator and run migrations + migrator := NewMigrator(db, registry) + + // Run first migration only + err = migrator.Migrate(usersVersion) + require.NoError(t, err) + + // Verify first migration was applied + var tableCount int + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 1, tableCount) + + // Run second migration + err = migrator.Migrate("") + require.NoError(t, err) + + // Verify second migration was applied + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 1, tableCount) + + // Check that the foreign key exists + var fkCount int + err = db.QueryRow(` + SELECT COUNT(*) FROM sqlite_master + WHERE type='table' AND name='posts' + AND sql LIKE '%FOREIGN KEY%REFERENCES%users%' + `).Scan(&fkCount) + require.NoError(t, err) + require.Equal(t, 1, fkCount) + + // Check migration records + migrations, err := versionTracker.GetAppliedMigrations() + require.NoError(t, err) + require.Equal(t, 2, len(migrations)) + + // Test rollback + err = migrator.Rollback(1) + require.NoError(t, err) + + // Verify posts table was dropped + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 0, tableCount) + + // Verify users table still exists + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'").Scan(&tableCount) + require.NoError(t, err) + require.Equal(t, 1, tableCount) +} diff --git a/migration/main_test.go b/migration/main_test.go new file mode 100644 index 0000000..ad59360 --- /dev/null +++ b/migration/main_test.go @@ -0,0 +1,17 @@ +package migration + +import ( + "os" + "testing" +) + +func TestMain(m *testing.M) { + // Setup code before running tests + + // Run tests + code := m.Run() + + // Cleanup after tests + + os.Exit(code) +} diff --git a/migration/migration.go b/migration/migration.go index 3bdbc18..f2f3d25 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -11,16 +11,12 @@ import ( type Migration struct { // Version represents the migration version (typically a timestamp) Version string - // Name is a descriptive name for the migration Name string - // CreatedAt represents when the migration was created CreatedAt time.Time - // UpFn defines the schema changes for migrating up UpFn func() *schema.Schema - // DownFn defines the schema changes for rolling back (migrating down) DownFn func() *schema.Schema } diff --git a/migration/migration_test.go b/migration/migration_test.go new file mode 100644 index 0000000..79599dd --- /dev/null +++ b/migration/migration_test.go @@ -0,0 +1,132 @@ +package migration + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/swiftcarrot/dbx/schema" +) + +func TestNewMigration(t *testing.T) { + version := "20250520000000" + name := "create_users" + + upFn := func() *schema.Schema { return schema.NewSchema() } + downFn := func() *schema.Schema { return schema.NewSchema() } + + migration := NewMigration(version, name, upFn, downFn) + + require.Equal(t, version, migration.Version) + require.Equal(t, name, migration.Name) + require.NotNil(t, migration.UpFn) + require.NotNil(t, migration.DownFn) + require.False(t, migration.CreatedAt.IsZero()) +} + +func TestMigrationFullVersion(t *testing.T) { + migration := NewMigration("20250520000000", "create_users", nil, nil) + require.Equal(t, "20250520000000_create_users", migration.FullVersion()) +} + +func TestMigrationUp(t *testing.T) { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + }) + + upFn := func() *schema.Schema { return s } + migration := NewMigration("20250520000000", "create_users", upFn, nil) + + upSchema := migration.Up() + require.Equal(t, 1, len(upSchema.Tables)) + require.Equal(t, "users", upSchema.Tables[0].Name) +} + +func TestMigrationDown(t *testing.T) { + s := schema.NewSchema() + s.DropTable("users") + + downFn := func() *schema.Schema { return s } + migration := NewMigration("20250520000000", "create_users", nil, downFn) + + downSchema := migration.Down() + require.Equal(t, 0, len(downSchema.Tables)) +} + +func TestGenerateVersionTimestamp(t *testing.T) { + ts := GenerateVersionTimestamp() + + // Check timestamp format: YYYYMMDDHHMMSS + require.Len(t, ts, 14) + + // Check that it's a valid timestamp by parsing it + now := time.Now() + nowStr := now.Format("20060102150405") + + // The timestamps should be close (within a minute) + require.InDelta(t, len(nowStr), len(ts), 0) +} + +func TestRegistry(t *testing.T) { + tempDir, err := os.MkdirTemp("", "dbx_test_*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + registry := NewRegistry(tempDir) + + // Test adding migrations + m1 := NewMigration("20250520000000", "first", nil, nil) + m2 := NewMigration("20250520000001", "second", nil, nil) + + registry.AddMigration(m1) + registry.AddMigration(m2) + + // Test getting migrations in sorted order + migrations := registry.GetMigrations() + require.Equal(t, 2, len(migrations)) + require.Equal(t, "20250520000000", migrations[0].Version) + require.Equal(t, "20250520000001", migrations[1].Version) + + // Test finding a migration by version + found := registry.FindMigrationByVersion("20250520000001") + require.NotNil(t, found) + require.Equal(t, "second", found.Name) + + // Test finding a non-existent migration + notFound := registry.FindMigrationByVersion("999") + require.Nil(t, notFound) +} + +func TestGenerator(t *testing.T) { + tempDir, err := os.MkdirTemp("", "dbx_test_*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + registry := NewRegistry(tempDir) + generator := NewGenerator(registry) + + // Test generating a migration + filePath, err := generator.Generate("create_users") + require.NoError(t, err) + + // Check that the file was created + _, err = os.Stat(filePath) + require.NoError(t, err) + + // Check file name format + fileName := filepath.Base(filePath) + require.Contains(t, fileName, "_create_users.go") + + // Check file content + content, err := os.ReadFile(filePath) + require.NoError(t, err) + + contentStr := string(content) + require.Contains(t, contentStr, "func init() {") + require.Contains(t, contentStr, "migration.Register(") + require.Contains(t, contentStr, "func up") + require.Contains(t, contentStr, "func down") +} diff --git a/migration/migrations.go b/migration/migrations.go new file mode 100644 index 0000000..ce59875 --- /dev/null +++ b/migration/migrations.go @@ -0,0 +1,98 @@ +package migration + +import ( + "database/sql" + "sync" + + "github.com/swiftcarrot/dbx/schema" +) + +// Global registry for migrations +var ( + defaultRegistry *Registry + defaultOnce sync.Once +) + +// SetMigrationsDir sets the directory where migrations are stored +func SetMigrationsDir(dir string) { + defaultOnce.Do(func() { + defaultRegistry = NewRegistry(dir) + }) +} + +// Register registers a new migration in the default registry +func Register(version, name string, upFn, downFn func() *schema.Schema) { + if defaultRegistry == nil { + panic("migration directory not set; call SetMigrationsDir first") + } + + migration := NewMigration(version, name, upFn, downFn) + defaultRegistry.AddMigration(migration) +} + +// CreateMigration generates a new migration file +func CreateMigration(name string) (string, error) { + if defaultRegistry == nil { + panic("migration directory not set; call SetMigrationsDir first") + } + + generator := NewGenerator(defaultRegistry) + return generator.Generate(name) +} + +// RunMigrations runs pending migrations +func RunMigrations(db *sql.DB, targetVersion string) error { + if defaultRegistry == nil { + panic("migration directory not set; call SetMigrationsDir first") + } + + if err := defaultRegistry.LoadMigrations(); err != nil { + return err + } + + migrator := NewMigrator(db, defaultRegistry) + return migrator.Migrate(targetVersion) +} + +// RollbackMigration rolls back the last migration or specified number of migrations +func RollbackMigration(db *sql.DB, steps int) error { + if defaultRegistry == nil { + panic("migration directory not set; call SetMigrationsDir first") + } + + if err := defaultRegistry.LoadMigrations(); err != nil { + return err + } + + migrator := NewMigrator(db, defaultRegistry) + return migrator.Rollback(steps) +} + +// GetMigrationStatus returns the status of all migrations +func GetMigrationStatus(db *sql.DB) ([]struct { + Version string + Name string + Status string + AppliedAt string +}, error) { + if defaultRegistry == nil { + panic("migration directory not set; call SetMigrationsDir first") + } + + if err := defaultRegistry.LoadMigrations(); err != nil { + return nil, err + } + + migrator := NewMigrator(db, defaultRegistry) + return migrator.Status() +} + +// GetCurrentVersion returns the current database version +func GetCurrentVersion(db *sql.DB) (string, error) { + versionTracker := NewVersionTracker(db) + if err := versionTracker.EnsureMigrationsTable(); err != nil { + return "", err + } + + return versionTracker.GetCurrentVersion() +} diff --git a/migration/migrator.go b/migration/migrator.go new file mode 100644 index 0000000..1e899fe --- /dev/null +++ b/migration/migrator.go @@ -0,0 +1,320 @@ +package migration + +import ( + "database/sql" + "fmt" + + "github.com/swiftcarrot/dbx/mysql" + "github.com/swiftcarrot/dbx/postgresql" + "github.com/swiftcarrot/dbx/schema" + "github.com/swiftcarrot/dbx/sqlite" +) + +// Migrator runs migrations +type Migrator struct { + db *sql.DB + registry *Registry + versionTracker *VersionTracker +} + +// NewMigrator creates a new migrator +func NewMigrator(db *sql.DB, registry *Registry) *Migrator { + versionTracker := NewVersionTracker(db) + + return &Migrator{ + db: db, + registry: registry, + versionTracker: versionTracker, + } +} + +// Init initializes the migrator +func (m *Migrator) Init() error { + return m.versionTracker.EnsureMigrationsTable() +} + +// Migrate runs migrations up to the specified version +// If version is empty, all pending migrations are run +func (m *Migrator) Migrate(version string) error { + // Initialize the migrations table if it doesn't exist + if err := m.Init(); err != nil { + return err + } + + // Get all migrations + migrations := m.registry.GetMigrations() + if len(migrations) == 0 { + return fmt.Errorf("no migrations found") + } + + // Get applied migrations + appliedMigrations, err := m.versionTracker.GetAppliedMigrations() + if err != nil { + return err + } + + // Create a set of applied migrations for quick lookup + appliedSet := make(map[string]bool) + for _, m := range appliedMigrations { + appliedSet[m.Version] = true + } + + // Determine the database type for SQL generation + dbType, err := m.versionTracker.DatabaseType() + if err != nil { + return err + } + + // Create appropriate SQL generator based on database type + var sqlGenerator schema.SQLGenerator + switch dbType { + case "mysql": + sqlGenerator = mysql.New() + case "postgresql": + sqlGenerator = postgresql.New() + case "sqlite": + sqlGenerator = sqlite.New() + default: + return fmt.Errorf("unsupported database type: %s", dbType) + } + + // Begin a transaction + tx, err := m.db.Begin() + if err != nil { + return err + } + + // Rollback on error + defer func() { + if err != nil { + tx.Rollback() + } + }() + + // If we have a target version, migrate up to that version + // Otherwise, run all pending migrations + for _, migration := range migrations { + // Skip if already applied + if appliedSet[migration.Version] { + continue + } + + // Stop if we've reached the target version + if version != "" && migration.Version > version { + break + } + + fmt.Printf("Migrating up: %s_%s\n", migration.Version, migration.Name) + + // Run the migration + upSchema := migration.Up() + + // Get current schema + var currentSchema *schema.Schema + if version == "" && len(appliedMigrations) == 0 { + // If this is the first migration, use an empty schema + currentSchema = schema.NewSchema() + } else { + // Otherwise, inspect the current schema + currentSchema, err = sqlGenerator.Inspect(m.db) + if err != nil { + return fmt.Errorf("failed to inspect schema: %w", err) + } + } + + // Generate changes + changes := schema.Diff(currentSchema, upSchema) + + // Apply changes + for _, change := range changes { + sql, err := sqlGenerator.GenerateSQL(change) + if err != nil { + return fmt.Errorf("failed to generate SQL: %w", err) + } + + _, err = tx.Exec(sql) + if err != nil { + return fmt.Errorf("failed to execute SQL: %w", err) + } + } + + // Record the migration + err = m.versionTracker.RecordMigration(migration.Version, migration.Name) + if err != nil { + return fmt.Errorf("failed to record migration: %w", err) + } + } + + // Commit the transaction + return tx.Commit() +} + +// Rollback rolls back the last applied migration +// If steps is specified, that many migrations are rolled back +func (m *Migrator) Rollback(steps int) error { + // Ensure migrations table exists + if err := m.Init(); err != nil { + return err + } + + // Get applied migrations + appliedMigrations, err := m.versionTracker.GetAppliedMigrations() + if err != nil { + return err + } + + if len(appliedMigrations) == 0 { + return fmt.Errorf("no migrations to roll back") + } + + // Determine how many migrations to roll back + if steps <= 0 { + steps = 1 // Default to rolling back just the last migration + } + + if steps > len(appliedMigrations) { + steps = len(appliedMigrations) + } + + // Get migrations to roll back + migrationsToRollback := appliedMigrations[len(appliedMigrations)-steps:] + + // Determine the database type for SQL generation + dbType, err := m.versionTracker.DatabaseType() + if err != nil { + return err + } + + // Create appropriate SQL generator based on database type + var sqlGenerator schema.SQLGenerator + switch dbType { + case "mysql": + sqlGenerator = mysql.New() + case "postgresql": + sqlGenerator = postgresql.New() + case "sqlite": + sqlGenerator = sqlite.New() + default: + return fmt.Errorf("unsupported database type: %s", dbType) + } + + // Begin a transaction + tx, err := m.db.Begin() + if err != nil { + return err + } + + // Rollback on error + defer func() { + if err != nil { + tx.Rollback() + } + }() + + // Roll back migrations in reverse order (newest first) + for i := len(migrationsToRollback) - 1; i >= 0; i-- { + record := migrationsToRollback[i] + + // Find the migration + migration := m.registry.FindMigrationByVersion(record.Version) + if migration == nil { + return fmt.Errorf("migration with version %s not found in registry", record.Version) + } + + fmt.Printf("Rolling back: %s_%s\n", migration.Version, migration.Name) + + // Run the down migration + downSchema := migration.Down() + + // Get current schema + currentSchema, err := sqlGenerator.Inspect(m.db) + if err != nil { + return fmt.Errorf("failed to inspect schema: %w", err) + } + + // Generate changes + changes := schema.Diff(currentSchema, downSchema) + + // Apply changes + for _, change := range changes { + sql, err := sqlGenerator.GenerateSQL(change) + if err != nil { + return fmt.Errorf("failed to generate SQL: %w", err) + } + + _, err = tx.Exec(sql) + if err != nil { + return fmt.Errorf("failed to execute SQL: %w", err) + } + } + + // Remove the migration record + err = m.versionTracker.RemoveMigration(migration.Version) + if err != nil { + return fmt.Errorf("failed to remove migration record: %w", err) + } + } + + // Commit the transaction + return tx.Commit() +} + +// Status returns the status of all migrations +func (m *Migrator) Status() ([]struct { + Version string + Name string + Status string + AppliedAt string +}, error) { + // Ensure migrations table exists + if err := m.Init(); err != nil { + return nil, err + } + + // Get all migrations + migrations := m.registry.GetMigrations() + + // Get applied migrations + appliedMigrations, err := m.versionTracker.GetAppliedMigrations() + if err != nil { + return nil, err + } + + // Create a map of applied migrations + appliedMap := make(map[string]MigrationRecord) + for _, m := range appliedMigrations { + appliedMap[m.Version] = m + } + + // Create status list + statusList := make([]struct { + Version string + Name string + Status string + AppliedAt string + }, len(migrations)) + + for i, migration := range migrations { + status := "Pending" + appliedAt := "" + + if record, ok := appliedMap[migration.Version]; ok { + status = "Applied" + appliedAt = record.AppliedAt.Format("2006-01-02 15:04:05") + } + + statusList[i] = struct { + Version string + Name string + Status string + AppliedAt string + }{ + Version: migration.Version, + Name: migration.Name, + Status: status, + AppliedAt: appliedAt, + } + } + + return statusList, nil +} diff --git a/migration/migrator_test.go b/migration/migrator_test.go new file mode 100644 index 0000000..ae8eb27 --- /dev/null +++ b/migration/migrator_test.go @@ -0,0 +1,122 @@ +package migration + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/swiftcarrot/dbx/schema" +) + +func TestMigrator(t *testing.T) { + db := setupTestDB(t) + + // Create a registry with test migrations + registry := NewRegistry(t.TempDir()) + + // Create test migrations + createUsersMigration := NewMigration("20250520000000", "create_users", func() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("users", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("name", &schema.VarcharType{Length: 255}) + t.SetPrimaryKey("pk_users", []string{"id"}) + }) + return s + }, func() *schema.Schema { + s := schema.NewSchema() + s.DropTable("users") + return s + }) + + createPostsMigration := NewMigration("20250520000001", "create_posts", func() *schema.Schema { + s := schema.NewSchema() + s.CreateTable("posts", func(t *schema.Table) { + t.Column("id", &schema.IntegerType{}) + t.Column("title", &schema.VarcharType{Length: 255}) + t.Column("user_id", &schema.IntegerType{}) + t.SetPrimaryKey("pk_posts", []string{"id"}) + t.ForeignKey("fk_posts_user", []string{"user_id"}, "users", []string{"id"}) + }) + return s + }, func() *schema.Schema { + s := schema.NewSchema() + s.DropTable("posts") + return s + }) + + // Add migrations to registry + registry.AddMigration(createUsersMigration) + registry.AddMigration(createPostsMigration) + + // Create migrator + migrator := NewMigrator(db, registry) + + // Initialize migrations table + err := migrator.Init() + require.NoError(t, err) + + // Run migrations + err = migrator.Migrate("") + require.NoError(t, err) + + // Check that migrations were applied + versionTracker := NewVersionTracker(db) + + migrations, err := versionTracker.GetAppliedMigrations() + require.NoError(t, err) + require.Equal(t, 2, len(migrations)) + + // Verify tables exist + var count int + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'").Scan(&count) + require.NoError(t, err) + require.Equal(t, 1, count) + + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&count) + require.NoError(t, err) + require.Equal(t, 1, count) + + // Test rolling back one migration + err = migrator.Rollback(1) + require.NoError(t, err) + + // Verify posts table was dropped + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&count) + require.NoError(t, err) + require.Equal(t, 0, count) + + // Verify users table still exists + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='users'").Scan(&count) + require.NoError(t, err) + require.Equal(t, 1, count) + + // Verify migration record was removed + has, err := versionTracker.HasMigration("20250520000001") + require.NoError(t, err) + require.False(t, has) + + // Test partial migration (up to specific version) + err = migrator.Migrate("20250520000001") + require.NoError(t, err) + + // Verify posts table exists again + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='posts'").Scan(&count) + require.NoError(t, err) + require.Equal(t, 1, count) + + // Test status + status, err := migrator.Status() + require.NoError(t, err) + require.Equal(t, 2, len(status)) + require.Equal(t, "Applied", status[0].Status) + require.Equal(t, "Applied", status[1].Status) + + // Test rolling back all migrations + err = migrator.Rollback(2) + require.NoError(t, err) + + // Verify both tables were dropped + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name IN ('users', 'posts')").Scan(&count) + require.NoError(t, err) + require.Equal(t, 0, count) +} diff --git a/migration/registry.go b/migration/registry.go new file mode 100644 index 0000000..45ab954 --- /dev/null +++ b/migration/registry.go @@ -0,0 +1,110 @@ +package migration + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +// Registry manages migration files and versions +type Registry struct { + migrations []*Migration + migrationsDir string +} + +// NewRegistry creates a new migration registry +func NewRegistry(migrationsDir string) *Registry { + return &Registry{ + migrations: []*Migration{}, + migrationsDir: migrationsDir, + } +} + +// AddMigration adds a migration to the registry +func (r *Registry) AddMigration(migration *Migration) { + r.migrations = append(r.migrations, migration) +} + +// FindMigrationByVersion finds a migration by its version +func (r *Registry) FindMigrationByVersion(version string) *Migration { + for _, m := range r.migrations { + if m.Version == version { + return m + } + } + return nil +} + +// GetMigrations returns all migrations sorted by version +func (r *Registry) GetMigrations() []*Migration { + // Sort migrations by version (which should be a timestamp) + sort.Slice(r.migrations, func(i, j int) bool { + return r.migrations[i].Version < r.migrations[j].Version + }) + + return r.migrations +} + +// LoadMigrations loads all migrations from the migrations directory +func (r *Registry) LoadMigrations() error { + // Check if migrations directory exists + if _, err := os.Stat(r.migrationsDir); os.IsNotExist(err) { + // Create migrations directory if it doesn't exist + if err := os.MkdirAll(r.migrationsDir, 0755); err != nil { + return fmt.Errorf("failed to create migrations directory: %w", err) + } + return nil + } + + // Walk through the migrations directory + err := filepath.WalkDir(r.migrationsDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories + if d.IsDir() { + return nil + } + + // Only process .go files + if !strings.HasSuffix(d.Name(), ".go") { + return nil + } + + // Skip test files + if strings.HasSuffix(d.Name(), "_test.go") { + return nil + } + + // Skip the migrations.go file itself + if d.Name() == "migrations.go" { + return nil + } + + // TODO: Parse migration file and register the migration + // This would be done by the migration generator + + return nil + }) + + return err +} + +// GenerateVersionTimestamp generates a timestamp-based version +// in the format used by Rails migrations (YYYYMMDDHHMMSS) +func GenerateVersionTimestamp() string { + now := time.Now() + return fmt.Sprintf("%d%02d%02d%02d%02d%02d", + now.Year(), now.Month(), now.Day(), + now.Hour(), now.Minute(), now.Second()) +} + +// GetMigrationsDir returns the migrations directory +func (r *Registry) GetMigrationsDir() string { + return r.migrationsDir +} diff --git a/migration/version_tracker.go b/migration/version_tracker.go new file mode 100644 index 0000000..f9f8ce7 --- /dev/null +++ b/migration/version_tracker.go @@ -0,0 +1,121 @@ +package migration + +import ( + "database/sql" + "fmt" + "time" +) + +// VersionTracker manages the migration versions in the database +type VersionTracker struct { + db *sql.DB +} + +// MigrationRecord represents a record in the schema_migrations table +type MigrationRecord struct { + Version string + Name string + AppliedAt time.Time +} + +// NewVersionTracker creates a new version tracker +func NewVersionTracker(db *sql.DB) *VersionTracker { + return &VersionTracker{ + db: db, + } +} + +// EnsureMigrationsTable ensures that the schema_migrations table exists +func (v *VersionTracker) EnsureMigrationsTable() error { + // Create the schema_migrations table if it doesn't exist + // This table tracks which migrations have been applied + createTableSQL := ` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version VARCHAR(255) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + applied_at TIMESTAMP NOT NULL + )` + + _, err := v.db.Exec(createTableSQL) + return err +} + +// GetAppliedMigrations returns all applied migrations +func (v *VersionTracker) GetAppliedMigrations() ([]MigrationRecord, error) { + rows, err := v.db.Query("SELECT version, name, applied_at FROM schema_migrations ORDER BY version") + if err != nil { + return nil, err + } + defer rows.Close() + + var migrations []MigrationRecord + for rows.Next() { + var m MigrationRecord + if err := rows.Scan(&m.Version, &m.Name, &m.AppliedAt); err != nil { + return nil, err + } + migrations = append(migrations, m) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return migrations, nil +} + +// RecordMigration records a migration as applied +func (v *VersionTracker) RecordMigration(version, name string) error { + _, err := v.db.Exec( + "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, ?)", + version, name, time.Now(), + ) + return err +} + +// RemoveMigration removes a migration record +func (v *VersionTracker) RemoveMigration(version string) error { + _, err := v.db.Exec("DELETE FROM schema_migrations WHERE version = ?", version) + return err +} + +// GetCurrentVersion returns the highest applied migration version +func (v *VersionTracker) GetCurrentVersion() (string, error) { + var version string + err := v.db.QueryRow("SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1").Scan(&version) + if err == sql.ErrNoRows { + return "", nil // No migrations applied yet + } + if err != nil { + return "", err + } + return version, nil +} + +// HasMigration checks if a migration has been applied +func (v *VersionTracker) HasMigration(version string) (bool, error) { + var count int + err := v.db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +// DatabaseType determines the type of database (mysql, postgresql, sqlite) +func (v *VersionTracker) DatabaseType() (string, error) { + // Try to determine database type from the driver name + // This is a simple approach and might need refinement + driverName := v.db.Driver().Name() + + switch { + case driverName == "mysql" || driverName == "mysqld": + return "mysql", nil + case driverName == "postgres" || driverName == "postgresql": + return "postgresql", nil + case driverName == "sqlite" || driverName == "sqlite3": + return "sqlite", nil + default: + return "", fmt.Errorf("unsupported database type: %s", driverName) + } +} diff --git a/migration/version_tracker_test.go b/migration/version_tracker_test.go new file mode 100644 index 0000000..db0087a --- /dev/null +++ b/migration/version_tracker_test.go @@ -0,0 +1,98 @@ +package migration + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + "github.com/swiftcarrot/dbx/internal/testutil" +) + +func setupTestDB(t *testing.T) *sql.DB { + db, err := testutil.GetSQLiteTestConn() + require.NoError(t, err) + + t.Cleanup(func() { + _, err := db.Exec(`DROP TABLE IF EXISTS schema_migrations`) + require.NoError(t, err) + }) + + return db +} + +func TestVersionTracker(t *testing.T) { + db := setupTestDB(t) + tracker := NewVersionTracker(db) + + // Test creating the migrations table + err := tracker.EnsureMigrationsTable() + require.NoError(t, err) + + // Test getting applied migrations when empty + migrations, err := tracker.GetAppliedMigrations() + require.NoError(t, err) + require.Empty(t, migrations) + + // Test recording a migration + err = tracker.RecordMigration("20250520000000", "create_users") + require.NoError(t, err) + + // Test has migration + has, err := tracker.HasMigration("20250520000000") + require.NoError(t, err) + require.True(t, has) + + has, err = tracker.HasMigration("nonexistent") + require.NoError(t, err) + require.False(t, has) + + // Test getting current version + version, err := tracker.GetCurrentVersion() + require.NoError(t, err) + require.Equal(t, "20250520000000", version) + + // Test getting applied migrations + migrations, err = tracker.GetAppliedMigrations() + require.NoError(t, err) + require.Equal(t, 1, len(migrations)) + require.Equal(t, "20250520000000", migrations[0].Version) + require.Equal(t, "create_users", migrations[0].Name) + require.False(t, migrations[0].AppliedAt.IsZero()) + + // Add another migration + err = tracker.RecordMigration("20250520000001", "create_posts") + require.NoError(t, err) + + // Test getting all migrations + migrations, err = tracker.GetAppliedMigrations() + require.NoError(t, err) + require.Equal(t, 2, len(migrations)) + + // Test getting current version (should be the latest) + version, err = tracker.GetCurrentVersion() + require.NoError(t, err) + require.Equal(t, "20250520000001", version) + + // Test removing a migration + err = tracker.RemoveMigration("20250520000001") + require.NoError(t, err) + + // Verify it was removed + has, err = tracker.HasMigration("20250520000001") + require.NoError(t, err) + require.False(t, has) + + // Test getting current version after removal + version, err = tracker.GetCurrentVersion() + require.NoError(t, err) + require.Equal(t, "20250520000000", version) +} + +func TestDatabaseType(t *testing.T) { + db := setupTestDB(t) + tracker := NewVersionTracker(db) + + dbType, err := tracker.DatabaseType() + require.NoError(t, err) + require.Equal(t, "sqlite", dbType) +}