diff --git a/api/api.go b/api/api.go index 992472c..7c709ee 100644 --- a/api/api.go +++ b/api/api.go @@ -2,16 +2,17 @@ package api import ( "context" + "crypto/tls" "errors" "fmt" + "net/http" + "net/http/pprof" + "path" + "github.com/database64128/shadowsocks-go/api/internal/restapi" "github.com/database64128/shadowsocks-go/api/ssm" - "github.com/database64128/shadowsocks-go/jsonhelper" - "github.com/gofiber/contrib/fiberzap/v2" - "github.com/gofiber/fiber/v2" - fiberlog "github.com/gofiber/fiber/v2/log" - "github.com/gofiber/fiber/v2/middleware/etag" - "github.com/gofiber/fiber/v2/middleware/pprof" + "github.com/database64128/shadowsocks-go/conn" + "github.com/database64128/shadowsocks-go/tlscerts" "go.uber.org/zap" ) @@ -34,111 +35,269 @@ type Config struct { // If empty, the remote peer's address is used. ProxyHeader string `json:"proxyHeader"` - // ListenAddress is the address to listen on. - ListenAddress string `json:"listen"` - - // CertFile is the path to the certificate file. - // If empty, TLS is disabled. - CertFile string `json:"certFile"` - - // KeyFile is the path to the key file. - // This is required if CertFile is set. - KeyFile string `json:"keyFile"` - - // ClientCertFile is the path to the client certificate file. - // If empty, client certificate authentication is disabled. - ClientCertFile string `json:"clientCertFile"` - // StaticPath is the path where static files are served from. // If empty, static file serving is disabled. StaticPath string `json:"staticPath"` - // SecretPath adds a secret path prefix to all routes. - // If empty, no secret path is added. + // SecretPath adds a secret path prefix to API and pprof endpoints. + // Static files are not affected. If empty, no secret path is added. SecretPath string `json:"secretPath"` - // FiberConfigPath overrides the [fiber.Config] settings we use. - // If empty, no overrides are applied. - FiberConfigPath string `json:"fiberConfigPath"` + // Listeners is the list of server listeners. + Listeners []ListenerConfig `json:"listeners"` } -// Server returns a new API server from the config. -func (c *Config) Server(logger *zap.Logger) (*Server, *ssm.ServerManager, error) { - if !c.Enabled { - return nil, nil, nil - } +// ListenerConfig is the configuration for a server listener. +type ListenerConfig struct { + // Network is the network type. + Network string `json:"network"` + + // Address is the address to listen on. + Address string `json:"address"` + + // Fwmark sets the listener's fwmark on Linux, or user cookie on FreeBSD. + // + // Available on Linux and FreeBSD. + Fwmark int `json:"fwmark"` + + // TrafficClass sets the traffic class of the listener. + // + // Available on most platforms except Windows. + TrafficClass int `json:"trafficClass"` + + // FastOpenBacklog specifies the maximum number of pending TFO connections on Linux. + // If the value is 0, Go std's listen(2) backlog is used. + // + // On other platforms, a non-negative value is ignored, as they do not have the option to set the TFO backlog. + // + // On all platforms, a negative value disables TFO. + FastOpenBacklog int `json:"fastOpenBacklog"` + + // DeferAcceptSecs sets TCP_DEFER_ACCEPT to the given number of seconds on the listener. + // + // Available on Linux. + DeferAcceptSecs int `json:"deferAcceptSecs"` + + // UserTimeoutMsecs sets TCP_USER_TIMEOUT to the given number of milliseconds on the listener. + // + // Available on Linux. + UserTimeoutMsecs int `json:"userTimeoutMsecs"` + + // CertList is the name of the certificate list in the certificate store, + // used as the server certificate for HTTPS. + CertList string `json:"certList"` + + // ClientCAs is the name of the X.509 certificate pool in the certificate store, + // used as the root CA set for verifying client certificates. + ClientCAs string `json:"clientCAs"` + + // EnableTLS controls whether to enable TLS. + EnableTLS bool `json:"enableTLS"` - fiberlog.SetLogger(fiberzap.NewLogger(fiberzap.LoggerConfig{ - SetLogger: logger, - })) + // RequireAndVerifyClientCert controls whether to require and verify client certificates. + RequireAndVerifyClientCert bool `json:"requireAndVerifyClientCert"` - fc := fiber.Config{ - ProxyHeader: c.ProxyHeader, - DisableStartupMessage: true, - Network: "tcp", - EnableTrustedProxyCheck: c.EnableTrustedProxyCheck, - TrustedProxies: c.TrustedProxies, + // ReusePort enables SO_REUSEPORT on the listener. + // + // Available on Linux and the BSDs. + ReusePort bool `json:"reusePort"` + + // FastOpen enables TCP Fast Open on the listener. + // + // Available on Linux, macOS, FreeBSD, and Windows. + FastOpen bool `json:"fastOpen"` + + // FastOpenFallback enables runtime detection of TCP Fast Open support on the listener. + // + // When enabled, the listener will start without TFO if TFO is not available on the system. + // When disabled, the listener will abort if TFO cannot be enabled on the socket. + // + // Available on all platforms. + FastOpenFallback bool `json:"fastOpenFallback"` + + // Multipath enables multipath TCP on the listener. + // + // Unlike Go std, we make MPTCP strictly opt-in. + // That is, if this field is false, MPTCP will be explicitly disabled. + // This ensures that if Go std suddenly decides to enable MPTCP by default, + // existing configurations won't encounter issues due to missing features in the kernel MPTCP stack, + // such as TCP keepalive (as of Linux 6.5), and failed connect attempts won't always be retried once. + // + // Available on platforms supported by Go std's MPTCP implementation. + Multipath bool `json:"multipath"` +} + +// NewServer returns a new API server from the config. +func (c *Config) NewServer(logger *zap.Logger, listenConfigCache conn.ListenConfigCache, tlsCertStore *tlscerts.Store) (*Server, *ssm.ServerManager, error) { + if len(c.Listeners) == 0 { + return nil, nil, errors.New("no listeners specified") } - if c.FiberConfigPath != "" { - if err := jsonhelper.OpenAndDecodeDisallowUnknownFields(c.FiberConfigPath, &fc); err != nil { - return nil, nil, fmt.Errorf("failed to load fiber config: %w", err) + lcs := make([]listenConfig, len(c.Listeners)) + for i := range c.Listeners { + lnc := &c.Listeners[i] + lcs[i] = listenConfig{ + listenConfig: listenConfigCache.Get(conn.ListenerSocketOptions{ + Fwmark: lnc.Fwmark, + TrafficClass: lnc.TrafficClass, + TCPFastOpenBacklog: lnc.FastOpenBacklog, + TCPDeferAcceptSecs: lnc.DeferAcceptSecs, + TCPUserTimeoutMsecs: lnc.UserTimeoutMsecs, + ReusePort: lnc.ReusePort, + TCPFastOpen: lnc.FastOpen, + TCPFastOpenFallback: lnc.FastOpenFallback, + MultipathTCP: lnc.Multipath, + }), + network: lnc.Network, + address: lnc.Address, } - } - app := fiber.New(fc) + if lnc.EnableTLS { + var tlsConfig tls.Config - app.Use(etag.New()) + if lnc.CertList != "" { + certs, getCert, ok := tlsCertStore.GetCertList(lnc.CertList) + if !ok { + return nil, nil, fmt.Errorf("certificate list %q not found", lnc.CertList) + } + tlsConfig.Certificates = certs + tlsConfig.GetCertificate = getCert + } - app.Use(fiberzap.New(fiberzap.Config{ - Logger: logger, - })) + if lnc.ClientCAs != "" { + pool, ok := tlsCertStore.GetX509CertPool(lnc.ClientCAs) + if !ok { + return nil, nil, fmt.Errorf("client CA X.509 certificate pool %q not found", lnc.ClientCAs) + } + tlsConfig.ClientCAs = pool + } - var router fiber.Router = app - if c.SecretPath != "" { - if c.SecretPath[0] != '/' { - c.SecretPath = "/" + c.SecretPath + if lnc.RequireAndVerifyClientCert { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + lcs[i].tlsConfig = &tlsConfig } - router = app.Group(c.SecretPath) } - if c.DebugPprof { - app.Use(pprof.New(pprof.Config{ - Prefix: c.SecretPath, - })) + mux := http.NewServeMux() + + basePath := "/" + if c.SecretPath != "" { + basePath = joinPatternPath(basePath, c.SecretPath) } - api := router.Group("/api") + if c.DebugPprof { + register := func(path string, handler http.HandlerFunc) { + pattern := "GET " + joinPatternPath(basePath, path) + mux.Handle(pattern, logPprofRequests(logger, handler)) + } + + register("/debug/pprof/", pprof.Index) + register("/debug/pprof/cmdline", pprof.Cmdline) + register("/debug/pprof/profile", pprof.Profile) + register("/debug/pprof/symbol", pprof.Symbol) + register("/debug/pprof/trace", pprof.Trace) + } // /api/ssm/v1 + apiSSMv1Path := joinPatternPath(basePath, "/api/ssm/v1") sm := ssm.NewServerManager() - sm.RegisterRoutes(api.Group("/ssm/v1")) + sm.RegisterHandlers(func(method, path string, handler restapi.HandlerFunc) { + pattern := method + " " + joinPatternPath(apiSSMv1Path, path) + mux.Handle(pattern, logAPIRequests(logger, handler)) + }) if c.StaticPath != "" { - router.Static("/", c.StaticPath, fiber.Static{ - ByteRange: true, - }) + mux.Handle("GET /", logFileServerRequests(logger, http.FileServer(http.Dir(c.StaticPath)))) + } + + errorLog, err := zap.NewStdLogAt(logger, zap.ErrorLevel) + if err != nil { + return nil, nil, fmt.Errorf("failed to create error logger: %w", err) } return &Server{ - logger: logger, - app: app, - listenAddress: c.ListenAddress, - certFile: c.CertFile, - keyFile: c.KeyFile, - clientCertFile: c.ClientCertFile, + logger: logger, + lcs: lcs, + server: http.Server{ + Handler: mux, + ErrorLog: errorLog, + }, }, sm, nil } +// joinPatternPath joins path elements into a pattern path. +func joinPatternPath(elem ...string) string { + p := path.Join(elem...) + if p == "" { + return "" + } + // Add back the trailing slash removed by [path.Join]. + if last := elem[len(elem)-1]; last != "" && last[len(last)-1] == '/' { + if p[len(p)-1] != '/' { + return p + "/" + } + } + return p +} + +// logPprofRequests is a middleware that logs pprof requests. +func logPprofRequests(logger *zap.Logger, h http.HandlerFunc) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h(w, r) + logger.Info("Handled pprof request", + zap.String("proto", r.Proto), + zap.String("method", r.Method), + zap.String("requestURI", r.RequestURI), + zap.String("host", r.Host), + zap.String("remoteAddr", r.RemoteAddr), + ) + }) +} + +// logAPIRequests is a middleware that logs API requests. +func logAPIRequests(logger *zap.Logger, h restapi.HandlerFunc) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + status, err := h(w, r) + logger.Info("Handled API request", + zap.String("proto", r.Proto), + zap.String("method", r.Method), + zap.String("requestURI", r.RequestURI), + zap.String("host", r.Host), + zap.String("remoteAddr", r.RemoteAddr), + zap.Int("status", status), + zap.Error(err), + ) + }) +} + +// logFileServerRequests is a middleware that logs file server requests. +func logFileServerRequests(logger *zap.Logger, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + logger.Info("Served file", + zap.String("proto", r.Proto), + zap.String("method", r.Method), + zap.String("requestURI", r.RequestURI), + zap.String("host", r.Host), + zap.String("remoteAddr", r.RemoteAddr), + ) + }) +} + +type listenConfig struct { + listenConfig conn.ListenConfig + network string + address string + tlsConfig *tls.Config +} + // Server is the RESTful API server. type Server struct { - logger *zap.Logger - app *fiber.App - listenAddress string - certFile string - keyFile string - clientCertFile string - ctx context.Context + logger *zap.Logger + lcs []listenConfig + server http.Server } // String implements [service.Service.String]. @@ -148,35 +307,29 @@ func (s *Server) String() string { // Start starts the API server. func (s *Server) Start(ctx context.Context) error { - s.logger.Info("Starting API server", zap.String("listenAddress", s.listenAddress)) - s.ctx = ctx - go func() { - var err error - switch { - case s.clientCertFile != "": - err = s.app.ListenMutualTLS(s.listenAddress, s.certFile, s.keyFile, s.clientCertFile) - case s.certFile != "": - err = s.app.ListenTLS(s.listenAddress, s.certFile, s.keyFile) - default: - err = s.app.Listen(s.listenAddress) - } + for i := range s.lcs { + lc := &s.lcs[i] + ln, _, err := lc.listenConfig.Listen(ctx, lc.network, lc.address) if err != nil { - s.logger.Fatal("Failed to start API server", zap.Error(err)) + return err } - }() + + if lc.tlsConfig != nil { + ln = tls.NewListener(ln, lc.tlsConfig) + } + + go func() { + if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed { + s.logger.Error("Failed to serve API", zap.Error(err)) + } + }() + + s.logger.Info("Started API server listener", zap.Stringer("listenAddress", ln.Addr())) + } return nil } // Stop stops the API server. func (s *Server) Stop() error { - if err := s.app.ShutdownWithContext(s.ctx); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - if errors.Is(err, context.DeadlineExceeded) { - return nil - } - return err - } - return nil + return s.server.Close() } diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 0000000..c34ae63 --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,46 @@ +package api + +import "testing" + +func TestJoinPatternPath(t *testing.T) { + for _, c := range []struct { + elem []string + want string + }{ + {[]string{}, ""}, + {[]string{""}, ""}, + {[]string{"a"}, "a"}, + {[]string{"/"}, "/"}, + {[]string{"/a"}, "/a"}, + {[]string{"a/"}, "a/"}, + {[]string{"/a/"}, "/a/"}, + {[]string{"", "b"}, "b"}, + {[]string{"", "/b"}, "/b"}, + {[]string{"", "b/"}, "b/"}, + {[]string{"", "/b/"}, "/b/"}, + {[]string{"a", "b"}, "a/b"}, + {[]string{"a", "/b"}, "a/b"}, + {[]string{"a", "b/"}, "a/b/"}, + {[]string{"a", "/b/"}, "a/b/"}, + {[]string{"/", "b"}, "/b"}, + {[]string{"/", "/b"}, "/b"}, + {[]string{"/", "b/"}, "/b/"}, + {[]string{"/", "/b/"}, "/b/"}, + {[]string{"/a", "b"}, "/a/b"}, + {[]string{"/a", "/b"}, "/a/b"}, + {[]string{"/a", "b/"}, "/a/b/"}, + {[]string{"/a", "/b/"}, "/a/b/"}, + {[]string{"a/", "b"}, "a/b"}, + {[]string{"a/", "/b"}, "a/b"}, + {[]string{"a/", "b/"}, "a/b/"}, + {[]string{"a/", "/b/"}, "a/b/"}, + {[]string{"/a/", "b"}, "/a/b"}, + {[]string{"/a/", "/b"}, "/a/b"}, + {[]string{"/a/", "b/"}, "/a/b/"}, + {[]string{"/a/", "/b/"}, "/a/b/"}, + } { + if got := joinPatternPath(c.elem...); got != c.want { + t.Errorf("joinPatternPath(%#v) = %q; want %q", c.elem, got, c.want) + } + } +} diff --git a/api/internal/restapi/restapi.go b/api/internal/restapi/restapi.go new file mode 100644 index 0000000..35cb833 --- /dev/null +++ b/api/internal/restapi/restapi.go @@ -0,0 +1,28 @@ +package restapi + +import ( + "encoding/json" + "net/http" +) + +// HandlerFunc is like [http.HandlerFunc], but returns a status code and an error. +type HandlerFunc func(w http.ResponseWriter, r *http.Request) (status int, err error) + +// EncodeResponse sets the Content-Type header field to application/json, and writes +// to the response writer with the given status code and data encoded as JSON. +// +// If data is nil, the status code is written and no data is encoded. +func EncodeResponse(w http.ResponseWriter, status int, data any) (int, error) { + if data == nil { + w.WriteHeader(status) + return status, nil + } + w.Header()["Content-Type"] = []string{"application/json"} + w.WriteHeader(status) + return status, json.NewEncoder(w).Encode(data) +} + +// DecodeRequest decodes the request body as JSON into the provided value. +func DecodeRequest(r *http.Request, v any) error { + return json.NewDecoder(r.Body).Decode(v) +} diff --git a/api/ssm/ssm.go b/api/ssm/ssm.go index 9a6e0be..93c23b2 100644 --- a/api/ssm/ssm.go +++ b/api/ssm/ssm.go @@ -3,11 +3,12 @@ package ssm import ( "errors" + "net/http" "github.com/database64128/shadowsocks-go" + "github.com/database64128/shadowsocks-go/api/internal/restapi" "github.com/database64128/shadowsocks-go/cred" "github.com/database64128/shadowsocks-go/stats" - "github.com/gofiber/fiber/v2" ) // StandardError is the standard error response. @@ -15,22 +16,6 @@ type StandardError struct { Message string `json:"error"` } -// ServerInfo contains information about the API server. -type ServerInfo struct { - Name string `json:"server"` - APIVersion string `json:"apiVersion"` -} - -var serverInfo = ServerInfo{ - Name: "shadowsocks-go " + shadowsocks.Version, - APIVersion: "v1", -} - -// GetServerInfo returns information about the API server. -func GetServerInfo(c *fiber.Ctx) error { - return c.JSON(&serverInfo) -} - type managedServer struct { cms *cred.ManagedServer sc stats.Collector @@ -58,131 +43,125 @@ func (sm *ServerManager) AddServer(name string, cms *cred.ManagedServer, sc stat sm.managedServerNames = append(sm.managedServerNames, name) } -// RegisterRoutes sets up routes for the /servers endpoint. -func (sm *ServerManager) RegisterRoutes(v1 fiber.Router) { - v1.Get("/servers", sm.ListServers) +// RegisterHandlers sets up handlers for the /servers endpoint. +func (sm *ServerManager) RegisterHandlers(register func(method string, path string, handler restapi.HandlerFunc)) { + register(http.MethodGet, "/servers", sm.handleListServers) - server := v1.Group("/servers/:server", sm.ContextManagedServer) - server.Get("", GetServerInfo) - server.Get("/stats", sm.GetStats) + register(http.MethodGet, "/servers/{server}", sm.requireServerStats(handleGetServerInfo)) + register(http.MethodGet, "/servers/{server}/stats", sm.requireServerStats(handleGetStats)) - users := server.Group("/users", sm.CheckMultiUserSupport) - users.Get("", sm.ListUsers) - users.Post("", sm.AddUser) - users.Get("/:username", sm.GetUser) - users.Patch("/:username", sm.UpdateUser) - users.Delete("/:username", sm.DeleteUser) + register(http.MethodGet, "/servers/{server}/users", sm.requireServerUsers(handleListUsers)) + register(http.MethodPost, "/servers/{server}/users", sm.requireServerUsers(handleAddUser)) + register(http.MethodGet, "/servers/{server}/users/{username}", sm.requireServerUsers(handleGetUser)) + register(http.MethodPatch, "/servers/{server}/users/{username}", sm.requireServerUsers(handleUpdateUser)) + register(http.MethodDelete, "/servers/{server}/users/{username}", sm.requireServerUsers(handleDeleteUser)) } -// ListServers lists all managed servers. -func (sm *ServerManager) ListServers(c *fiber.Ctx) error { - return c.JSON(&sm.managedServerNames) +func (sm *ServerManager) handleListServers(w http.ResponseWriter, _ *http.Request) (int, error) { + return restapi.EncodeResponse(w, http.StatusOK, &sm.managedServerNames) } -// ContextManagedServer is a middleware for the servers group. -// It adds the server with the given name to the request context. -func (sm *ServerManager) ContextManagedServer(c *fiber.Ctx) error { - name := c.Params("server") - ms := sm.managedServers[name] - if ms == nil { - return c.Status(fiber.StatusNotFound).JSON(&StandardError{Message: "server not found"}) +func (sm *ServerManager) requireServerStats(h func(http.ResponseWriter, *http.Request, stats.Collector) (int, error)) restapi.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) (int, error) { + name := r.PathValue("server") + ms := sm.managedServers[name] + if ms == nil { + return restapi.EncodeResponse(w, http.StatusNotFound, StandardError{Message: "server not found"}) + } + return h(w, r, ms.sc) } - c.Locals(0, ms) - return c.Next() } -// managedServerFromContext returns the managed server from the request context. -func managedServerFromContext(c *fiber.Ctx) *managedServer { - return c.Locals(0).(*managedServer) -} +var serverInfoJSON = []byte(`{"server":"shadowsocks-go ` + shadowsocks.Version + `","apiVersion":"v1"}`) -// GetStats returns server traffic statistics. -func (sm *ServerManager) GetStats(c *fiber.Ctx) error { - ms := managedServerFromContext(c) - if c.QueryBool("clear") { - return c.JSON(ms.sc.SnapshotAndReset()) - } - return c.JSON(ms.sc.Snapshot()) +func handleGetServerInfo(w http.ResponseWriter, _ *http.Request, _ stats.Collector) (int, error) { + w.Header()["Content-Type"] = []string{"application/json"} + _, err := w.Write(serverInfoJSON) + return http.StatusOK, err } -// CheckMultiUserSupport is a middleware for the users group. -// It checks whether the selected server supports user management. -func (sm *ServerManager) CheckMultiUserSupport(c *fiber.Ctx) error { - ms := managedServerFromContext(c) - if ms.cms == nil { - return c.Status(fiber.StatusNotFound).JSON(&StandardError{Message: "The server does not support user management."}) +func handleGetStats(w http.ResponseWriter, r *http.Request, sc stats.Collector) (int, error) { + var serverStats stats.Server + if v := r.URL.Query()["clear"]; len(v) == 1 && (v[0] == "" || v[0] == "true") { + serverStats = sc.SnapshotAndReset() + } else { + serverStats = sc.Snapshot() } - return c.Next() + return restapi.EncodeResponse(w, http.StatusOK, serverStats) } -// UserList contains a list of user credentials. -type UserList struct { - Users []cred.UserCredential `json:"users"` +func (sm *ServerManager) requireServerUsers(h func(http.ResponseWriter, *http.Request, *managedServer) (int, error)) func(http.ResponseWriter, *http.Request) (int, error) { + return func(w http.ResponseWriter, r *http.Request) (int, error) { + name := r.PathValue("server") + ms := sm.managedServers[name] + if ms == nil { + return restapi.EncodeResponse(w, http.StatusNotFound, StandardError{Message: "server not found"}) + } + if ms.cms == nil { + return restapi.EncodeResponse(w, http.StatusNotFound, StandardError{Message: "The server does not support user management."}) + } + return h(w, r, ms) + } } -// ListUsers lists server users. -func (sm *ServerManager) ListUsers(c *fiber.Ctx) error { - ms := managedServerFromContext(c) - return c.JSON(&UserList{Users: ms.cms.Credentials()}) +func handleListUsers(w http.ResponseWriter, _ *http.Request, ms *managedServer) (int, error) { + type response struct { + Users []cred.UserCredential `json:"users"` + } + return restapi.EncodeResponse(w, http.StatusOK, response{Users: ms.cms.Credentials()}) } -// AddUser adds a new user credential to the server. -func (sm *ServerManager) AddUser(c *fiber.Ctx) error { +func handleAddUser(w http.ResponseWriter, r *http.Request, ms *managedServer) (int, error) { var uc cred.UserCredential - if err := c.BodyParser(&uc); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(&StandardError{Message: err.Error()}) + if err := restapi.DecodeRequest(r, &uc); err != nil { + return restapi.EncodeResponse(w, http.StatusBadRequest, StandardError{Message: err.Error()}) } - ms := managedServerFromContext(c) if err := ms.cms.AddCredential(uc.Name, uc.UPSK); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(&StandardError{Message: err.Error()}) + return restapi.EncodeResponse(w, http.StatusBadRequest, StandardError{Message: err.Error()}) } - return c.JSON(&uc) -} -// UserInfo contains information about a user. -type UserInfo struct { - cred.UserCredential - stats.Traffic + return restapi.EncodeResponse(w, http.StatusOK, &uc) } -// GetUser returns information about a user. -func (sm *ServerManager) GetUser(c *fiber.Ctx) error { - ms := managedServerFromContext(c) - username := c.Params("username") - uc, ok := ms.cms.GetCredential(username) +func handleGetUser(w http.ResponseWriter, r *http.Request, ms *managedServer) (int, error) { + type response struct { + cred.UserCredential + stats.Traffic + } + + username := r.PathValue("username") + userCred, ok := ms.cms.GetCredential(username) if !ok { - return c.Status(fiber.StatusNotFound).JSON(&StandardError{Message: "user not found"}) + return restapi.EncodeResponse(w, http.StatusNotFound, StandardError{Message: "user not found"}) } - return c.JSON(&UserInfo{uc, ms.sc.Snapshot().Traffic}) + + return restapi.EncodeResponse(w, http.StatusOK, response{userCred, ms.sc.Snapshot().Traffic}) } -// UpdateUser updates a user's credential. -func (sm *ServerManager) UpdateUser(c *fiber.Ctx) error { +func handleUpdateUser(w http.ResponseWriter, r *http.Request, ms *managedServer) (int, error) { var update struct { UPSK []byte `json:"uPSK"` } - if err := c.BodyParser(&update); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(&StandardError{Message: err.Error()}) + if err := restapi.DecodeRequest(r, &update); err != nil { + return restapi.EncodeResponse(w, http.StatusBadRequest, StandardError{Message: err.Error()}) } - ms := managedServerFromContext(c) - username := c.Params("username") + username := r.PathValue("username") if err := ms.cms.UpdateCredential(username, update.UPSK); err != nil { if errors.Is(err, cred.ErrNonexistentUser) { - return c.Status(fiber.StatusNotFound).JSON(&StandardError{Message: err.Error()}) + return restapi.EncodeResponse(w, http.StatusNotFound, StandardError{Message: err.Error()}) } - return c.Status(fiber.StatusBadRequest).JSON(&StandardError{Message: err.Error()}) + return restapi.EncodeResponse(w, http.StatusBadRequest, StandardError{Message: err.Error()}) } - return c.SendStatus(fiber.StatusNoContent) + + return restapi.EncodeResponse(w, http.StatusNoContent, nil) } -// DeleteUser deletes a user's credential. -func (sm *ServerManager) DeleteUser(c *fiber.Ctx) error { - ms := managedServerFromContext(c) - username := c.Params("username") +func handleDeleteUser(w http.ResponseWriter, r *http.Request, ms *managedServer) (int, error) { + username := r.PathValue("username") if err := ms.cms.DeleteCredential(username); err != nil { - return c.Status(fiber.StatusNotFound).JSON(&StandardError{Message: err.Error()}) + return restapi.EncodeResponse(w, http.StatusNotFound, StandardError{Message: err.Error()}) } - return c.SendStatus(fiber.StatusNoContent) + return restapi.EncodeResponse(w, http.StatusNoContent, nil) } diff --git a/conn/conn.go b/conn/conn.go index 732faee..eaca28b 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -70,6 +70,14 @@ type ListenConfig struct { fns setFuncSlice } +// Listen wraps [tfo.ListenConfig.Listen]. +func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (ln net.Listener, info SocketInfo, err error) { + tlc := lc.tlc + tlc.Control = lc.fns.controlFunc(&info) + ln, err = tlc.Listen(ctx, network, address) + return +} + // ListenTCP wraps [tfo.ListenConfig.Listen] and returns a [*net.TCPListener] directly. func (lc *ListenConfig) ListenTCP(ctx context.Context, network, address string) (tln *net.TCPListener, info SocketInfo, err error) { tlc := lc.tlc diff --git a/docs/config.json b/docs/config.json index 68c343d..6d1e61a 100644 --- a/docs/config.json +++ b/docs/config.json @@ -523,12 +523,27 @@ "enableTrustedProxyCheck": false, "trustedProxies": [], "proxyHeader": "X-Forwarded-For", - "listen": ":20221", - "certFile": "", - "keyFile": "", - "clientCertFile": "", - "secretPath": "/4paZvyoK3dCjyQXU33md5huJMMYVD9o8", - "fiberConfigPath": "" + "staticPath": "", + "secretPath": "4paZvyoK3dCjyQXU33md5huJMMYVD9o8", + "listeners": [ + { + "network": "tcp", + "address": ":20221", + "fwmark": 52140, + "trafficClass": 0, + "reusePort": false, + "fastOpen": true, + "fastOpenBacklog": 0, + "fastOpenFallback": true, + "multipath": false, + "deferAcceptSecs": 0, + "userTimeoutMsecs": 0, + "certList": "example.com", + "clientCAs": "my-root-ca", + "enableTLS": false, + "requireAndVerifyClientCert": false + } + ] }, "certs": { "certLists": [ diff --git a/go.mod b/go.mod index f41d71d..d79c496 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,6 @@ go 1.23.0 require ( github.com/database64128/netx-go v0.0.0-20241205055133-3d4b4d263f10 github.com/database64128/tfo-go/v2 v2.2.2 - github.com/gofiber/contrib/fiberzap/v2 v2.1.4 - github.com/gofiber/fiber/v2 v2.52.5 github.com/oschwald/geoip2-golang v1.11.0 go.uber.org/zap v1.27.0 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba @@ -16,17 +14,7 @@ require ( ) require ( - github.com/andybalholm/brotli v1.1.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect github.com/oschwald/maxminddb-golang v1.13.1 // indirect - github.com/rivo/uniseg v0.4.7 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.55.0 // indirect - github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index f5cbf9d..667260d 100644 --- a/go.sum +++ b/go.sum @@ -1,45 +1,19 @@ -github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= -github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/database64128/netx-go v0.0.0-20241205055133-3d4b4d263f10 h1:UJId3liaDh+tlJ1e3OmXqIevs9JFYXXo1K30Yx/nkrc= github.com/database64128/netx-go v0.0.0-20241205055133-3d4b4d263f10/go.mod h1:dqHsLB0Fb36Z2NSrzKklBf27+hLifGwPEGcGGXib3Rw= github.com/database64128/tfo-go/v2 v2.2.2 h1:BxynF4qGF5ct3DpPLEG62uyJZ3LQhqaf0Ken+kyy7PM= github.com/database64128/tfo-go/v2 v2.2.2/go.mod h1:2IW8jppdBwdVMjA08uEyMNnqiAHKUlqAA+J8NrsfktY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gofiber/contrib/fiberzap/v2 v2.1.4 h1:GCtCQnT4Cr9az4qab2Ozmqsomkxm4Ei86MfKk/1p5+0= -github.com/gofiber/contrib/fiberzap/v2 v2.1.4/go.mod h1:PkdXgUzw+oj4m6ksfKJ0Hs3H7iPhwvhfI4b2LSA9hhA= -github.com/gofiber/fiber/v2 v2.52.5 h1:tWoP1MJQjGEe4GB5TUGOi7P2E0ZMMRx5ZTG4rT+yGMo= -github.com/gofiber/fiber/v2 v2.52.5/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/oschwald/geoip2-golang v1.11.0 h1:hNENhCn1Uyzhf9PTmquXENiWS6AlxAEnBII6r8krA3w= github.com/oschwald/geoip2-golang v1.11.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8= -github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM= -github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= -github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -50,9 +24,7 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/service/service.go b/service/service.go index ec08c7b..457448f 100644 --- a/service/service.go +++ b/service/service.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/database64128/shadowsocks-go/api" + "github.com/database64128/shadowsocks-go/api/ssm" "github.com/database64128/shadowsocks-go/conn" "github.com/database64128/shadowsocks-go/cred" "github.com/database64128/shadowsocks-go/dns" @@ -144,14 +145,17 @@ func (sc *Config) Manager(logger *zap.Logger) (*Manager, error) { } credman := cred.NewManager(logger) - apiServer, apiSM, err := sc.API.Server(logger) - if err != nil { - return nil, fmt.Errorf("failed to create API server: %w", err) - } + var apiSM *ssm.ServerManager services := make([]Relay, 0, 2+2*len(sc.Servers)) services = append(services, credman) - if apiServer != nil { + + if sc.API.Enabled { + var apiServer *api.Server + apiServer, apiSM, err = sc.API.NewServer(logger, listenConfigCache, tlsCertStore) + if err != nil { + return nil, fmt.Errorf("failed to create API server: %w", err) + } services = append(services, apiServer) }