diff --git a/cmd/check.go b/cmd/check.go index 566ebf0f..883adfee 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -52,8 +52,6 @@ type channelExport struct { Files []exportFile `json:"files"` } -const dateLayout = "2006-01-02_15-04-05" - var termWidth = func() (width int, err error) { width, _, err = term.GetSize(int(os.Stdout.Fd())) if err == nil { @@ -65,6 +63,7 @@ var termWidth = func() (width int, err error) { func NewCheckCmd() *cobra.Command { var cfg config.ServerCmdConfig + loader := config.NewConfigLoader() cmd := &cobra.Command{ Use: "check", Short: "Check and purge incomplete files", @@ -72,11 +71,7 @@ func NewCheckCmd() *cobra.Command { runCheckCmd(cmd, &cfg) }, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - loader := config.NewConfigLoader() - if err := loader.InitializeConfig(cmd); err != nil { - return err - } - if err := loader.Load(&cfg); err != nil { + if err := loader.Load(cmd, &cfg); err != nil { return err } if err := checkRequiredCheckFlags(&cfg); err != nil { @@ -85,18 +80,13 @@ func NewCheckCmd() *cobra.Command { return nil }, } - addChecktFlags(cmd, &cfg) + loader.RegisterPlags(cmd.Flags(), "", cfg, true) + cmd.Flags().Bool("export", true, "Export incomplete files to json file") + cmd.Flags().Bool("clean", false, "Clean missing and orphan file parts") + cmd.Flags().String("user", "", "Telegram User Name") return cmd } -func addChecktFlags(cmd *cobra.Command, cfg *config.ServerCmdConfig) { - flags := cmd.Flags() - config.AddCommonFlags(flags, cfg) - flags.Bool("export", true, "Export incomplete files to json file") - flags.Bool("clean", false, "Clean missing and orphan file parts") - flags.String("user", "", "Telegram User Name") -} - func checkRequiredCheckFlags(cfg *config.ServerCmdConfig) error { var missingFields []string @@ -245,6 +235,9 @@ func runCheckCmd(cmd *cobra.Command, cfg *config.ServerCmdConfig) { lg.Fatalf("Channel %d: found %d messages out of %d", id, len(msgs), total) continue } + + msgIds := utils.Map(msgs, func(m tg.NotEmptyMessage) int { return m.GetID() }) + uploadPartIds := []int{} if err := db.Model(&models.Upload{}).Where("user_id = ?", user.UserId).Where("channel_id = ?", id). Pluck("part_id", &uploadPartIds).Error; err != nil { @@ -256,8 +249,9 @@ func runCheckCmd(cmd *cobra.Command, cfg *config.ServerCmdConfig) { for _, partID := range uploadPartIds { uploadPartMap[partID] = true } + msgMap := make(map[int]bool) - for _, m := range msgs { + for _, m := range msgIds { if m > 0 && !uploadPartMap[m] { msgMap[m] = true } @@ -343,7 +337,7 @@ func runCheckCmd(cmd *cobra.Command, cfg *config.ServerCmdConfig) { } -func loadChannelMessages(ctx context.Context, client *telegram.Client, channelId int64) (msgs []int, total int, err error) { +func loadChannelMessages(ctx context.Context, client *telegram.Client, channelId int64) (msgs []tg.NotEmptyMessage, total int, err error) { errChan := make(chan error, 1) go func() { @@ -392,7 +386,7 @@ func loadChannelMessages(ctx context.Context, client *telegram.Client, channelId for msgiter.Next(ctx) { msg := msgiter.Value() - msgs = append(msgs, msg.Msg.GetID()) + msgs = append(msgs, msg.Msg) count++ bar.Set(count) } diff --git a/cmd/run.go b/cmd/run.go index 94ca1909..5159f4e0 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -6,7 +6,6 @@ import ( "net" "net/http" "regexp" - "strings" "time" "github.com/go-chi/chi/v5" @@ -21,7 +20,6 @@ import ( "github.com/tgdrive/teldrive/internal/chizap" "github.com/tgdrive/teldrive/internal/config" "github.com/tgdrive/teldrive/internal/database" - "github.com/tgdrive/teldrive/internal/duration" "github.com/tgdrive/teldrive/internal/logging" "github.com/tgdrive/teldrive/internal/middleware" "github.com/tgdrive/teldrive/internal/tgc" @@ -37,6 +35,7 @@ import ( func NewRun() *cobra.Command { var cfg config.ServerCmdConfig + loader := config.NewConfigLoader() cmd := &cobra.Command{ Use: "run", Short: "Start Teldrive Server", @@ -45,96 +44,18 @@ func NewRun() *cobra.Command { }, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - loader := config.NewConfigLoader() - if err := loader.InitializeConfig(cmd); err != nil { + if err := loader.Load(cmd, &cfg); err != nil { return err } - if err := loader.Load(&cfg); err != nil { - return err - } - if err := checkRequiredRunFlags(&cfg); err != nil { + if err := loader.Validate(); err != nil { return err } return nil }, } - addServerFlags(cmd, &cfg) + loader.RegisterPlags(cmd.Flags(), "", cfg, false) return cmd } -func addServerFlags(cmd *cobra.Command, cfg *config.ServerCmdConfig) { - - flags := cmd.Flags() - - config.AddCommonFlags(flags, cfg) - - // Server config - flags.IntVarP(&cfg.Server.Port, "server-port", "p", 8080, "Server port") - duration.DurationVar(flags, &cfg.Server.GracefulShutdown, "server-graceful-shutdown", 10*time.Second, "Server graceful shutdown timeout") - flags.BoolVar(&cfg.Server.EnablePprof, "server-enable-pprof", false, "Enable Pprof Profiling") - duration.DurationVar(flags, &cfg.Server.ReadTimeout, "server-read-timeout", 1*time.Hour, "Server read timeout") - duration.DurationVar(flags, &cfg.Server.WriteTimeout, "server-write-timeout", 1*time.Hour, "Server write timeout") - - // CronJobs config - flags.BoolVar(&cfg.CronJobs.Enable, "cronjobs-enable", true, "Run cron jobs") - duration.DurationVar(flags, &cfg.CronJobs.CleanFilesInterval, "cronjobs-clean-files-interval", 1*time.Hour, "Clean files interval") - duration.DurationVar(flags, &cfg.CronJobs.CleanUploadsInterval, "cronjobs-clean-uploads-interval", 12*time.Hour, "Clean uploads interval") - duration.DurationVar(flags, &cfg.CronJobs.FolderSizeInterval, "cronjobs-folder-size-interval", 2*time.Hour, "Folder size update interval") - - // Cache config - flags.IntVar(&cfg.Cache.MaxSize, "cache-max-size", 10*1024*1024, "Max Cache max size (memory)") - flags.StringVar(&cfg.Cache.RedisAddr, "cache-redis-addr", "", "Redis address") - flags.StringVar(&cfg.Cache.RedisPass, "cache-redis-pass", "", "Redis password") - - // JWT config - flags.StringVar(&cfg.JWT.Secret, "jwt-secret", "", "JWT secret key") - duration.DurationVar(flags, &cfg.JWT.SessionTime, "jwt-session-time", (30*24)*time.Hour, "JWT session duration") - flags.StringSliceVar(&cfg.JWT.AllowedUsers, "jwt-allowed-users", []string{}, "Allowed users") - - // Telegram config - flags.StringVar(&cfg.TG.StorageFile, "tg-storage-file", "", "Sqlite Storage file path") - flags.BoolVar(&cfg.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting for telegram client") - flags.IntVar(&cfg.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst for telegram client") - flags.IntVar(&cfg.TG.Rate, "tg-rate", 100, "Limiting rate for telegram client") - flags.StringVar(&cfg.TG.Proxy, "tg-proxy", "", "HTTP OR SOCKS5 proxy URL") - flags.BoolVar(&cfg.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable Stream bots") - flags.BoolVar(&cfg.TG.Ntp, "tg-ntp", false, "Use NTP server time") - flags.BoolVar(&cfg.TG.EnableLogging, "tg-enable-logging", false, "Enable telegram client logging") - flags.Int64Var(&cfg.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size") - - // Telegram Uploads config - flags.StringVar(&cfg.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key") - flags.IntVar(&cfg.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads") - flags.IntVar(&cfg.TG.Uploads.MaxRetries, "tg-uploads-max-retries", 10, "Uploads Retries") - duration.DurationVar(flags, &cfg.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout") - duration.DurationVar(flags, &cfg.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration") - flags.IntVar(&cfg.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads") - flags.IntVar(&cfg.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers") - duration.DurationVar(flags, &cfg.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 20*time.Second, "Chunk Fetch Timeout") - -} - -func checkRequiredRunFlags(cfg *config.ServerCmdConfig) error { - var missingFields []string - - if cfg.DB.DataSource == "" { - missingFields = append(missingFields, "db-data-source") - } - if cfg.JWT.Secret == "" { - missingFields = append(missingFields, "jwt-secret") - } - if cfg.TG.AppHash == "" { - missingFields = append(missingFields, "tg-app-hash") - } - if cfg.TG.AppId == 0 { - missingFields = append(missingFields, "tg-app-id") - } - - if len(missingFields) > 0 { - return fmt.Errorf("required configuration values not set: %s", strings.Join(missingFields, ", ")) - } - - return nil -} func findAvailablePort(startPort int) (int, error) { for port := startPort; port < startPort+100; port++ { diff --git a/config.sample.toml b/config.sample.toml index c948febb..27ea92dc 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -1,53 +1,57 @@ -[db] - data-source = "" - prepare-stmt = true - log-level = 1 - [db.pool] - enable = true - max-idle-connections = 25 - max-lifetime = "10m" - max-open-connections = 25 +[cache] +max-size = 10485760 +redis-addr = '' +redis-pass = '' [cronjobs] - enable = true +clean-files-interval = '1h' +clean-uploads-interval = '12h' +enable = true +folder-size-interval = '2h' + +[db] +log-level = 'info' +prepare-stmt = true + +[db.pool] +enable = true +max-idle-connections = 25 +max-lifetime = '10m' +max-open-connections = 25 [jwt] - allowed-users = [""] - secret = "" - session-time = "30d" +session-time = '30d' +secret = '' +allowed-users = [] [log] - development = true - level = -1 +level = 'info' +file = '' [server] - graceful-shutdown = "15s" - port = 8080 - read-timeout = "1h" - write-timeout = "1h" +graceful-shutdown = '10s' +port = 8080 +read-timeout = '1h' +write-timeout = '1h' [tg] - app-hash = "" - app-id = 0 - app-version = "4.6.3 K" - bg-bots-limit = 5 - device-model = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0" - disable-stream-bots = false - lang-code = "en" - lang-pack = "webk" - rate = 100 - rate-burst = 5 - rate-limit = true - session-file = "" - system-lang-code = "en-US" - system-version = "Win32" - proxy= "http://127.0.0.1:8080" - - [tg.uploads] - encryption-key = "" - retention = "7d" - threads = 8 - [tg.stream] - multi-threads = 0 - buffers = 8 +pool-size = 8 +rate = 100 +rate-burst = 5 +ntp = false +disable-stream-bots = false +storage-file = '' +proxy = '' +rate-limit = true +reconnect-timeout = '5m' + +[tg.stream] +buffers = 8 +chunk-timeout = '20s' +[tg.uploads] +multi-threads = 0 +encryption-key = '' +max-retries = 10 +retention = '7d' +threads = '8' diff --git a/go.mod b/go.mod index 163574cb..fc979544 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/go-chi/chi/v5 v5.2.0 github.com/go-chi/cors v1.2.1 github.com/go-co-op/gocron v1.37.0 - github.com/go-viper/mapstructure/v2 v2.2.1 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/gotd/contrib v0.21.0 @@ -18,6 +17,7 @@ require ( github.com/iyear/connectproxy v0.1.1 github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213 github.com/manifoldco/promptui v0.9.0 + github.com/mitchellh/mapstructure v1.5.0 github.com/ogen-go/ogen v1.8.1 github.com/redis/go-redis/v9 v9.7.0 github.com/schollz/progressbar/v3 v3.18.0 @@ -58,7 +58,6 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/go.sum b/go.sum index 97c5020a..f990c75f 100644 --- a/go.sum +++ b/go.sum @@ -93,8 +93,6 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= -github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= diff --git a/internal/config/config.go b/internal/config/config.go index 210b95c5..b654a769 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,108 +8,98 @@ import ( "strings" "time" - "github.com/go-viper/mapstructure/v2" + "github.com/mitchellh/mapstructure" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/tgdrive/teldrive/internal/duration" - "go.uber.org/zap/zapcore" ) +type ServerCmdConfig struct { + Server ServerConfig `config:"server"` + Log LoggingConfig `config:"log"` + JWT JWTConfig `config:"jwt"` + DB DBConfig `config:"db"` + TG TGConfig `config:"tg"` + CronJobs CronJobConfig `config:"cronjobs"` + Cache CacheConfig `config:"cache"` +} + type ServerConfig struct { - Port int `mapstructure:"port"` - GracefulShutdown time.Duration `mapstructure:"graceful-shutdown"` - EnablePprof bool `mapstructure:"enable-pprof"` - ReadTimeout time.Duration `mapstructure:"read-timeout"` - WriteTimeout time.Duration `mapstructure:"write-timeout"` + Port int `config:"port" description:"HTTP port for the server to listen on" default:"8080"` + GracefulShutdown time.Duration `config:"graceful-shutdown" description:"Grace period for server shutdown" default:"10s"` + EnablePprof bool `config:"enable-pprof" description:"Enable pprof debugging endpoints"` + ReadTimeout time.Duration `config:"read-timeout" description:"Maximum duration for reading entire request" default:"1h"` + WriteTimeout time.Duration `config:"write-timeout" description:"Maximum duration for writing response" default:"1h"` } type CacheConfig struct { - MaxSize int `mapstructure:"max-size"` - RedisAddr string `mapstructure:"redis-addr"` - RedisPass string `mapstructure:"redis-pass"` + MaxSize int `config:"max-size" description:"Maximum cache size in bytes" default:"10485760"` + RedisAddr string `config:"redis-addr" description:"Redis server address"` + RedisPass string `config:"redis-pass" description:"Redis server password"` } type LoggingConfig struct { - Level string `mapstructure:"level"` - File string `mapstructure:"file"` + Level string `config:"level" description:"Logging level (debug, info, warn, error)" default:"info"` + File string `config:"file" description:"Log file path, if empty logs to stdout"` } type JWTConfig struct { - Secret string `mapstructure:"secret"` - SessionTime time.Duration `mapstructure:"session-time"` - AllowedUsers []string `mapstructure:"allowed-users"` + Secret string `config:"secret" description:"JWT signing secret key" required:"true"` + SessionTime time.Duration `config:"session-time" description:"JWT token validity duration" default:"30d"` + AllowedUsers []string `config:"allowed-users" description:"List of allowed usernames"` } +type DBPool struct { + Enable bool `config:"enable" description:"Enable connection pooling" default:"true"` + MaxOpenConnections int `config:"max-open-connections" description:"Maximum number of open connections" default:"25"` + MaxIdleConnections int `config:"max-idle-connections" description:"Maximum number of idle connections" default:"25"` + MaxLifetime time.Duration `config:"max-lifetime" description:"Maximum connection lifetime" default:"10m"` +} type DBConfig struct { - DataSource string `mapstructure:"data-source"` - PrepareStmt bool `mapstructure:"prepare-stmt"` - LogLevel string `mapstructure:"log-level"` - Pool struct { - Enable bool `mapstructure:"enable"` - MaxOpenConnections int `mapstructure:"max-open-connections"` - MaxIdleConnections int `mapstructure:"max-idle-connections"` - MaxLifetime time.Duration `mapstructure:"max-lifetime"` - } `mapstructure:"pool"` + DataSource string `config:"data-source" description:"Database connection string" required:"true"` + PrepareStmt bool `config:"prepare-stmt" description:"Use prepared statements" default:"true"` + LogLevel string `config:"log-level" description:"Database logging level" default:"info"` + Pool DBPool `config:"pool"` } type CronJobConfig struct { - Enable bool `mapstructure:"enable"` - CleanFilesInterval time.Duration `mapstructure:"clean-files-interval"` - CleanUploadsInterval time.Duration `mapstructure:"clean-uploads-interval"` - FolderSizeInterval time.Duration `mapstructure:"folder-size-interval"` + Enable bool `config:"enable" description:"Enable scheduled background jobs" default:"true"` + CleanFilesInterval time.Duration `config:"clean-files-interval" description:"Interval for cleaning expired files" default:"1h"` + CleanUploadsInterval time.Duration `config:"clean-uploads-interval" description:"Interval for cleaning incomplete uploads" default:"12h"` + FolderSizeInterval time.Duration `config:"folder-size-interval" description:"Interval for updating folder sizes" default:"2h"` } -type TGConfig struct { - AppId int `mapstructure:"app-id"` - AppHash string `mapstructure:"app-hash"` - RateLimit bool `mapstructure:"rate-limit"` - RateBurst int `mapstructure:"rate-burst"` - Rate int `mapstructure:"rate"` - UserName string `mapstructure:"user-name"` - DeviceModel string `mapstructure:"device-model"` - SystemVersion string `mapstructure:"system-version"` - AppVersion string `mapstructure:"app-version"` - LangCode string `mapstructure:"lang-code"` - SystemLangCode string `mapstructure:"system-lang-code"` - LangPack string `mapstructure:"lang-pack"` - Ntp bool `mapstructure:"ntp"` - StorageFile string `mapstructure:"storage-file"` - DisableStreamBots bool `mapstructure:"disable-stream-bots"` - Proxy string `mapstructure:"proxy"` - ReconnectTimeout time.Duration `mapstructure:"reconnect-timeout"` - PoolSize int64 `mapstructure:"pool-size"` - EnableLogging bool `mapstructure:"enable-logging"` - Uploads struct { - EncryptionKey string `mapstructure:"encryption-key"` - Threads int `mapstructure:"threads"` - MaxRetries int `mapstructure:"max-retries"` - Retention time.Duration `mapstructure:"retention"` - } `mapstructure:"uploads"` - Stream struct { - MultiThreads int `mapstructure:"multi-threads"` - Buffers int `mapstructure:"buffers"` - ChunkTimeout time.Duration `mapstructure:"chunk-timeout"` - } `mapstructure:"stream"` +type TGStream struct { + MultiThreads int `config:"multi-threads" description:"Number of download threads"` + Buffers int `config:"buffers" description:"Number of stream buffers" default:"8"` + ChunkTimeout time.Duration `config:"chunk-timeout" description:"Chunk download timeout" default:"20s"` } -type ServerCmdConfig struct { - Server ServerConfig `mapstructure:"server"` - Log LoggingConfig `mapstructure:"log"` - JWT JWTConfig `mapstructure:"jwt"` - DB DBConfig `mapstructure:"db"` - TG TGConfig `mapstructure:"tg"` - CronJobs CronJobConfig `mapstructure:"cronjobs"` - Cache CacheConfig `mapstructure:"cache"` +type TGUpload struct { + EncryptionKey string `config:"encryption-key" description:"Encryption key for uploads" required:"true"` + Threads int `config:"threads" description:"Number of upload threads" default:"8"` + MaxRetries int `config:"max-retries" description:"Maximum upload retry attempts" default:"10"` + Retention time.Duration `config:"retention" description:"Upload retention period" default:"7d"` } - -type MigrateCmdConfig struct { - DB DBConfig `mapstructure:"db"` - Log LoggingConfig `mapstructure:"log"` +type TGConfig struct { + RateLimit bool `config:"rate-limit" description:"Enable rate limiting for API calls" default:"true"` + RateBurst int `config:"rate-burst" description:"Maximum burst size for rate limiting" default:"5"` + Rate int `config:"rate" description:"Rate limit in requests per minute" default:"100"` + Ntp bool `config:"ntp" description:"Use NTP for time synchronization"` + StorageFile string `config:"storage-file" description:"Path to SQLite storage file"` + DisableStreamBots bool `config:"disable-stream-bots" description:"Disable streaming bots"` + Proxy string `config:"proxy" description:"HTTP/SOCKS5 proxy URL"` + ReconnectTimeout time.Duration `config:"reconnect-timeout" description:"Client reconnection timeout" default:"5m"` + PoolSize int `config:"pool-size" description:"Session pool size" default:"8"` + EnableLogging bool `config:"enable-logging" description:"Enable Telegram client logging"` + Uploads TGUpload `config:"uploads"` + Stream TGStream `config:"stream"` } type ConfigLoader struct { - v *viper.Viper + v *viper.Viper + requiredFields []string } func NewConfigLoader() *ConfigLoader { @@ -136,11 +126,11 @@ func StringToDurationHook() mapstructure.DecodeHookFunc { } } -func (cl *ConfigLoader) InitializeConfig(cmd *cobra.Command) error { +func (cl *ConfigLoader) Load(cmd *cobra.Command, cfg interface{}) error { + cl.v.SetConfigType("toml") cfgFile := cmd.Flags().Lookup("config").Value.String() - if cfgFile != "" { cl.v.SetConfigFile(cfgFile) } else { @@ -153,59 +143,134 @@ func (cl *ConfigLoader) InitializeConfig(cmd *cobra.Command) error { cl.v.SetConfigName("config") } - cl.v.SetEnvPrefix("teldrive") - cl.v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) - cl.v.AutomaticEnv() - - if err := cl.v.BindPFlags(cmd.Flags()); err != nil { - return fmt.Errorf("error binding flags: %v", err) - } - if err := cl.v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return fmt.Errorf("error reading config file: %v", err) } } + return cl.load(cfg) +} + +func (cl *ConfigLoader) Validate() error { + missingFields := []string{} + for _, key := range cl.requiredFields { + if !cl.v.IsSet(key) { + missingFields = append(missingFields, strings.ReplaceAll(key, ".", "-")) + } + } + if len(missingFields) > 0 { + return fmt.Errorf("required configuration values not set: %s", strings.Join(missingFields, ", ")) + } return nil } -func (cl *ConfigLoader) Load(cfg interface{}) error { - config := &mapstructure.DecoderConfig{ - DecodeHook: mapstructure.ComposeDecodeHookFunc( - StringToDurationHook(), - ), - WeaklyTypedInput: true, - Result: cfg, +func (cl *ConfigLoader) RegisterPlags(flags *pflag.FlagSet, prefix string, v interface{}, skipFlags bool) error { + flags.StringP("config", "c", "", "Config file path (default $HOME/.teldrive/config.toml)") + return cl.walkStruct(v, prefix, func(key string, field reflect.StructField, value reflect.Value) error { + return cl.setDefault(flags, key, field, skipFlags) + }) +} + +func (cl *ConfigLoader) setDefault(flags *pflag.FlagSet, key string, field reflect.StructField, skipFlags bool) error { + description := field.Tag.Get("description") + defaultVal := field.Tag.Get("default") + + if defaultVal != "" { + description += fmt.Sprintf(" (default %s)", defaultVal) } + if required := field.Tag.Get("required"); required == "true" { + cl.requiredFields = append(cl.requiredFields, key) + } + + flagKey := strings.ReplaceAll(key, ".", "-") - decoder, err := mapstructure.NewDecoder(config) - if err != nil { - return fmt.Errorf("failed to create decoder: %v", err) + if defaultVal != "" { + cl.v.SetDefault(key, defaultVal) } - if err := decoder.Decode(cl.v.AllSettings()); err != nil { - return fmt.Errorf("failed to decode config: %v", err) + if skipFlags { + return nil + } + + switch field.Type.Kind() { + case reflect.String: + flags.String(flagKey, "", description) + case reflect.Int: + flags.Int(flagKey, 0, description) + case reflect.Int64: + flags.Int64(flagKey, 0, description) + case reflect.Bool: + flags.Bool(flagKey, false, description) + case reflect.Slice: + switch field.Type.Elem().Kind() { + case reflect.String: + flags.StringSlice(flagKey, nil, description) + case reflect.Int: + flags.IntSlice(flagKey, nil, description) + + } + default: + if field.Type == reflect.TypeOf(time.Duration(0)) { + flags.Duration(flagKey, time.Duration(0), description) + + } + } + if err := cl.v.BindPFlag(key, flags.Lookup(flagKey)); err != nil { + return fmt.Errorf("error binding flag %s: %w", key, err) } return nil } -func AddCommonFlags(flags *pflag.FlagSet, config *ServerCmdConfig) { +func (cl *ConfigLoader) walkStruct(v interface{}, prefix string, fn func(key string, field reflect.StructField, value reflect.Value) error) error { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + typ := val.Type() + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + value := val.Field(i) + configTag := field.Tag.Get("config") + if configTag == "" { + continue + } + key := configTag + if prefix != "" { + key = prefix + "." + configTag + } + if field.Type.Kind() == reflect.Struct { + var nestedValue interface{} + if value.CanAddr() { + nestedValue = value.Addr().Interface() + } else { + nestedValue = value.Interface() + } + if err := cl.walkStruct(nestedValue, key, fn); err != nil { + return err + } + continue + } + + if err := fn(key, field, value); err != nil { + return err + } + } - flags.StringP("config", "c", "", "Config file path (default $HOME/.teldrive/config.toml)") + return nil +} + +func decodeTag(tag string) viper.DecoderConfigOption { + return func(c *mapstructure.DecoderConfig) { + c.TagName = tag + } +} - // Log config - flags.StringVar(&config.Log.Level, "log-level", zapcore.InfoLevel.String(), "Logging level") - flags.StringVar(&config.Log.File, "log-file", "", "Logging file path") - - // DB config - flags.StringVar(&config.DB.DataSource, "db-data-source", "", "Database connection string") - flags.StringVar(&config.DB.LogLevel, "db-log-level", zapcore.InfoLevel.String(), "Database log level") - flags.BoolVar(&config.DB.PrepareStmt, "db-prepare-stmt", true, "Enable prepared statements") - flags.BoolVar(&config.DB.Pool.Enable, "db-pool-enable", true, "Enable database pool") - flags.IntVar(&config.DB.Pool.MaxIdleConnections, "db-pool-max-open-connections", 25, "Database max open connections") - flags.IntVar(&config.DB.Pool.MaxIdleConnections, "db-pool-max-idle-connections", 25, "Database max idle connections") - duration.DurationVar(flags, &config.DB.Pool.MaxLifetime, "db-pool-max-lifetime", 10*time.Minute, "Database max connection lifetime") +func (cl *ConfigLoader) load(cfg interface{}) error { + return cl.v.Unmarshal(&cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( + StringToDurationHook(), + )), decodeTag("config")) } diff --git a/internal/duration/duration.go b/internal/duration/duration.go index 5c40a782..964f054e 100644 --- a/internal/duration/duration.go +++ b/internal/duration/duration.go @@ -5,8 +5,6 @@ import ( "strconv" "strings" "time" - - "github.com/spf13/pflag" ) type Duration time.Duration @@ -85,10 +83,6 @@ func newDurationValue(val time.Duration, p *time.Duration) *Duration { return (*Duration)(p) } -func DurationVar(f *pflag.FlagSet, p *time.Duration, name string, value time.Duration, usage string) { - f.VarP(newDurationValue(value, p), name, "", usage) -} - func ParseDuration(age string) (time.Duration, error) { return parseDurationFromNow(age) }