Skip to content

Commit

Permalink
Add config file support
Browse files Browse the repository at this point in the history
  • Loading branch information
chriskuehl committed Sep 7, 2024
1 parent 1fdaa46 commit 37fac6a
Show file tree
Hide file tree
Showing 18 changed files with 355 additions and 76 deletions.
File renamed without changes.
47 changes: 37 additions & 10 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,74 @@ import (

"github.com/chriskuehl/fluffy/server"
"github.com/chriskuehl/fluffy/server/config"
"github.com/chriskuehl/fluffy/server/config/loader"
"github.com/chriskuehl/fluffy/server/logging"
)

var Version = "(dev)"

func newConfigFromArgs(args []string) (*config.Config, error) {
c := server.NewConfig()
type cmdConfig struct {
conf *config.Config

printConfig bool
configPath string
}

func newConfigFromArgs(args []string) (*cmdConfig, error) {
c := &cmdConfig{
conf: server.NewConfig(),
}
fs := flag.NewFlagSet("fluffy", flag.ExitOnError)
fs.StringVar(&c.Host, "host", "localhost", "host to listen on")
fs.UintVar(&c.Port, "port", 8080, "port to listen on")
fs.BoolVar(&c.DevMode, "dev", false, "enable dev mode")
fs.BoolVar(&c.printConfig, "print-config", false, "print the config and exit")
fs.StringVar(&c.configPath, "config", "", "path to config file")
fs.StringVar(&c.conf.Host, "host", "localhost", "host to listen on")
fs.UintVar(&c.conf.Port, "port", 8080, "port to listen on")
fs.BoolVar(&c.conf.DevMode, "dev", false, "enable dev mode")
if err := fs.Parse(args); err != nil {
return nil, err
}
c.Version = Version
if c.configPath != "" {
if err := loader.LoadConfigTOML(c.conf, c.configPath); err != nil {
return nil, fmt.Errorf("loading config: %w", err)
}
}
c.conf.Version = Version
return c, nil
}

func run(ctx context.Context, w io.Writer, args []string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()

config, err := newConfigFromArgs(args)
conf, err := newConfigFromArgs(args)
if err != nil {
return fmt.Errorf("parsing args: %w", err)
}

if conf.printConfig {
c, err := loader.DumpConfigTOML(conf.conf)
if err != nil {
return fmt.Errorf("dumping config: %w", err)
}

fmt.Println(c)
return nil
}

logger := logging.NewSlogLogger(slog.New(slog.NewTextHandler(w, nil)))

handler, err := server.NewServer(logger, config)
handler, err := server.NewServer(logger, conf.conf)
if err != nil {
return fmt.Errorf("creating server: %w", err)
}

httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, strconv.FormatUint(uint64(config.Port), 10)),
Addr: net.JoinHostPort(conf.conf.Host, strconv.FormatUint(uint64(conf.conf.Port), 10)),
Handler: handler,
}
go func() {
logger.Info(ctx, "listening", "addr", httpServer.Addr)
if config.DevMode {
if conf.conf.DevMode {
logger.Warn(ctx, "dev mode enabled! do not use in production!")
}
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
)

require (
github.com/BurntSushi/toml v1.4.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sys v0.24.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/adrg/xdg v0.5.0 h1:dDaZvhMXatArP1NPHhnfaQUqWBLBsmx1h1HXQdMoFCY=
github.com/adrg/xdg v0.5.0/go.mod h1:dDdY4M4DF9Rjy4kHPeNL+ilVF+p2lK8IdM9/rTSGcI4=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
Expand Down
10 changes: 5 additions & 5 deletions server/assets/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ func assetObjectPath(path, hash string) string {
//
// In development mode, this will return a URL served by the fluffy server itself. In production,
// this will return a URL to the object store.
func AssetURL(c *config.Config, path string) (string, error) {
if c.DevMode {
url := c.HomeURL
func AssetURL(conf *config.Config, path string) (string, error) {
if conf.DevMode {
url := conf.HomeURL
url.Path = "/dev/static/" + path
return url.String(), nil
}
Expand All @@ -79,8 +79,8 @@ func AssetURL(c *config.Config, path string) (string, error) {
if !ok {
return "", fmt.Errorf("asset not found: %s", path)
}
url := c.ObjectURLPattern
url.Path = fmt.Sprintf(url.Path, assetObjectPath(path, hash))
url := conf.ObjectURLPattern
url.Path = strings.Replace(url.Path, "{path}", assetObjectPath(path, hash), -1)
return url.String(), nil
}

Expand Down
36 changes: 36 additions & 0 deletions server/assets/assets_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package assets_test

import (
"testing"

"github.com/chriskuehl/fluffy/server/assets"
"github.com/chriskuehl/fluffy/testfunc"
)

func TestAssetURLDev(t *testing.T) {
conf := testfunc.NewConfig()
conf.DevMode = true

got, err := assets.AssetURL(conf, "img/favicon.ico")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "http://localhost:8080/dev/static/img/favicon.ico"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}

func TestAssetURLProd(t *testing.T) {
conf := testfunc.NewConfig()
conf.DevMode = false

got, err := assets.AssetURL(conf, "img/favicon.ico")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "http://localhost:8080/dev/object/static/5b707398fe549635b8794ac8e73db6938dd7b6b7a28b339296bde1b0fdec764b/img/favicon.ico"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
4 changes: 2 additions & 2 deletions server/assets/dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"github.com/chriskuehl/fluffy/server/logging"
)

func HandleDevStatic(config *config.Config, logger logging.Logger) http.HandlerFunc {
if !config.DevMode {
func HandleDevStatic(conf *config.Config, logger logging.Logger) http.HandlerFunc {
if !conf.DevMode {
return func(w http.ResponseWriter, r *http.Request) {
logger.Warn(r.Context(), "assets cannot be served from the server in production")
w.WriteHeader(http.StatusNotFound)
Expand Down
27 changes: 15 additions & 12 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,38 @@ type Config struct {
Version string
}

func (c *Config) Validate() []string {
func (conf *Config) Validate() []string {
var errs []string
if c.Branding == "" {
if conf.Branding == "" {
errs = append(errs, "Branding must not be empty")
}
if c.AbuseContactEmail == "" {
if conf.AbuseContactEmail == "" {
errs = append(errs, "AbuseContactEmail must not be empty")
}
if c.MaxUploadBytes <= 0 {
if conf.MaxUploadBytes <= 0 {
errs = append(errs, "MaxUploadBytes must be greater than 0")
}
if c.MaxMultipartMemoryBytes <= 0 {
if conf.MaxMultipartMemoryBytes <= 0 {
errs = append(errs, "MaxMultipartMemoryBytes must be greater than 0")
}
if strings.HasSuffix(c.HomeURL.Path, "/") {
if strings.HasSuffix(conf.HomeURL.Path, "/") {
errs = append(errs, "HomeURL must not end with a slash")
}
if !strings.Contains(c.ObjectURLPattern.Path, "%s") {
errs = append(errs, "ObjectURLPattern must contain a '%s' placeholder")
if !strings.Contains(conf.ObjectURLPattern.Path, "{path}") {
errs = append(errs, "ObjectURLPattern must contain a '{path}' placeholder")
}
if !strings.Contains(c.HTMLURLPattern.Path, "%s") {
errs = append(errs, "HTMLURLPattern must contain a '%s' placeholder")
if !strings.Contains(conf.HTMLURLPattern.Path, "{path}") {
errs = append(errs, "HTMLURLPattern must contain a '{path}' placeholder")
}
for ext := range c.ForbiddenFileExtensions {
if conf.ForbiddenFileExtensions == nil {
errs = append(errs, "ForbiddenFileExtensions must not be nil")
}
for ext := range conf.ForbiddenFileExtensions {
if strings.HasPrefix(ext, ".") {
errs = append(errs, "ForbiddenFileExtensions should not start with a dot: "+ext)
}
}
if c.Version == "" {
if conf.Version == "" {
errs = append(errs, "Version must not be empty")
}
return errs
Expand Down
125 changes: 125 additions & 0 deletions server/config/loader/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package loader

import (
"fmt"
"html/template"
"net/url"

"github.com/BurntSushi/toml"
"github.com/chriskuehl/fluffy/server/config"
"github.com/chriskuehl/fluffy/server/storage"
)

type filesystemStorageBackend struct {
ObjectRoot string `toml:"object_root"`
HTMLRoot string `toml:"html_root"`
}

type configFile struct {
Branding string `toml:"branding"`
CustomFooterHTML string `toml:"custom_footer_html"`
AbuseContactEmail string `toml:"abuse_contact_email"`
MaxUploadBytes int64 `toml:"max_upload_bytes"`
MaxMultipartMemoryBytes int64 `toml:"max_multipart_memory_bytes"`
HomeURL string `toml:"home_url"`
ObjectURLPattern string `toml:"object_url_pattern"`
HTMLURLPattern string `toml:"html_url_pattern"`
ForbiddenFileExtensions []string `toml:"forbidden_file_extensions"`
Host string `toml:"host"`
Port uint `toml:"port"`

FilesystemStorageBackend *filesystemStorageBackend `toml:"filesystem_storage_backend"`
}

func LoadConfigTOML(conf *config.Config, path string) error {
var cfg configFile
md, err := toml.DecodeFile(path, &cfg)
if err != nil {
return fmt.Errorf("decoding config: %w", err)
}
if len(md.Undecoded()) > 0 {
return fmt.Errorf("unknown keys in config: %v", md.Undecoded())
}
if cfg.Branding != "" {
conf.Branding = cfg.Branding
}
if cfg.CustomFooterHTML != "" {
conf.CustomFooterHTML = template.HTML(cfg.CustomFooterHTML)
}
if cfg.AbuseContactEmail != "" {
conf.AbuseContactEmail = cfg.AbuseContactEmail
}
if cfg.MaxUploadBytes != 0 {
conf.MaxUploadBytes = cfg.MaxUploadBytes
}
if cfg.MaxMultipartMemoryBytes != 0 {
conf.MaxMultipartMemoryBytes = cfg.MaxMultipartMemoryBytes
}
if cfg.HomeURL != "" {
u, err := url.ParseRequestURI(cfg.HomeURL)
if err != nil {
return fmt.Errorf("parsing HomeURL: %w", err)
}
conf.HomeURL = *u
}
if cfg.ObjectURLPattern != "" {
u, err := url.ParseRequestURI(cfg.ObjectURLPattern)
if err != nil {
return fmt.Errorf("parsing ObjectURLPattern: %w", err)
}
conf.ObjectURLPattern = *u
}
if cfg.HTMLURLPattern != "" {
u, err := url.ParseRequestURI(cfg.HTMLURLPattern)
if err != nil {
return fmt.Errorf("parsing HTMLURLPattern: %w", err)
}
conf.HTMLURLPattern = *u
}
for _, ext := range cfg.ForbiddenFileExtensions {
conf.ForbiddenFileExtensions[ext] = struct{}{}
}
if cfg.Host != "" {
conf.Host = cfg.Host
}
if cfg.Port != 0 {
conf.Port = cfg.Port
}
if cfg.FilesystemStorageBackend != nil {
conf.StorageBackend = &storage.FilesystemBackend{
ObjectRoot: cfg.FilesystemStorageBackend.ObjectRoot,
HTMLRoot: cfg.FilesystemStorageBackend.HTMLRoot,
}
}
return nil
}

func DumpConfigTOML(conf *config.Config) (string, error) {
cfg := configFile{
Branding: conf.Branding,
CustomFooterHTML: string(conf.CustomFooterHTML),
AbuseContactEmail: conf.AbuseContactEmail,
MaxUploadBytes: conf.MaxUploadBytes,
MaxMultipartMemoryBytes: conf.MaxMultipartMemoryBytes,
HomeURL: conf.HomeURL.String(),
ObjectURLPattern: conf.ObjectURLPattern.String(),
HTMLURLPattern: conf.HTMLURLPattern.String(),
ForbiddenFileExtensions: make([]string, 0, len(conf.ForbiddenFileExtensions)),
Host: conf.Host,
Port: conf.Port,
}
for ext := range conf.ForbiddenFileExtensions {
cfg.ForbiddenFileExtensions = append(cfg.ForbiddenFileExtensions, ext)
}
if fs, ok := conf.StorageBackend.(*storage.FilesystemBackend); ok {
cfg.FilesystemStorageBackend = &filesystemStorageBackend{
ObjectRoot: fs.ObjectRoot,
HTMLRoot: fs.HTMLRoot,
}
}
buf, err := toml.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshaling config: %w", err)
}
return string(buf), nil
}
Loading

0 comments on commit 37fac6a

Please sign in to comment.