Skip to content

Commit

Permalink
router supports non-standard methods
Browse files Browse the repository at this point in the history
  • Loading branch information
xgfone committed Feb 9, 2023
1 parent 405a71d commit 98f85c7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
41 changes: 37 additions & 4 deletions router/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ type methodHandler struct {
connect interface{}
propfind interface{}
report interface{}
any interface{}
others map[string]interface{}
}

func newMethodHandler() *methodHandler { return &methodHandler{} }
func newMethodHandler() *methodHandler { return &methodHandler{others: make(map[string]interface{})} }

func (mh *methodHandler) Range(f func(string, interface{})) {
if mh.get != nil {
Expand Down Expand Up @@ -112,6 +114,14 @@ func (mh *methodHandler) Range(f func(string, interface{})) {
if mh.report != nil {
f(REPORT, mh.report)
}

if mh.any != nil {
f("", mh.any)
}

for method, handler := range mh.others {
f(method, handler)
}
}

func (mh *methodHandler) Methods() []string {
Expand All @@ -124,6 +134,11 @@ func (mh *methodHandler) DelHandler(method string) { mh.AddHandler(method, nil)
func (mh *methodHandler) AddHandler(method string, handler interface{}) {
switch method {
case "": // For Any Method
if handler == nil {
for method := range mh.others {
delete(mh.others, method)
}
}
*mh = methodHandler{
get: handler,
put: handler,
Expand All @@ -136,6 +151,8 @@ func (mh *methodHandler) AddHandler(method string, handler interface{}) {
connect: handler,
propfind: handler,
report: handler,
any: handler,
others: mh.others,
}
case http.MethodGet:
mh.get = handler
Expand All @@ -159,6 +176,12 @@ func (mh *methodHandler) AddHandler(method string, handler interface{}) {
mh.propfind = handler
case REPORT:
mh.report = handler
default:
if handler == nil {
delete(mh.others, method)
} else {
mh.others[method] = handler
}
}
}

Expand Down Expand Up @@ -187,7 +210,10 @@ func (mh *methodHandler) FindHandler(method string) interface{} {
case REPORT:
return mh.report
default:
return nil
if h, ok := mh.others[method]; ok {
return h
}
return mh.any
}
}

Expand All @@ -214,6 +240,10 @@ func (mh *methodHandler) HasHandler() bool {
return true
} else if mh.report != nil {
return true
} else if mh.any != nil {
return true
} else if len(mh.others) > 0 {
return true
}
return false
}
Expand Down Expand Up @@ -383,7 +413,7 @@ type Config struct {
// - OPTIONS
// - PROPFIND
// - REPORT
//
// - Other non-standard methods
type Router struct {
conf Config
tree *node
Expand Down Expand Up @@ -481,6 +511,8 @@ func (r *Router) Add(name, path, method string, h interface{}) (n int, err error
return 0, fmt.Errorf("route handler must not be nil")
}

method = strings.ToUpper(method)

// Validate path
if r.conf.RemoveTrailingSlash {
path = strings.TrimRight(path, "/")
Expand Down Expand Up @@ -689,6 +721,7 @@ func (r *Router) Match(path, method string, pnames, pvalues []string) (
ns string // Next search
)

method = strings.ToUpper(method)
// Search order static > param > any
for {
if search == "" {
Expand Down Expand Up @@ -826,7 +859,7 @@ func (r *Router) Match(path, method string, pnames, pvalues []string) (
// Del deletes the given route.
func (r *Router) Del(path, method string) (err error) {
if path != "" {
err = r.delRoute(path, method)
err = r.delRoute(path, strings.ToUpper(method))
}
return
}
Expand Down
17 changes: 14 additions & 3 deletions router/echo/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,30 @@ func TestRouterAnyMethod(t *testing.T) {
handler2 := 2
handler3 := 3
handler4 := 4
handler5 := 5

router := NewRouter(nil)
router.Add("", "/path1", "GET", handler1)
router.Add("", "/path2", "PUT", handler2)
router.Add("", "/path2", "POST", handler3)
router.Add("", "/path2", "", handler4)

if rs := getRoutes(router); len(rs) != 12 {
handler, _ := router.Match("/path2", "nonstandard", nil, nil)
if handler != nil {
t.Errorf("unexpect to get the handler: %v, %T", handler, handler)
}
router.Add("", "/path2", "nonstandard", handler5)
handler, _ = router.Match("/path2", "nonstandard", nil, nil)
if h, ok := handler.(int); !ok || h != 5 {
t.Error("got an unexpected handler")
}

router.Add("", "/path2", "", handler4)
if rs := getRoutes(router); len(rs) != 14 {
t.Error(rs)
}

router.Del("/path2", "POST")
if rs := getRoutes(router); len(rs) != 11 {
if rs := getRoutes(router); len(rs) != 13 {
t.Error(rs)
} else {
for _, r := range rs {
Expand Down

0 comments on commit 98f85c7

Please sign in to comment.