GoTest/pkg/database/database.go
2025-11-29 03:27:19 +08:00

99 lines
1.7 KiB
Go

package database
import (
"database/sql"
"fmt"
"log"
"time"
"yinli-api/pkg/config"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
var DB *gorm.DB
// InitDatabase 初始化数据库连接
func InitDatabase(cfg *config.Config) error {
dsn := cfg.GetDSN()
db, err := gorm.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
// 设置连接池参数
db.DB().SetMaxIdleConns(cfg.Database.MaxIdleConns)
db.DB().SetMaxOpenConns(cfg.Database.MaxOpenConns)
db.DB().SetConnMaxLifetime(time.Hour)
// 测试连接
if err := db.DB().Ping(); err != nil {
return fmt.Errorf("数据库连接测试失败: %w", err)
}
// 启用日志
if cfg.Server.Mode == "debug" {
db.LogMode(true)
}
DB = db
log.Println("数据库连接成功")
return nil
}
// CloseDatabase 关闭数据库连接
func CloseDatabase() error {
if DB != nil {
return DB.Close()
}
return nil
}
// GetDB 获取数据库实例
func GetDB() *gorm.DB {
return DB
}
// Ping 检查数据库连接
func Ping() error {
if DB == nil {
return fmt.Errorf("数据库未初始化")
}
return DB.DB().Ping()
}
// Transaction 执行事务
func Transaction(fn func(*gorm.DB) error) error {
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// RawQuery 执行原生SQL查询
func RawQuery(query string, args ...interface{}) (*sql.Rows, error) {
return DB.Raw(query, args...).Rows()
}
// Exec 执行原生SQL
func Exec(sql string, values ...interface{}) error {
return DB.Exec(sql, values...).Error
}