From 291bb97e3d03606258dbee45c24ca5db358b46c7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 6 Jul 2023 16:53:14 -0700 Subject: [PATCH] client request options --- api/client.go | 82 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/api/client.go b/api/client.go index 8216fa74..f153f32e 100644 --- a/api/client.go +++ b/api/client.go @@ -25,19 +25,35 @@ func NewClient(hosts ...string) *Client { } } -func (c *Client) stream(ctx context.Context, method string, path string, reqData any, fn func(bts []byte) error) error { - var reqBody io.Reader - var data []byte - var err error - if reqData != nil { - data, err = json.Marshal(reqData) - if err != nil { - return err - } - reqBody = bytes.NewReader(data) +type options struct { + requestBody io.Reader + responseFunc func(bts []byte) error +} + +func OptionRequestBody(data any) func(*options) { + bts, err := json.Marshal(data) + if err != nil { + panic(err) } - request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), reqBody) + return func(opts *options) { + opts.requestBody = bytes.NewReader(bts) + } +} + +func OptionResponseFunc(fn func([]byte) error) func(*options) { + return func(opts *options) { + opts.responseFunc = fn + } +} + +func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*options)) error { + var opts options + for _, fn := range fns { + fn(&opts) + } + + request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), opts.requestBody) if err != nil { return err } @@ -51,10 +67,12 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData } defer response.Body.Close() - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - if err := fn(scanner.Bytes()); err != nil { - return err + if opts.responseFunc != nil { + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + if err := opts.responseFunc(scanner.Bytes()); err != nil { + return err + } } } @@ -64,25 +82,31 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData type GenerateResponseFunc func(GenerateResponse) error func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error { - return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error { - var resp GenerateResponse - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } + return c.stream(ctx, http.MethodPost, "/api/generate", + OptionRequestBody(req), + OptionResponseFunc(func(bts []byte) error { + var resp GenerateResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } - return fn(resp) - }) + return fn(resp) + }), + ) } type PullProgressFunc func(PullProgress) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { - return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { - var resp PullProgress - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } + return c.stream(ctx, http.MethodPost, "/api/pull", + OptionRequestBody(req), + OptionResponseFunc(func(bts []byte) error { + var resp PullProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } - return fn(resp) - }) + return fn(resp) + }), + ) }