From c69bc19e46bf40b24518444cd6754453ac41cdd0 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 12 Jun 2024 18:48:16 -0400 Subject: [PATCH] move OLLAMA_HOST to envconfig (#5009) --- api/client.go | 55 ++------------------------------- api/client_test.go | 41 ++---------------------- api/types.go | 3 -- cmd/cmd.go | 8 +---- envconfig/config.go | 67 +++++++++++++++++++++++++++++++++++++++- envconfig/config_test.go | 48 ++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 103 deletions(-) diff --git a/api/client.go b/api/client.go index dc099e95..fccbc9ad 100644 --- a/api/client.go +++ b/api/client.go @@ -23,11 +23,9 @@ import ( "net" "net/http" "net/url" - "os" "runtime" - "strconv" - "strings" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/version" ) @@ -65,10 +63,7 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { - ollamaHost, err := GetOllamaHost() - if err != nil { - return nil, err - } + ollamaHost := envconfig.Host return &Client{ base: &url.URL{ @@ -79,52 +74,6 @@ func ClientFromEnvironment() (*Client, error) { }, nil } -type OllamaHost struct { - Scheme string - Host string - Port string -} - -func GetOllamaHost() (OllamaHost, error) { - defaultPort := "11434" - - hostVar := os.Getenv("OLLAMA_HOST") - hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) - - scheme, hostport, ok := strings.Cut(hostVar, "://") - switch { - case !ok: - scheme, hostport = "http", hostVar - case scheme == "http": - defaultPort = "80" - case scheme == "https": - defaultPort = "443" - } - - // trim trailing slashes - hostport = strings.TrimRight(hostport, "/") - - host, port, err := net.SplitHostPort(hostport) - if err != nil { - host, port = "127.0.0.1", defaultPort - if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { - host = ip.String() - } else if hostport != "" { - host = hostport - } - } - - if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { - return OllamaHost{}, ErrInvalidHostPort - } - - return OllamaHost{ - Scheme: scheme, - Host: host, - Port: port, - }, nil -} - func NewClient(base *url.URL, http *http.Client) *Client { return &Client{ base: base, diff --git a/api/client_test.go b/api/client_test.go index b2c51d00..fe9fd74f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,11 +1,9 @@ package api import ( - "fmt" - "net" "testing" - "github.com/stretchr/testify/assert" + "github.com/ollama/ollama/envconfig" ) func TestClientFromEnvironment(t *testing.T) { @@ -35,6 +33,7 @@ func TestClientFromEnvironment(t *testing.T) { for k, v := range testCases { t.Run(k, func(t *testing.T) { t.Setenv("OLLAMA_HOST", v.value) + envconfig.LoadConfig() client, err := ClientFromEnvironment() if err != v.err { @@ -46,40 +45,4 @@ func TestClientFromEnvironment(t *testing.T) { } }) } - - hostTestCases := map[string]*testCase{ - "empty": {value: "", expect: "127.0.0.1:11434"}, - "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, - "only port": {value: ":1234", expect: ":1234"}, - "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, - "hostname": {value: "example.com", expect: "example.com:11434"}, - "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, - "zero port": {value: ":0", expect: ":0"}, - "too large port": {value: ":66000", err: ErrInvalidHostPort}, - "too small port": {value: ":-1", err: ErrInvalidHostPort}, - "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, - "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, - "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, - "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, - "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, - "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, - "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, - "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, - } - - for k, v := range hostTestCases { - t.Run(k, func(t *testing.T) { - t.Setenv("OLLAMA_HOST", v.value) - - oh, err := GetOllamaHost() - if err != v.err { - t.Fatalf("expected %s, got %s", v.err, err) - } - - if err == nil { - host := net.JoinHostPort(oh.Host, oh.Port) - assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) - } - }) - } } diff --git a/api/types.go b/api/types.go index caf2ad70..d99cf3bc 100644 --- a/api/types.go +++ b/api/types.go @@ -2,7 +2,6 @@ package api import ( "encoding/json" - "errors" "fmt" "log/slog" "math" @@ -377,8 +376,6 @@ func (m *Metrics) Summary() { } } -var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") - func (opts *Options) FromMap(m map[string]interface{}) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct diff --git a/cmd/cmd.go b/cmd/cmd.go index b5747543..ae7c8da8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -960,17 +960,11 @@ func generate(cmd *cobra.Command, opts runOptions) error { } func RunServer(cmd *cobra.Command, _ []string) error { - // retrieve the OLLAMA_HOST environment variable - ollamaHost, err := api.GetOllamaHost() - if err != nil { - return err - } - if err := initializeKeypair(); err != nil { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port)) + ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) if err != nil { return err } diff --git a/envconfig/config.go b/envconfig/config.go index ae4e9939..2c3b6f77 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -1,6 +1,7 @@ package envconfig import ( + "errors" "fmt" "log/slog" "net" @@ -11,6 +12,18 @@ import ( "strings" ) +type OllamaHost struct { + Scheme string + Host string + Port string +} + +func (o OllamaHost) String() string { + return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port) +} + +var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") + var ( // Set via OLLAMA_ORIGINS in the environment AllowOrigins []string @@ -34,6 +47,8 @@ var ( NoPrune bool // Set via OLLAMA_NUM_PARALLEL in the environment NumParallel int + // Set via OLLAMA_HOST in the environment + Host *OllamaHost // Set via OLLAMA_RUNNERS_DIR in the environment RunnersDir string // Set via OLLAMA_TMPDIR in the environment @@ -50,7 +65,7 @@ func AsMap() map[string]EnvVar { return map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, - "OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"}, + "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"}, @@ -216,4 +231,54 @@ func LoadConfig() { } KeepAlive = clean("OLLAMA_KEEP_ALIVE") + + var err error + Host, err = getOllamaHost() + if err != nil { + slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) + } +} + +func getOllamaHost() (*OllamaHost, error) { + defaultPort := "11434" + + hostVar := os.Getenv("OLLAMA_HOST") + hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) + + scheme, hostport, ok := strings.Cut(hostVar, "://") + switch { + case !ok: + scheme, hostport = "http", hostVar + case scheme == "http": + defaultPort = "80" + case scheme == "https": + defaultPort = "443" + } + + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + + host, port, err := net.SplitHostPort(hostport) + if err != nil { + host, port = "127.0.0.1", defaultPort + if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { + host = ip.String() + } else if hostport != "" { + host = hostport + } + } + + if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { + return &OllamaHost{ + Scheme: scheme, + Host: host, + Port: defaultPort, + }, ErrInvalidHostPort + } + + return &OllamaHost{ + Scheme: scheme, + Host: host, + Port: port, + }, nil } diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 429434ae..7d923d62 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -1,8 +1,11 @@ package envconfig import ( + "fmt" + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,3 +24,48 @@ func TestConfig(t *testing.T) { LoadConfig() require.True(t, FlashAttention) } + +func TestClientFromEnvironment(t *testing.T) { + type testCase struct { + value string + expect string + err error + } + + hostTestCases := map[string]*testCase{ + "empty": {value: "", expect: "127.0.0.1:11434"}, + "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, + "only port": {value: ":1234", expect: ":1234"}, + "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, + "hostname": {value: "example.com", expect: "example.com:11434"}, + "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, + "zero port": {value: ":0", expect: ":0"}, + "too large port": {value: ":66000", err: ErrInvalidHostPort}, + "too small port": {value: ":-1", err: ErrInvalidHostPort}, + "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, + "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, + "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, + "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, + "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, + "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, + "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, + "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, + } + + for k, v := range hostTestCases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", v.value) + LoadConfig() + + oh, err := getOllamaHost() + if err != v.err { + t.Fatalf("expected %s, got %s", v.err, err) + } + + if err == nil { + host := net.JoinHostPort(oh.Host, oh.Port) + assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + } + }) + } +}