GoTest/pkg/config/config.go
2025-11-29 03:27:19 +08:00

136 lines
3.4 KiB
Go

package config
import (
"fmt"
"log"
"os"
"github.com/spf13/viper"
)
// Config 应用配置结构
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
RateLimit RateLimitConfig `mapstructure:"rateLimit"`
CORS CORSConfig `mapstructure:"cors"`
Log LogConfig `mapstructure:"log"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Port string `mapstructure:"port"`
Mode string `mapstructure:"mode"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
Charset string `mapstructure:"charset"`
ParseTime bool `mapstructure:"parseTime"`
Loc string `mapstructure:"loc"`
MaxIdleConns int `mapstructure:"maxIdleConns"`
MaxOpenConns int `mapstructure:"maxOpenConns"`
}
// RedisConfig Redis配置
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
PoolSize int `mapstructure:"poolSize"`
}
// JWTConfig JWT配置
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHours int `mapstructure:"expireHours"`
}
// RateLimitConfig 频率限制配置
type RateLimitConfig struct {
Enabled bool `mapstructure:"enabled"`
Requests int `mapstructure:"requests"`
Window int `mapstructure:"window"`
}
// CORSConfig CORS配置
type CORSConfig struct {
AllowOrigins []string `mapstructure:"allowOrigins"`
AllowMethods []string `mapstructure:"allowMethods"`
AllowHeaders []string `mapstructure:"allowHeaders"`
}
// LogConfig 日志配置
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
Output string `mapstructure:"output"`
}
var AppConfig *Config
// LoadConfig 加载配置文件
func LoadConfig(env string) (*Config, error) {
if env == "" {
env = "dev"
}
viper.SetConfigName(env)
viper.SetConfigType("yaml")
viper.AddConfigPath("./config")
viper.AddConfigPath("../config")
viper.AddConfigPath("../../config")
// 设置环境变量前缀
viper.SetEnvPrefix("YINLI")
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
AppConfig = &config
log.Printf("已加载配置文件: %s", viper.ConfigFileUsed())
return &config, nil
}
// GetEnv 获取环境变量,如果不存在则返回默认值
func GetEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// GetDSN 获取数据库连接字符串
func (c *Config) GetDSN() string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s",
c.Database.Username,
c.Database.Password,
c.Database.Host,
c.Database.Port,
c.Database.DBName,
c.Database.Charset,
c.Database.ParseTime,
c.Database.Loc,
)
}
// GetRedisAddr 获取Redis连接地址
func (c *Config) GetRedisAddr() string {
return fmt.Sprintf("%s:%d", c.Redis.Host, c.Redis.Port)
}