From fc8c0445843859726776dc0ff632b32ea664306b Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 8 Mar 2024 22:23:47 -0800 Subject: [PATCH] add allowed host middleware and remove `workDir` middleware (#3018) --- server/routes.go | 77 +++++++++++++++++++++++++++++++++---------- server/routes_test.go | 10 +----- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/server/routes.go b/server/routes.go index e5adc345..cfcdcec4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -10,6 +10,7 @@ import ( "log/slog" "net" "net/http" + "net/netip" "os" "os/signal" "path/filepath" @@ -35,7 +36,7 @@ import ( var mode string = gin.DebugMode type Server struct { - WorkDir string + addr net.Addr } func init() { @@ -904,15 +905,64 @@ var defaultAllowOrigins = []string{ "0.0.0.0", } -func NewServer() (*Server, error) { - workDir, err := os.MkdirTemp("", "ollama") - if err != nil { - return nil, err +func allowedHost(host string) bool { + if host == "" || host == "localhost" { + return true } - return &Server{ - WorkDir: workDir, - }, nil + if hostname, err := os.Hostname(); err == nil && host == hostname { + return true + } + + var tlds = []string{ + ".localhost", + ".local", + ".internal", + } + + for _, tld := range tlds { + if strings.HasSuffix(host, "."+tld) { + return true + } + } + + return false +} + +func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { + return func(c *gin.Context) { + if addr == nil { + c.Next() + return + } + + if !netip.MustParseAddrPort(addr.String()).Addr().IsLoopback() { + c.Next() + return + } + + if addrPort, _ := netip.ParseAddrPort(c.Request.Host); addrPort.Addr().IsLoopback() { + c.Next() + return + } + + if addr, _ := netip.ParseAddr(c.Request.Host); addr.IsLoopback() { + c.Next() + return + } + + host, _, err := net.SplitHostPort(c.Request.Host) + if err != nil { + host = c.Request.Host + } + + if allowedHost(host) { + c.Next() + return + } + + c.AbortWithStatus(http.StatusForbidden) + } } func (s *Server) GenerateRoutes() http.Handler { @@ -938,10 +988,7 @@ func (s *Server) GenerateRoutes() http.Handler { r := gin.Default() r.Use( cors.New(config), - func(c *gin.Context) { - c.Set("workDir", s.WorkDir) - c.Next() - }, + allowedHostsMiddleware(s.addr), ) r.POST("/api/pull", PullModelHandler) @@ -1010,10 +1057,7 @@ func Serve(ln net.Listener) error { } } - s, err := NewServer() - if err != nil { - return err - } + s := &Server{addr: ln.Addr()} r := s.GenerateRoutes() slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) @@ -1029,7 +1073,6 @@ func Serve(ln net.Listener) error { if loaded.runner != nil { loaded.runner.Close() } - os.RemoveAll(s.WorkDir) os.Exit(0) }() diff --git a/server/routes_test.go b/server/routes_test.go index 9cf96f10..bbed02ed 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -21,12 +21,6 @@ import ( "github.com/jmorganca/ollama/version" ) -func setupServer(t *testing.T) (*Server, error) { - t.Helper() - - return NewServer() -} - func Test_Routes(t *testing.T) { type testCase struct { Name string @@ -207,9 +201,7 @@ func Test_Routes(t *testing.T) { }, } - s, err := setupServer(t) - assert.Nil(t, err) - + s := Server{} router := s.GenerateRoutes() httpSrv := httptest.NewServer(router)