From 46a98d4ece6f45bbf51447fb21ccfa3c647a9c4b Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Mon, 30 Dec 2024 16:06:09 +0800 Subject: [PATCH] feat: support koemotion --- README.md | 3 +- cspell.config.yaml | 5 + go.mod | 3 + go.sum | 6 + pkg/backend/backend.go | 77 +++++++--- pkg/backend/elevenlabs.go | 9 +- pkg/backend/koemotion.go | 92 ++++++++++++ pkg/backend/openai.go | 14 +- pkg/utils/json.go | 5 +- pkg/utils/jsonpatch/jsonpatch.go | 67 +++++++++ pkg/utils/jsonpatch/jsonpatch_test.go | 54 +++++++ pkg/utils/mo.go | 35 +++++ pkg/utils/string.go | 140 ++++++++++++++++++ pkg/utils/string_test.go | 204 ++++++++++++++++++++++++++ 14 files changed, 681 insertions(+), 33 deletions(-) create mode 100644 pkg/backend/koemotion.go create mode 100644 pkg/utils/jsonpatch/jsonpatch.go create mode 100644 pkg/utils/jsonpatch/jsonpatch_test.go create mode 100644 pkg/utils/mo.go diff --git a/README.md b/README.md index 07f6c9a..73f753a 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ unSpeech lets you use various online TTS with OpenAI-compatible API. - [OpenAI](https://platform.openai.com/docs/api-reference/audio/createSpeech) - [ElevenLabs](https://elevenlabs.io/docs/api-reference/text-to-speech/convert) +- [Koemotion (by Rinna)](https://koemotion.rinna.co.jp/) ## Getting Started @@ -36,7 +37,7 @@ You can use unSpeech with most OpenAI clients. The `model` parameter should be provider + model, e.g. `openai/tts-1-hd`, `elevenlabs/eleven_multilingual_v2`. -The `Authorization` header is auto-converted to the vendor's corresponding auth method, such as `xi-api-key`. +The `Authorization` header is auto-converted to the vendor's corresponding auth method, such as `xi-api-key`. ###### `curl` diff --git a/cspell.config.yaml b/cspell.config.yaml index 8f7470e..22096d5 100644 --- a/cspell.config.yaml +++ b/cspell.config.yaml @@ -7,6 +7,7 @@ words: - containedctx - contextcheck - cyclop + - dataurl - depguard - Describedby - Detailf @@ -20,6 +21,7 @@ words: - exhaustive - exhaustruct - exportloopref + - Facemotion - flac - forcetypeassert - funlen @@ -43,6 +45,8 @@ words: - hreflang - ineffassign - ireturn + - jsonpatch + - koemotion - labstack - lll - maintidx @@ -62,6 +66,7 @@ words: - predeclared - reassign - revive + - Rinna - samber - staticcheck - strconv diff --git a/go.mod b/go.mod index 0657757..230682a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/moeru-ai/unspeech go 1.23.2 require ( + github.com/evanphx/json-patch/v5 v5.9.0 github.com/golang-module/carbon v1.7.3 github.com/labstack/echo/v4 v4.13.2 github.com/nekomeowww/fo v1.4.0 @@ -10,6 +11,7 @@ require ( github.com/samber/mo v1.13.0 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.10.0 + github.com/vincent-petithory/dataurl v1.0.0 k8s.io/client-go v0.32.0 ) @@ -23,6 +25,7 @@ require ( github.com/labstack/gommon v0.4.2 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.3.0 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index 96515db..7ac26c5 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= +github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gobuffalo/envy v1.7.0 h1:GlXgaiBkmrYMHco6t4j7SacKO4XUjvh5pwXh0f4uxXU= github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= @@ -50,6 +52,8 @@ github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh github.com/nekomeowww/fo v1.4.0 h1:ULX5KsnDzWHoDwHgtjd2wibpdpyh+5/5DITmvhJZyWY= github.com/nekomeowww/fo v1.4.0/go.mod h1:ctwQ+BZ0UYUb2s+yM7h9SFHjqGCXeUIXFLK2ujAneWw= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -89,6 +93,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8AbShPRpg2CI= +github.com/vincent-petithory/dataurl v1.0.0/go.mod h1:FHafX5vmDzyP+1CQATJn7WFKc9CvnvxyvZy6I1MrG/U= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index ddc8b3f..08ff456 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -1,6 +1,9 @@ package backend import ( + "bytes" + "encoding/json" + "io" "strings" "github.com/labstack/echo/v4" @@ -8,10 +11,11 @@ import ( "github.com/samber/mo" "github.com/moeru-ai/unspeech/pkg/apierrors" + "github.com/moeru-ai/unspeech/pkg/utils" ) -// Options represent API parameters refer to https://platform.openai.com/docs/api-reference/audio/createSpeech -type Options struct { +// OpenAISpeechRequestOptions represent API parameters refer to https://platform.openai.com/docs/api-reference/audio/createSpeech +type OpenAISpeechRequestOptions struct { // (required) One of the available TTS models. Model string `json:"model"` // (required) The text to generate audio for. @@ -29,20 +33,48 @@ type Options struct { Speed int `json:"speed,omitempty"` } -type FullOptions struct { - Options +type SpeechRequestOptions struct { + OpenAISpeechRequestOptions + Backend string `json:"backend"` Model string `json:"model"` + + body mo.Option[*bytes.Buffer] + bodyParsedMap map[string]any } -func Speech(c echo.Context) mo.Result[any] { - var options Options +func (o SpeechRequestOptions) AsBuffer() mo.Option[*bytes.Buffer] { + return o.body +} - if err := c.Bind(&options); err != nil { - return mo.Err[any](apierrors.NewErrBadRequest()) +func (o SpeechRequestOptions) AsMap() map[string]any { + return o.bodyParsedMap +} + +func NewSpeechRequestOptions(body io.ReadCloser) mo.Result[SpeechRequestOptions] { + buffer := new(bytes.Buffer) + + _, err := buffer.ReadFrom(body) + if err != nil { + return mo.Err[SpeechRequestOptions](apierrors.NewErrBadRequest().WithDetail(err.Error())) } + + var optionsMap map[string]any + + err = json.Unmarshal(buffer.Bytes(), &optionsMap) + if err != nil { + return mo.Err[SpeechRequestOptions](apierrors.NewErrBadRequest().WithDetail(err.Error())) + } + + var options OpenAISpeechRequestOptions + + err = json.Unmarshal(buffer.Bytes(), &options) + if err != nil { + return mo.Err[SpeechRequestOptions](apierrors.NewErrBadRequest().WithDetail(err.Error())) + } + if options.Model == "" || options.Input == "" || options.Voice == "" { - return mo.Err[any](apierrors.NewErrInvalidArgument().WithDetail("either one of model, input, and voice parameter is required")) + return mo.Err[SpeechRequestOptions](apierrors.NewErrInvalidArgument().WithDetail("either one of model, input, and voice parameter is required")) } backendAndModel := lo.Ternary( @@ -51,18 +83,29 @@ func Speech(c echo.Context) mo.Result[any] { []string{options.Model, ""}, ) - fullOptions := FullOptions{ - Options: options, - Backend: backendAndModel[0], - Model: backendAndModel[1], + return mo.Ok(SpeechRequestOptions{ + OpenAISpeechRequestOptions: options, + Backend: backendAndModel[0], + Model: backendAndModel[1], + body: mo.Some(buffer), + bodyParsedMap: optionsMap, + }) +} + +func Speech(c echo.Context) mo.Result[any] { + options := NewSpeechRequestOptions(c.Request().Body) + if options.IsError() { + return mo.Err[any](options.Error()) } - switch backendAndModel[0] { + switch options.MustGet().Backend { case "openai": - return openai(c, fullOptions) + return openai(c, utils.ResultToOption(options)) case "elevenlabs": - return elevenlabs(c, fullOptions) + return elevenlabs(c, utils.ResultToOption(options)) + case "koemotion": + return koemotion(c, utils.ResultToOption(options)) default: - return mo.Err[any](apierrors.NewErrBadRequest()) + return mo.Err[any](apierrors.NewErrBadRequest().WithDetail("unsupported backend")) } } diff --git a/pkg/backend/elevenlabs.go b/pkg/backend/elevenlabs.go index a0abe69..8c9b3b6 100644 --- a/pkg/backend/elevenlabs.go +++ b/pkg/backend/elevenlabs.go @@ -18,16 +18,17 @@ import ( type ElevenLabsOptions struct { Text string `json:"text"` ModelID string `json:"model_id,omitempty"` + // TODO: support other options } -func elevenlabs(c echo.Context, options FullOptions) mo.Result[any] { +func elevenlabs(c echo.Context, options mo.Option[SpeechRequestOptions]) mo.Result[any] { reqURL := lo.Must(url.Parse("https://api.elevenlabs.io/v1/text-to-speech")). - JoinPath(options.Voice). + JoinPath(options.MustGet().Voice). String() values := ElevenLabsOptions{ - Text: options.Input, - ModelID: options.Model, + Text: options.MustGet().Input, + ModelID: options.MustGet().Model, } payload := lo.Must(json.Marshal(values)) diff --git a/pkg/backend/koemotion.go b/pkg/backend/koemotion.go new file mode 100644 index 0000000..f437bc6 --- /dev/null +++ b/pkg/backend/koemotion.go @@ -0,0 +1,92 @@ +package backend + +import ( + "bytes" + "encoding/json" + "log/slog" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + "github.com/moeru-ai/unspeech/pkg/apierrors" + "github.com/moeru-ai/unspeech/pkg/utils" + "github.com/moeru-ai/unspeech/pkg/utils/jsonpatch" + "github.com/samber/mo" + "github.com/vincent-petithory/dataurl" +) + +func koemotion(c echo.Context, options mo.Option[SpeechRequestOptions]) mo.Result[any] { + patchedPayload := jsonpatch.ApplyPatches( + options.MustGet().body.OrElse(new(bytes.Buffer)).Bytes(), + mo.Some(jsonpatch.ApplyOptions{AllowMissingPathOnRemove: true}), + jsonpatch.NewRemove("/model"), + jsonpatch.NewRemove("/voice"), + jsonpatch.NewRemove("/input"), + jsonpatch.NewAdd("/text", options.MustGet().Input), + ) + if patchedPayload.IsError() { + return mo.Err[any](apierrors.NewErrInternal().WithDetail(patchedPayload.Error().Error()).WithCaller()) + } + + req, err := http.NewRequestWithContext( + c.Request().Context(), + http.MethodPost, + "https://api.rinna.co.jp/koemotion/infer", + bytes.NewBuffer(patchedPayload.MustGet()), + ) + if err != nil { + return mo.Err[any](apierrors.NewErrInternal().WithCaller()) + } + + // Rewrite the Authorization header + req.Header.Set("Ocp-Apim-Subscription-Key", strings.TrimPrefix( + c.Request().Header.Get("Authorization"), + "Bearer ", + )) + req.Header.Set("Content-Type", "application/json") + + res, err := http.DefaultClient.Do(req) + if err != nil { + return mo.Err[any](apierrors.NewErrBadGateway().WithDetail(err.Error()).WithError(err).WithCaller()) + } + + defer func() { _ = res.Body.Close() }() + + if res.StatusCode >= 400 && res.StatusCode < 600 { + switch { + case strings.HasPrefix(res.Header.Get("Content-Type"), "application/json"): + return mo.Err[any](apierrors. + NewUpstreamError(res.StatusCode). + WithDetail(NewJSONResponseError(res.StatusCode, res.Body).OrEmpty().Error())) + case strings.HasPrefix(res.Header.Get("Content-Type"), "text/"): + return mo.Err[any](apierrors. + NewUpstreamError(res.StatusCode). + WithDetail(NewTextResponseError(res.StatusCode, res.Body).OrEmpty().Error())) + default: + slog.Warn("unknown upstream error with unknown Content-Type", + slog.Int("status", res.StatusCode), + slog.String("content-type", res.Header.Get("Content-Type")), + slog.String("content-length", res.Header.Get("Content-Length")), + ) + } + } + + var resBody map[string]any + + err = json.NewDecoder(res.Body).Decode(&resBody) + if err != nil { + return mo.Err[any](apierrors.NewErrInternal().WithDetail(err.Error()).WithError(err).WithCaller()) + } + + audioDataURLString := utils.GetByJSONPath[string](resBody, "{ .audio }") + if audioDataURLString == "" { + return mo.Err[any](apierrors.NewErrInternal().WithDetail("upstream returned empty audio data URL").WithCaller()) + } + + audioDataURL, err := dataurl.DecodeString(audioDataURLString) + if err != nil { + return mo.Err[any](apierrors.NewErrInternal().WithDetail(err.Error()).WithError(err).WithCaller()) + } + + return mo.Ok[any](c.Blob(http.StatusOK, "audio/mp3", audioDataURL.Data)) +} diff --git a/pkg/backend/openai.go b/pkg/backend/openai.go index 6dd27af..e099a0b 100644 --- a/pkg/backend/openai.go +++ b/pkg/backend/openai.go @@ -13,13 +13,13 @@ import ( "github.com/samber/mo" ) -func openai(c echo.Context, options FullOptions) mo.Result[any] { - values := Options{ - Model: options.Model, - Input: options.Input, - Voice: options.Voice, - ResponseFormat: options.ResponseFormat, - Speed: options.Speed, +func openai(c echo.Context, options mo.Option[SpeechRequestOptions]) mo.Result[any] { + values := OpenAISpeechRequestOptions{ + Model: options.MustGet().Model, + Input: options.MustGet().Input, + Voice: options.MustGet().Voice, + ResponseFormat: options.MustGet().ResponseFormat, + Speed: options.MustGet().Speed, } payload := lo.Must(json.Marshal(values)) diff --git a/pkg/utils/json.go b/pkg/utils/json.go index 30625ab..8ed0e52 100644 --- a/pkg/utils/json.go +++ b/pkg/utils/json.go @@ -39,9 +39,7 @@ func GetByJSONPath[T any](input any, template string) T { } func ReadAsJSONWithClose(readCloser io.ReadCloser) (*bytes.Buffer, map[string]any, error) { - defer func() { - _ = readCloser.Close() - }() + defer func() { readCloser.Close() }() buffer, jsonMap, err := ReadAsJSON(readCloser) if err != nil { @@ -72,7 +70,6 @@ func FromMap[T any, MK comparable, MV any](m map[MK]MV) (*T, error) { if m == nil { return nil, nil } - if len(m) == 0 { return nil, nil } diff --git a/pkg/utils/jsonpatch/jsonpatch.go b/pkg/utils/jsonpatch/jsonpatch.go new file mode 100644 index 0000000..bd72986 --- /dev/null +++ b/pkg/utils/jsonpatch/jsonpatch.go @@ -0,0 +1,67 @@ +package jsonpatch + +import ( + "encoding/json" + + jsonpatch "github.com/evanphx/json-patch/v5" + "github.com/moeru-ai/unspeech/pkg/utils" + "github.com/samber/lo" + "github.com/samber/mo" +) + +type JSONPatchOperation string + +const ( + JSONPatchOperationAdd JSONPatchOperation = "add" + JSONPatchOperationRemove JSONPatchOperation = "remove" + JSONPatchOperationReplace JSONPatchOperation = "replace" +) + +type JSONPatchOperationObject struct { + Operation JSONPatchOperation `json:"op"` + Path string `json:"path"` + Value any `json:"value,omitempty"` +} + +func NewPatches(operations ...mo.Option[JSONPatchOperationObject]) []byte { + return lo.Must(json.Marshal(utils.MapOptionsPresent(operations))) +} + +func NewReplace(path string, to any) mo.Option[JSONPatchOperationObject] { + return mo.Some(JSONPatchOperationObject{ + Operation: JSONPatchOperationReplace, + Path: path, + Value: to, + }) +} + +func NewAdd(path string, value any) mo.Option[JSONPatchOperationObject] { + return mo.Some(JSONPatchOperationObject{ + Operation: JSONPatchOperationAdd, + Path: path, + Value: value, + }) +} + +func NewRemove(path string) mo.Option[JSONPatchOperationObject] { + return mo.Some(JSONPatchOperationObject{ + Operation: JSONPatchOperationRemove, + Path: path, + }) +} + +type ApplyOptions jsonpatch.ApplyOptions + +func ApplyPatches(bytes []byte, applyOpt mo.Option[ApplyOptions], patches ...mo.Option[JSONPatchOperationObject]) mo.Result[[]byte] { + patch, err := jsonpatch.DecodePatch(NewPatches(patches...)) + if err != nil { + return mo.Err[[]byte](err) + } + + patched, err := patch.Apply(bytes) + if err != nil { + return mo.Err[[]byte](err) + } + + return mo.Ok[[]byte](patched) +} diff --git a/pkg/utils/jsonpatch/jsonpatch_test.go b/pkg/utils/jsonpatch/jsonpatch_test.go new file mode 100644 index 0000000..cd90548 --- /dev/null +++ b/pkg/utils/jsonpatch/jsonpatch_test.go @@ -0,0 +1,54 @@ +package jsonpatch + +import ( + "encoding/json" + "testing" + + jsonpatch "github.com/evanphx/json-patch/v5" + "github.com/samber/lo" + "github.com/stretchr/testify/require" +) + +func TestJSONPatchReplace(t *testing.T) { + patch, err := jsonpatch.DecodePatch(NewPatches( + NewReplace("/model", "gpt-3.5-turbo"), + )) + require.NoError(t, err) + + patched, err := patch.Apply(lo.Must(json.Marshal(map[string]interface{}{ + "model": "gpt-3.5", + }))) + require.NoError(t, err) + + require.JSONEq(t, `{"model":"gpt-3.5-turbo"}`, string(patched)) +} + +func TestJSONPatchAdd(t *testing.T) { + patch, err := jsonpatch.DecodePatch(NewPatches( + NewAdd("/stream_options", map[string]any{ + "include_usage": true, + }), + )) + require.NoError(t, err) + + patched, err := patch.Apply(lo.Must(json.Marshal(map[string]interface{}{ + "model": "gpt-3.5", + }))) + require.NoError(t, err) + + require.JSONEq(t, `{"model":"gpt-3.5","stream_options":{"include_usage":true}}`, string(patched)) +} + +func TestJSONPatchRemove(t *testing.T) { + patch, err := jsonpatch.DecodePatch(NewPatches( + NewRemove("/model"), + )) + require.NoError(t, err) + + patched, err := patch.Apply(lo.Must(json.Marshal(map[string]interface{}{ + "model": "gpt-3.5", + }))) + require.NoError(t, err) + + require.JSONEq(t, `{}`, string(patched)) +} diff --git a/pkg/utils/mo.go b/pkg/utils/mo.go new file mode 100644 index 0000000..8f321e1 --- /dev/null +++ b/pkg/utils/mo.go @@ -0,0 +1,35 @@ +package utils + +import ( + "github.com/samber/lo" + "github.com/samber/mo" +) + +func FilterOptionPresent[T any](item mo.Option[T], _ int) bool { + return item.IsPresent() +} + +func FilterOptionAbsent[T any](item mo.Option[T], _ int) bool { + return item.IsAbsent() +} + +func MapOptionOrEmpty[T any](item mo.Option[T], _ int) T { + return item.OrEmpty() +} + +func MapOptionMust(item mo.Option[error], _ int) error { + return item.MustGet() +} + +func MapOptionsPresent[T any](items []mo.Option[T]) []T { + filtered := lo.Filter(items, FilterOptionPresent) + return lo.Map(filtered, MapOptionOrEmpty) +} + +func ResultToOption[T any](item mo.Result[T]) mo.Option[T] { + if item.IsError() { + return mo.None[T]() + } + + return mo.Some(item.MustGet()) +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go index 1bed9fc..aa2dc4e 100644 --- a/pkg/utils/string.go +++ b/pkg/utils/string.go @@ -5,6 +5,8 @@ import ( "fmt" "strconv" "strings" + + "github.com/samber/lo" ) var ( @@ -39,6 +41,9 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo case string: val, _ := any(str).(T) return val, nil + case *string: + val, _ := any(&str).(T) + return val, nil case int: val, err := strconv.ParseInt(str, 10, 0) if err != nil { @@ -47,6 +52,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(int(val)).(T) + return typeVal, nil + case *int: + val, err := strconv.ParseInt(str, 10, 0) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(int(val))).(T) + return typeVal, nil case int8: val, err := strconv.ParseInt(str, 10, 8) @@ -56,6 +70,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(int8(val)).(T) + return typeVal, nil + case *int8: + val, err := strconv.ParseInt(str, 10, 8) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(int8(val))).(T) + return typeVal, nil case int16: val, err := strconv.ParseInt(str, 10, 16) @@ -65,6 +88,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(int16(val)).(T) + return typeVal, nil + case *int16: + val, err := strconv.ParseInt(str, 10, 16) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(int16(val))).(T) + return typeVal, nil case int32: val, err := strconv.ParseInt(str, 10, 32) @@ -74,6 +106,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(int32(val)).(T) + return typeVal, nil + case *int32: + val, err := strconv.ParseInt(str, 10, 32) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(int32(val))).(T) + return typeVal, nil case int64: val, err := strconv.ParseInt(str, 10, 64) @@ -83,6 +124,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(val).(T) + return typeVal, nil + case *int64: + val, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(val)).(T) + return typeVal, nil case uint: val, err := strconv.ParseUint(str, 10, 0) @@ -92,6 +142,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(uint(val)).(T) + return typeVal, nil + case *uint: + val, err := strconv.ParseUint(str, 10, 0) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(uint(val))).(T) + return typeVal, nil case uint8: val, err := strconv.ParseUint(str, 10, 8) @@ -101,6 +160,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(uint8(val)).(T) + return typeVal, nil + case *uint8: + val, err := strconv.ParseUint(str, 10, 8) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(uint8(val))).(T) + return typeVal, nil case uint16: val, err := strconv.ParseUint(str, 10, 16) @@ -110,6 +178,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(uint16(val)).(T) + return typeVal, nil + case *uint16: + val, err := strconv.ParseUint(str, 10, 16) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(uint16(val))).(T) + return typeVal, nil case uint32: val, err := strconv.ParseUint(str, 10, 32) @@ -119,6 +196,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(uint32(val)).(T) + return typeVal, nil + case *uint32: + val, err := strconv.ParseUint(str, 10, 32) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(uint32(val))).(T) + return typeVal, nil case uint64: val, err := strconv.ParseUint(str, 10, 64) @@ -128,6 +214,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(val).(T) + return typeVal, nil + case *uint64: + val, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(val)).(T) + return typeVal, nil case float32: val, err := strconv.ParseFloat(str, 32) @@ -137,6 +232,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(float32(val)).(T) + return typeVal, nil + case *float32: + val, err := strconv.ParseFloat(str, 32) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(float32(val))).(T) + return typeVal, nil case float64: val, err := strconv.ParseFloat(str, 64) @@ -146,6 +250,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(val).(T) + return typeVal, nil + case *float64: + val, err := strconv.ParseFloat(str, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(val)).(T) + return typeVal, nil case complex64: val, err := strconv.ParseComplex(str, 64) @@ -155,6 +268,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(complex64(val)).(T) + return typeVal, nil + case *complex64: + val, err := strconv.ParseComplex(str, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(complex64(val))).(T) + return typeVal, nil case complex128: val, err := strconv.ParseComplex(str, 128) @@ -164,6 +286,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(val).(T) + return typeVal, nil + case *complex128: + val, err := strconv.ParseComplex(str, 128) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(val)).(T) + return typeVal, nil case bool: val, err := strconv.ParseBool(str) @@ -173,6 +304,15 @@ func FromString[T any](str string) (T, error) { //nolint:gocyclo typeVal, _ := any(val).(T) + return typeVal, nil + case *bool: + val, err := strconv.ParseBool(str) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(lo.ToPtr(val)).(T) + return typeVal, nil case []byte: val, _ := any([]byte(str)).(T) diff --git a/pkg/utils/string_test.go b/pkg/utils/string_test.go index ace9e3e..29c5dfb 100644 --- a/pkg/utils/string_test.go +++ b/pkg/utils/string_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -46,66 +47,130 @@ func TestFromString(t *testing.T) { require.NoError(t, err) assert.Equal(t, "", stringVal) + stringPtrVal, err := FromString[*string]("") + require.NoError(t, err) + assert.Nil(t, stringPtrVal) + intVal, err := FromString[int]("") require.NoError(t, err) assert.Zero(t, intVal) + intPtrVal, err := FromString[*int]("") + require.NoError(t, err) + assert.Nil(t, intPtrVal) + int8Val, err := FromString[int8]("") require.NoError(t, err) assert.Zero(t, int8Val) + int8PtrVal, err := FromString[*int8]("") + require.NoError(t, err) + assert.Nil(t, int8PtrVal) + int16Val, err := FromString[int16]("") require.NoError(t, err) assert.Zero(t, int16Val) + int16PtrVal, err := FromString[*int16]("") + require.NoError(t, err) + assert.Nil(t, int16PtrVal) + int32Val, err := FromString[int32]("") require.NoError(t, err) assert.Zero(t, int32Val) + int32PtrVal, err := FromString[*int32]("") + require.NoError(t, err) + assert.Nil(t, int32PtrVal) + int64Val, err := FromString[int64]("") require.NoError(t, err) assert.Zero(t, int64Val) + int64PtrVal, err := FromString[*int64]("") + require.NoError(t, err) + assert.Nil(t, int64PtrVal) + uintVal, err := FromString[uint]("") require.NoError(t, err) assert.Zero(t, uintVal) + uintPtrVal, err := FromString[*uint]("") + require.NoError(t, err) + assert.Nil(t, uintPtrVal) + uint8Val, err := FromString[uint8]("") require.NoError(t, err) assert.Zero(t, uint8Val) + uint8PtrVal, err := FromString[*uint8]("") + require.NoError(t, err) + assert.Nil(t, uint8PtrVal) + uint16Val, err := FromString[uint16]("") require.NoError(t, err) assert.Zero(t, uint16Val) + uint16PtrVal, err := FromString[*uint16]("") + require.NoError(t, err) + assert.Nil(t, uint16PtrVal) + uint32Val, err := FromString[uint32]("") require.NoError(t, err) assert.Zero(t, uint32Val) + uint32PtrVal, err := FromString[*uint32]("") + require.NoError(t, err) + assert.Nil(t, uint32PtrVal) + uint64Val, err := FromString[uint64]("") require.NoError(t, err) assert.Zero(t, uint64Val) + uint64PtrVal, err := FromString[*uint64]("") + require.NoError(t, err) + assert.Nil(t, uint64PtrVal) + float32Val, err := FromString[float32]("") require.NoError(t, err) assert.Zero(t, float32Val) + float32PtrVal, err := FromString[*float32]("") + require.NoError(t, err) + assert.Nil(t, float32PtrVal) + float64Val, err := FromString[float64]("") require.NoError(t, err) assert.Zero(t, float64Val) + float64PtrVal, err := FromString[*float64]("") + require.NoError(t, err) + assert.Nil(t, float64PtrVal) + complex64Val, err := FromString[complex64]("") require.NoError(t, err) assert.Zero(t, complex64Val) + complex64PtrVal, err := FromString[*complex64]("") + require.NoError(t, err) + assert.Nil(t, complex64PtrVal) + complex128Val, err := FromString[complex128]("") require.NoError(t, err) assert.Zero(t, complex128Val) + complex128PtrVal, err := FromString[*complex128]("") + require.NoError(t, err) + assert.Nil(t, complex128PtrVal) + boolVal, err := FromString[bool]("") require.NoError(t, err) assert.False(t, boolVal) + boolPtrVal, err := FromString[*bool]("") + require.NoError(t, err) + assert.Nil(t, boolPtrVal) + bytesVal, err := FromString[[]byte]("") require.NoError(t, err) assert.Empty(t, bytesVal) @@ -142,76 +207,151 @@ func TestFromString(t *testing.T) { require.EqualError(t, err, "failed to convert string to type int: strconv.ParseInt: parsing \"invalid\": invalid syntax") assert.Zero(t, intVal) + intPtrVal, err := FromString[*int]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *int: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Nil(t, intPtrVal) + int8Val, err := FromString[int8]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type int8: strconv.ParseInt: parsing \"invalid\": invalid syntax") assert.Zero(t, int8Val) + int8PtrVal, err := FromString[*int8]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *int8: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Nil(t, int8PtrVal) + int16Val, err := FromString[int16]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type int16: strconv.ParseInt: parsing \"invalid\": invalid syntax") assert.Zero(t, int16Val) + int16PtrVal, err := FromString[*int16]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *int16: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Nil(t, int16PtrVal) + int32Val, err := FromString[int32]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type int32: strconv.ParseInt: parsing \"invalid\": invalid syntax") assert.Zero(t, int32Val) + int32PtrVal, err := FromString[*int32]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *int32: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Nil(t, int32PtrVal) + int64Val, err := FromString[int64]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type int64: strconv.ParseInt: parsing \"invalid\": invalid syntax") assert.Zero(t, int64Val) + int64PtrVal, err := FromString[*int64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *int64: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Nil(t, int64PtrVal) + uintVal, err := FromString[uint]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type uint: strconv.ParseUint: parsing \"invalid\": invalid syntax") assert.Zero(t, uintVal) + uintPtrVal, err := FromString[*uint]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *uint: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Nil(t, uintPtrVal) + uint8Val, err := FromString[uint8]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type uint8: strconv.ParseUint: parsing \"invalid\": invalid syntax") assert.Zero(t, uint8Val) + uint8PtrVal, err := FromString[*uint8]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *uint8: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Nil(t, uint8PtrVal) + uint16Val, err := FromString[uint16]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type uint16: strconv.ParseUint: parsing \"invalid\": invalid syntax") assert.Zero(t, uint16Val) + uint16PtrVal, err := FromString[*uint16]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *uint16: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Nil(t, uint16PtrVal) + uint32Val, err := FromString[uint32]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type uint32: strconv.ParseUint: parsing \"invalid\": invalid syntax") assert.Zero(t, uint32Val) + uint32PtrVal, err := FromString[*uint32]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *uint32: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Nil(t, uint32PtrVal) + uint64Val, err := FromString[uint64]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type uint64: strconv.ParseUint: parsing \"invalid\": invalid syntax") assert.Zero(t, uint64Val) + uint64PtrVal, err := FromString[*uint64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *uint64: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Nil(t, uint64PtrVal) + float32Val, err := FromString[float32]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type float32: strconv.ParseFloat: parsing \"invalid\": invalid syntax") assert.Zero(t, float32Val) + float32PtrVal, err := FromString[*float32]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *float32: strconv.ParseFloat: parsing \"invalid\": invalid syntax") + assert.Nil(t, float32PtrVal) + float64Val, err := FromString[float64]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type float64: strconv.ParseFloat: parsing \"invalid\": invalid syntax") assert.Zero(t, float64Val) + float64PtrVal, err := FromString[*float64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *float64: strconv.ParseFloat: parsing \"invalid\": invalid syntax") + assert.Nil(t, float64PtrVal) + complex64Val, err := FromString[complex64]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type complex64: strconv.ParseComplex: parsing \"invalid\": invalid syntax") assert.Zero(t, complex64Val) + complex64PtrVal, err := FromString[*complex64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *complex64: strconv.ParseComplex: parsing \"invalid\": invalid syntax") + assert.Nil(t, complex64PtrVal) + complex128Val, err := FromString[complex128]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type complex128: strconv.ParseComplex: parsing \"invalid\": invalid syntax") assert.Zero(t, complex128Val) + complex128PtrVal, err := FromString[*complex128]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *complex128: strconv.ParseComplex: parsing \"invalid\": invalid syntax") + assert.Nil(t, complex128PtrVal) + boolVal, err := FromString[bool]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type bool: strconv.ParseBool: parsing \"invalid\": invalid syntax") assert.False(t, boolVal) + boolPtrVal, err := FromString[*bool]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type *bool: strconv.ParseBool: parsing \"invalid\": invalid syntax") + assert.Nil(t, boolPtrVal) + mapVal, err := FromString[map[string]any]("invalid") require.Error(t, err) require.EqualError(t, err, "failed to convert string to type map[string]interface {}: invalid character 'i' looking for beginning of value") @@ -248,66 +388,130 @@ func TestFromString(t *testing.T) { require.NoError(t, err) assert.Equal(t, "abcd", stringVal) + stringPtrVal, err := FromString[*string]("abcd") + require.NoError(t, err) + assert.Equal(t, "abcd", lo.FromPtr(stringPtrVal)) + intVal, err := FromString[int]("1234") require.NoError(t, err) assert.Equal(t, 1234, intVal) + intPtrVal, err := FromString[*int]("1234") + require.NoError(t, err) + assert.Equal(t, 1234, lo.FromPtr(intPtrVal)) + int8Val, err := FromString[int8]("123") require.NoError(t, err) assert.Equal(t, int8(123), int8Val) + int8PtrVal, err := FromString[*int8]("123") + require.NoError(t, err) + assert.Equal(t, int8(123), lo.FromPtr(int8PtrVal)) + int16Val, err := FromString[int16]("1234") require.NoError(t, err) assert.Equal(t, int16(1234), int16Val) + int16PtrVal, err := FromString[*int16]("1234") + require.NoError(t, err) + assert.Equal(t, int16(1234), lo.FromPtr(int16PtrVal)) + int32Val, err := FromString[int32]("1234") require.NoError(t, err) assert.Equal(t, int32(1234), int32Val) + int32PtrVal, err := FromString[*int32]("1234") + require.NoError(t, err) + assert.Equal(t, int32(1234), lo.FromPtr(int32PtrVal)) + int64Val, err := FromString[int64]("1234") require.NoError(t, err) assert.Equal(t, int64(1234), int64Val) + int64PtrVal, err := FromString[*int64]("1234") + require.NoError(t, err) + assert.Equal(t, int64(1234), lo.FromPtr(int64PtrVal)) + uintVal, err := FromString[uint]("1234") require.NoError(t, err) assert.Equal(t, uint(1234), uintVal) + uintPtrVal, err := FromString[*uint]("1234") + require.NoError(t, err) + assert.Equal(t, uint(1234), lo.FromPtr(uintPtrVal)) + uint8Val, err := FromString[uint8]("123") require.NoError(t, err) assert.Equal(t, uint8(123), uint8Val) + uint8PtrVal, err := FromString[*uint8]("123") + require.NoError(t, err) + assert.Equal(t, uint8(123), lo.FromPtr(uint8PtrVal)) + uint16Val, err := FromString[uint16]("1234") require.NoError(t, err) assert.Equal(t, uint16(1234), uint16Val) + uint16PtrVal, err := FromString[*uint16]("1234") + require.NoError(t, err) + assert.Equal(t, uint16(1234), lo.FromPtr(uint16PtrVal)) + uint32Val, err := FromString[uint32]("1234") require.NoError(t, err) assert.Equal(t, uint32(1234), uint32Val) + uint32PtrVal, err := FromString[*uint32]("1234") + require.NoError(t, err) + assert.Equal(t, uint32(1234), lo.FromPtr(uint32PtrVal)) + uint64Val, err := FromString[uint64]("1234") require.NoError(t, err) assert.Equal(t, uint64(1234), uint64Val) + uint64PtrVal, err := FromString[*uint64]("1234") + require.NoError(t, err) + assert.Equal(t, uint64(1234), lo.FromPtr(uint64PtrVal)) + float32Val, err := FromString[float32]("1234.56") require.NoError(t, err) assert.InDelta(t, float32(1234.56), float32Val, 0.0001) + float32PtrVal, err := FromString[*float32]("1234.56") + require.NoError(t, err) + assert.InDelta(t, float32(1234.56), lo.FromPtr(float32PtrVal), 0.0001) + float64Val, err := FromString[float64]("1234.56") require.NoError(t, err) assert.InDelta(t, float64(1234.56), float64Val, 0.0001) + float64PtrVal, err := FromString[*float64]("1234.56") + require.NoError(t, err) + assert.InDelta(t, float64(1234.56), lo.FromPtr(float64PtrVal), 0.0001) + complex64Val, err := FromString[complex64]("1234.56") require.NoError(t, err) assert.Equal(t, complex64(1234.56), complex64Val) + complex64PtrVal, err := FromString[*complex64]("1234.56") + require.NoError(t, err) + assert.Equal(t, complex64(1234.56), lo.FromPtr(complex64PtrVal)) + complex128Val, err := FromString[complex128]("1234.56") require.NoError(t, err) assert.Equal(t, complex128(1234.56), complex128Val) + complex128PtrVal, err := FromString[*complex128]("1234.56") + require.NoError(t, err) + assert.Equal(t, complex128(1234.56), lo.FromPtr(complex128PtrVal)) + boolVal, err := FromString[bool]("true") require.NoError(t, err) assert.True(t, boolVal) + boolPtrVal, err := FromString[*bool]("true") + require.NoError(t, err) + assert.True(t, lo.FromPtr(boolPtrVal)) + bytesVal, err := FromString[[]byte]("abcd") require.NoError(t, err) assert.Equal(t, []byte("abcd"), bytesVal)