package main import ( "context" "database/sql" "fmt" "log" "os" "strconv" "time" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/mysql" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/joho/godotenv" _ "github.com/go-sql-driver/mysql" ) func main() { // 加载 .env 文件 err := godotenv.Load(".env") if err != nil { log.Println("未能加载 .env 文件,可能使用环境变量。") } if len(os.Args) < 2 { log.Fatalf("使用方法: go run main.go [args...]") } command := os.Args[1] if command == "create" { if len(os.Args) < 3 { log.Fatalf("使用方法: go run main.go create ") } createMigration(os.Args[2]) return } databaseUrl := os.Getenv("DATABASE_DSN") if databaseUrl == "" { log.Fatalf("DATABASE_DSN 环境变量未设置") } m, err := migrate.New( "file://sql", // 迁移文件路径 databaseUrl, ) if err != nil { log.Fatalf("初始化迁移工具失败: %v", err) } defer func() { if sourceErr, databaseErr := m.Close(); sourceErr != nil || databaseErr != nil { log.Printf("关闭迁移工具时出错: sourceErr=%v, databaseErr=%v", sourceErr, databaseErr) } }() switch command { case "up": if err := m.Up(); err != nil && err != migrate.ErrNoChange { log.Fatalf("执行迁移 Up 失败: %v", err) } log.Println("迁移 Up 成功") case "down": if len(os.Args) < 3 { // 不带版本号时回滚最新一个版本 if err := m.Down(); err != nil && err != migrate.ErrNoChange { log.Fatalf("执行迁移 Down 失败: %v", err) } log.Println("迁移 Down 成功 (回滚一个版本)") } else { // 带版本号时回滚到指定版本 version, err := strconv.Atoi(os.Args[2]) if err != nil { log.Fatalf("版本号错误: %v", err) } if err := m.Migrate(uint(version)); err != nil && err != migrate.ErrNoChange { log.Fatalf("执行迁移 Migrate 到版本 %d 失败: %v", version, err) } log.Printf("迁移 Down 成功 (回滚到版本 %d)", version) } case "goto": if len(os.Args) < 3 { log.Fatalf("使用方法: go run main.go goto ") } version, err := strconv.Atoi(os.Args[2]) if err != nil { log.Fatalf("版本号错误: %v", err) } if err := m.Migrate(uint(version)); err != nil && err != migrate.ErrNoChange { log.Fatalf("执行迁移 Goto 版本 %d 失败: %v", version, err) } log.Printf("迁移 Goto 成功 (到版本 %d)", version) case "force": if len(os.Args) < 3 { log.Fatalf("使用方法: go run main.go force ") } version, err := strconv.Atoi(os.Args[2]) if err != nil { log.Fatalf("版本号错误: %v", err) } if err := m.Force(int(version)); err != nil { log.Fatalf("强制设置版本 %d 失败: %v", version, err) } log.Printf("强制设置版本 %d 成功", version) case "version": version, dirty, err := m.Version() if err != nil && err != migrate.ErrNoChange { log.Fatalf("获取版本失败: %v", err) } if dirty { log.Printf("当前数据库版本: %d (dirty)", version) } else { log.Printf("当前数据库版本: %d", version) } default: log.Fatalf("未知命令: %s. 支持的命令: up, down, goto, force, version, create", command) } } func createMigration(name string) { timestamp := time.Now().Unix() upFilename := fmt.Sprintf("sql/%d_%s.up.sql", timestamp, name) downFilename := fmt.Sprintf("sql/%d_%s.down.sql", timestamp, name) if err := os.WriteFile(upFilename, []byte(""), 0644); err != nil { log.Fatalf("创建 %s 失败: %v", upFilename, err) } if err := os.WriteFile(downFilename, []byte(""), 0644); err != nil { log.Fatalf("创建 %s 失败: %v", downFilename, err) } log.Printf("成功创建迁移文件: %s 和 %s", upFilename, downFilename) } // 辅助函数,用于检查数据库连接是否正常 func checkDBConnection(dsn string) error { db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("无法连接数据库: %v", err) } defer db.Close() db.SetMaxOpenConns(1) db.SetMaxIdleConns(1) db.SetConnMaxLifetime(5 * time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err = db.PingContext(ctx) if err != nil { return fmt.Errorf("无法 Ping 数据库: %v", err) } return nil }