diff --git a/api/client.go b/api/client.go index 3cf55a25..e86dc530 100644 --- a/api/client.go +++ b/api/client.go @@ -44,14 +44,24 @@ func checkError(resp *http.Response, body []byte) error { } func ClientFromEnvironment() (*Client, error) { + defaultPort := "11434" + scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") - if !ok { + switch { + case !ok: scheme, hostport = "http", os.Getenv("OLLAMA_HOST") + case scheme == "http": + defaultPort = "80" + case scheme == "https": + defaultPort = "443" } + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + host, port, err := net.SplitHostPort(hostport) if err != nil { - host, port = "127.0.0.1", "11434" + host, port = "127.0.0.1", defaultPort if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { host = ip.String() } else if hostport != "" { diff --git a/api/client_test.go b/api/client_test.go new file mode 100644 index 00000000..0eafedca --- /dev/null +++ b/api/client_test.go @@ -0,0 +1,43 @@ +package api + +import "testing" + +func TestClientFromEnvironment(t *testing.T) { + type testCase struct { + value string + expect string + err error + } + + testCases := map[string]*testCase{ + "empty": {value: "", expect: "http://127.0.0.1:11434"}, + "only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"}, + "only port": {value: ":1234", expect: "http://:1234"}, + "address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"}, + "scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"}, + "scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"}, + "scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"}, + "hostname": {value: "example.com", expect: "http://example.com:11434"}, + "hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"}, + "scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"}, + "scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"}, + "scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"}, + "trailing slash": {value: "example.com/", expect: "http://example.com:11434"}, + "trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"}, + } + + for k, v := range testCases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", v.value) + + client, err := ClientFromEnvironment() + if err != v.err { + t.Fatalf("expected %s, got %s", v.err, err) + } + + if client.base.String() != v.expect { + t.Fatalf("expected %s, got %s", v.expect, client.base.String()) + } + }) + } +}