diff --git a/api/client.go b/api/client.go index 4b9aee04..8216fa74 100644 --- a/api/client.go +++ b/api/client.go @@ -5,14 +5,24 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net/http" + "net/url" ) type Client struct { - URL string - HTTP http.Client + base url.URL +} + +func NewClient(hosts ...string) *Client { + host := "127.0.0.1:11434" + if len(hosts) > 0 { + host = hosts[0] + } + + return &Client{ + base: url.URL{Scheme: "http", Host: host}, + } } func (c *Client) stream(ctx context.Context, method string, path string, reqData any, fn func(bts []byte) error) error { @@ -27,23 +37,21 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData reqBody = bytes.NewReader(data) } - url := fmt.Sprintf("%s%s", c.URL, path) - - req, err := http.NewRequestWithContext(ctx, method, url, reqBody) + request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), reqBody) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") - res, err := c.HTTP.Do(req) + response, err := http.DefaultClient.Do(request) if err != nil { return err } - defer res.Body.Close() + defer response.Body.Close() - scanner := bufio.NewScanner(res.Body) + scanner := bufio.NewScanner(response.Body) for scanner.Scan() { if err := fn(scanner.Bytes()); err != nil { return err diff --git a/cmd/cmd.go b/cmd/cmd.go index 38277c1e..4ee74e77 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -36,10 +36,7 @@ func RunRun(cmd *cobra.Command, args []string) error { } func pull(model string) error { - client, err := NewAPIClient() - if err != nil { - return err - } + client := api.NewClient() var bar *progressbar.ProgressBar return client.Pull( @@ -68,10 +65,7 @@ func RunGenerate(_ *cobra.Command, args []string) error { } func generate(model string, prompts ...string) error { - client, err := NewAPIClient() - if err != nil { - return err - } + client := api.NewClient() for _, prompt := range prompts { client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { @@ -121,12 +115,6 @@ func RunServer(_ *cobra.Command, _ []string) error { return server.Serve(ln) } -func NewAPIClient() (*api.Client, error) { - return &api.Client{ - URL: "http://localhost:11434", - }, nil -} - func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile)