From 2a66a1164a009f597f8931f155e18b05777c6602 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jul 2023 11:54:22 -0700 Subject: [PATCH 1/4] common stream producer --- server/models.go | 32 ++----------- server/routes.go | 114 +++++++++++++++++++++++------------------------ 2 files changed, 61 insertions(+), 85 deletions(-) diff --git a/server/models.go b/server/models.go index 813cccc9..496b2c45 100644 --- a/server/models.go +++ b/server/models.go @@ -8,8 +8,6 @@ import ( "os" "path" "strconv" - - "github.com/jmorganca/ollama/api" ) const directoryURL = "https://ollama.ai/api/models" @@ -36,14 +34,6 @@ func (m *Model) FullName() string { return path.Join(home, ".ollama", "models", m.Name+".bin") } -func pull(model string, progressCh chan<- api.PullProgress) error { - remote, err := getRemote(model) - if err != nil { - return fmt.Errorf("failed to pull model: %w", err) - } - return saveModel(remote, progressCh) -} - func getRemote(model string) (*Model, error) { // resolve the model download from our directory resp, err := http.Get(directoryURL) @@ -68,7 +58,7 @@ func getRemote(model string) (*Model, error) { return nil, fmt.Errorf("model not found in directory: %s", model) } -func saveModel(model *Model, progressCh chan<- api.PullProgress) error { +func saveModel(model *Model, fn func(total, completed int64)) error { // this models cache directory is created by the server on startup client := &http.Client{} @@ -98,11 +88,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { // already downloaded - progressCh <- api.PullProgress{ - Total: alreadyDownloaded, - Completed: alreadyDownloaded, - Percent: 100, - } + fn(alreadyDownloaded, alreadyDownloaded) return nil } @@ -136,19 +122,9 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { totalBytes += int64(n) - // send progress updates - progressCh <- api.PullProgress{ - Total: totalSize, - Completed: totalBytes, - Percent: float64(totalBytes) / float64(totalSize) * 100, - } - } - - progressCh <- api.PullProgress{ - Total: totalSize, - Completed: totalSize, - Percent: 100, + fn(totalSize, totalBytes) } + fn(totalSize, totalSize) return nil } diff --git a/server/routes.go b/server/routes.go index 47551f15..94894fdb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -79,35 +79,54 @@ func generate(c *gin.Context) { req.Prompt = sb.String() } - ch := make(chan string) + ch := make(chan any) g, _ := errgroup.WithContext(c.Request.Context()) g.Go(func() error { defer close(ch) return llm.Predict(req.Prompt, func(s string) { - ch <- s + ch <- api.GenerateResponse{Response: s} }) }) g.Go(func() error { - c.Stream(func(w io.Writer) bool { - s, ok := <-ch - if !ok { - return false - } + stream(c, ch) + return nil + }) - bts, err := json.Marshal(api.GenerateResponse{Response: s}) - if err != nil { - return false - } + if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } +} - bts = append(bts, '\n') - if _, err := w.Write(bts); err != nil { - return false - } +func pull(c *gin.Context) { + var req api.PullRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } - return true + remote, err := getRemote(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ch := make(chan any) + g, _ := errgroup.WithContext(c.Request.Context()) + g.Go(func() error { + defer close(ch) + return saveModel(remote, func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(total) / float64(completed) * 100, + } }) + }) + g.Go(func() error { + stream(c, ch) return nil }) @@ -124,47 +143,7 @@ func Serve(ln net.Listener) error { c.String(http.StatusOK, "Ollama is running") }) - r.POST("api/pull", func(c *gin.Context) { - var req api.PullRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - progressCh := make(chan api.PullProgress) - go func() { - defer close(progressCh) - if err := pull(req.Model, progressCh); err != nil { - var opError *net.OpError - if errors.As(err, &opError) { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - }() - - c.Stream(func(w io.Writer) bool { - progress, ok := <-progressCh - if !ok { - return false - } - - bts, err := json.Marshal(progress) - if err != nil { - return false - } - - bts = append(bts, '\n') - if _, err := w.Write(bts); err != nil { - return false - } - - return true - }) - }) - + r.POST("api/pull", pull) r.POST("/api/generate", generate) log.Printf("Listening on %s", ln.Addr()) @@ -186,3 +165,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i return } + +func stream(c *gin.Context, ch chan any) { + c.Stream(func(w io.Writer) bool { + val, ok := <-ch + if !ok { + return false + } + + bts, err := json.Marshal(val) + if err != nil { + return false + } + + bts = append(bts, '\n') + if _, err := w.Write(bts); err != nil { + return false + } + + return true + }) +} From e243329e2eb529e8ffab1f3d9af5f5aefd7ed8f1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jul 2023 13:05:51 -0700 Subject: [PATCH 2/4] check api status --- api/client.go | 28 +++++++++++++++++++++++----- cmd/cmd.go | 19 ++++++++++++++----- server/routes.go | 2 +- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/api/client.go b/api/client.go index ccbcbf6b..29ab2698 100644 --- a/api/client.go +++ b/api/client.go @@ -10,6 +10,20 @@ import ( "net/url" ) +type StatusError struct { + StatusCode int + Status string + Message string +} + +func (e StatusError) Error() string { + if e.Message != "" { + return fmt.Sprintf("%s: %s", e.Status, e.Message) + } + + return e.Status +} + type Client struct { base url.URL } @@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client { } } -func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error { +func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { var buf *bytes.Buffer if data != nil { bts, err := json.Marshal(data) @@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call scanner := bufio.NewScanner(response.Body) for scanner.Scan() { var errorResponse struct { - Error string `json:"error"` + Error string `json:"error,omitempty"` } bts := scanner.Bytes() @@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call return fmt.Errorf("unmarshal: %w", err) } - if len(errorResponse.Error) > 0 { - return fmt.Errorf("stream: %s", errorResponse.Error) + if response.StatusCode >= 400 { + return StatusError{ + StatusCode: response.StatusCode, + Status: response.Status, + Message: errorResponse.Error, + } } - if err := callback(bts); err != nil { + if err := fn(bts); err != nil { return err } } diff --git a/cmd/cmd.go b/cmd/cmd.go index 8421b8f5..ca924ae9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/http" "os" "path" "strings" @@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error { switch { case errors.Is(err, os.ErrNotExist): if err := pull(args[0]); err != nil { - return err + var apiStatusError api.StatusError + if !errors.As(err, &apiStatusError) { + return err + } + + if apiStatusError.StatusCode != http.StatusBadGateway { + return err + } } case err != nil: return err @@ -50,11 +58,12 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { - if bar == nil && progress.Percent == 100 { - // already downloaded - return nil - } if bar == nil { + if progress.Percent == 100 { + // already downloaded + return nil + } + bar = progressbar.DefaultBytes(progress.Total) } diff --git a/server/routes.go b/server/routes.go index 94894fdb..1478f9ae 100644 --- a/server/routes.go +++ b/server/routes.go @@ -108,7 +108,7 @@ func pull(c *gin.Context) { remote, err := getRemote(req.Model) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return } From 948323fa78067084ba2046e750a1d0d8e7de3034 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jul 2023 13:36:35 -0700 Subject: [PATCH 3/4] rename partial file --- server/models.go | 75 ++++++++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/server/models.go b/server/models.go index 496b2c45..fd689ed6 100644 --- a/server/models.go +++ b/server/models.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -34,6 +35,14 @@ func (m *Model) FullName() string { return path.Join(home, ".ollama", "models", m.Name+".bin") } +func (m *Model) TempFile() string { + fullName := m.FullName() + return path.Join( + path.Dir(fullName), + fmt.Sprintf(".%s.part", path.Base(fullName)), + ) +} + func getRemote(model string) (*Model, error) { // resolve the model download from our directory resp, err := http.Get(directoryURL) @@ -66,37 +75,45 @@ func saveModel(model *Model, fn func(total, completed int64)) error { if err != nil { return fmt.Errorf("failed to download model: %w", err) } - // check for resume - alreadyDownloaded := int64(0) - fileInfo, err := os.Stat(model.FullName()) - if err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("failed to check resume model file: %w", err) - } - // file doesn't exist, create it now - } else { - alreadyDownloaded = fileInfo.Size() - req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded)) + + // check if completed file exists + fi, err := os.Stat(model.FullName()) + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + fn(fi.Size(), fi.Size()) + return nil } + var size int64 + + // completed file doesn't exist, check partial file + fi, err = os.Stat(model.TempFile()) + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + } + + req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size)) + resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to download model: %w", err) } - defer resp.Body.Close() - if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { - // already downloaded - fn(alreadyDownloaded, alreadyDownloaded) - return nil - } - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + if resp.StatusCode >= 400 { return fmt.Errorf("failed to download model: %s", resp.Status) } - out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { panic(err) } @@ -104,27 +121,23 @@ func saveModel(model *Model, fn func(total, completed int64)) error { totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - buf := make([]byte, 1024) - totalBytes := alreadyDownloaded - totalSize += alreadyDownloaded + totalBytes := size + totalSize += size for { - n, err := resp.Body.Read(buf) - if err != nil && err != io.EOF { + n, err := io.CopyN(out, resp.Body, 8192) + if err != nil && !errors.Is(err, io.EOF) { return err } + if n == 0 { break } - if _, err := out.Write(buf[:n]); err != nil { - return err - } - - totalBytes += int64(n) + totalBytes += n fn(totalSize, totalBytes) } fn(totalSize, totalSize) - return nil + return os.Rename(model.TempFile(), model.FullName()) } From a806b03f6213d12e0b901f236bfeda306324c17e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jul 2023 14:57:17 -0700 Subject: [PATCH 4/4] no errgroup --- go.mod | 1 - go.sum | 2 -- server/routes.go | 62 +++++++++++++++++++----------------------------- 3 files changed, 24 insertions(+), 41 deletions(-) diff --git a/go.mod b/go.mod index 8beb32bd..c2e15346 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,6 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sync v0.3.0 golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect diff --git a/go.sum b/go.sum index 9189b115..2adee49d 100644 --- a/go.sum +++ b/go.sum @@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/server/routes.go b/server/routes.go index 1478f9ae..ef19f3c2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -16,7 +16,6 @@ import ( "github.com/gin-gonic/gin" "github.com/lithammer/fuzzysearch/fuzzy" - "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" @@ -56,12 +55,8 @@ func generate(c *gin.Context) { req.Model = path.Join(cacheDir(), "models", req.Model+".bin") } - llm, err := llama.New(req.Model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer llm.Close() + ch := make(chan any) + go stream(c, ch) templateNames := make([]string, 0, len(templates.Templates())) for _, template := range templates.Templates() { @@ -79,24 +74,22 @@ func generate(c *gin.Context) { req.Prompt = sb.String() } - ch := make(chan any) - g, _ := errgroup.WithContext(c.Request.Context()) - g.Go(func() error { - defer close(ch) - return llm.Predict(req.Prompt, func(s string) { - ch <- api.GenerateResponse{Response: s} - }) - }) - - g.Go(func() error { - stream(c, ch) - return nil - }) - - if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + llm, err := llama.New(req.Model, req.Options) + if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + defer llm.Close() + + fn := func(s string) { + ch <- api.GenerateResponse{Response: s} + } + + if err := llm.Predict(req.Prompt, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } func pull(c *gin.Context) { @@ -113,24 +106,17 @@ func pull(c *gin.Context) { } ch := make(chan any) - g, _ := errgroup.WithContext(c.Request.Context()) - g.Go(func() error { - defer close(ch) - return saveModel(remote, func(total, completed int64) { - ch <- api.PullProgress{ - Total: total, - Completed: completed, - Percent: float64(total) / float64(completed) * 100, - } - }) - }) + go stream(c, ch) - g.Go(func() error { - stream(c, ch) - return nil - }) + fn := func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(total) / float64(completed) * 100, + } + } - if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + if err := saveModel(remote, fn); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return }