diff --git a/pprof.go b/pprof.go index 7cf185e..55b7001 100644 --- a/pprof.go +++ b/pprof.go @@ -24,9 +24,16 @@ func getPrefix(prefixOptions ...string) string { // the provided gin.Engine. prefixOptions is a optional. If not prefixOptions, // the default path prefix is used, otherwise first prefixOptions will be path prefix. func Register(r *gin.Engine, prefixOptions ...string) { + RouteRegister(&(r.RouterGroup), prefixOptions...) +} + +// RouteRegister the standard HandlerFuncs from the net/http/pprof package with +// the provided gin.GrouterGroup. prefixOptions is a optional. If not prefixOptions, +// the default path prefix is used, otherwise first prefixOptions will be path prefix. +func RouteRegister(rg *gin.RouterGroup, prefixOptions ...string) { prefix := getPrefix(prefixOptions...) - prefixRouter := r.Group(prefix) + prefixRouter := rg.Group(prefix) { prefixRouter.GET("/", pprofHandler(pprof.Index)) prefixRouter.GET("/cmdline", pprofHandler(pprof.Cmdline)) diff --git a/pprof_test.go b/pprof_test.go index 59e57ec..8d2d528 100644 --- a/pprof_test.go +++ b/pprof_test.go @@ -1,6 +1,12 @@ package pprof -import "testing" +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) func Test_getPrefix(t *testing.T) { tests := []struct { @@ -18,3 +24,43 @@ func Test_getPrefix(t *testing.T) { } } } + +func TestRegisterAndRouteRegister(t *testing.T) { + bearerToken := "Bearer token" + gin.SetMode(gin.ReleaseMode) + r := gin.New() + Register(r) + adminGroup := r.Group("/admin", func(c *gin.Context) { + if c.Request.Header.Get("Authorization") != bearerToken { + c.AbortWithStatus(http.StatusForbidden) + return + } + c.Next() + }) + RouteRegister(adminGroup, "pprof") + + req, _ := http.NewRequest(http.MethodGet, "/debug/pprof/", nil) + rw := httptest.NewRecorder() + r.ServeHTTP(rw, req) + + if expected, got := http.StatusOK, rw.Code; expected != got { + t.Errorf("expected: %d, got: %d", expected, got) + } + + req, _ = http.NewRequest(http.MethodGet, "/admin/pprof/", nil) + rw = httptest.NewRecorder() + r.ServeHTTP(rw, req) + + if expected, got := http.StatusForbidden, rw.Code; expected != got { + t.Errorf("expected: %d, got: %d", expected, got) + } + + req, _ = http.NewRequest(http.MethodGet, "/admin/pprof/", nil) + req.Header.Set("Authorization", bearerToken) + rw = httptest.NewRecorder() + r.ServeHTTP(rw, req) + + if expected, got := http.StatusOK, rw.Code; expected != got { + t.Errorf("expected: %d, got: %d", expected, got) + } +}