From a806b03f6213d12e0b901f236bfeda306324c17e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jul 2023 14:57:17 -0700 Subject: [PATCH] no errgroup --- go.mod | 1 - go.sum | 2 -- server/routes.go | 62 +++++++++++++++++++----------------------------- 3 files changed, 24 insertions(+), 41 deletions(-) diff --git a/go.mod b/go.mod index 8beb32bd..c2e15346 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,6 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sync v0.3.0 golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect diff --git a/go.sum b/go.sum index 9189b115..2adee49d 100644 --- a/go.sum +++ b/go.sum @@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/server/routes.go b/server/routes.go index 1478f9ae..ef19f3c2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -16,7 +16,6 @@ import ( "github.com/gin-gonic/gin" "github.com/lithammer/fuzzysearch/fuzzy" - "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" @@ -56,12 +55,8 @@ func generate(c *gin.Context) { req.Model = path.Join(cacheDir(), "models", req.Model+".bin") } - llm, err := llama.New(req.Model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer llm.Close() + ch := make(chan any) + go stream(c, ch) templateNames := make([]string, 0, len(templates.Templates())) for _, template := range templates.Templates() { @@ -79,24 +74,22 @@ func generate(c *gin.Context) { req.Prompt = sb.String() } - ch := make(chan any) - g, _ := errgroup.WithContext(c.Request.Context()) - g.Go(func() error { - defer close(ch) - return llm.Predict(req.Prompt, func(s string) { - ch <- api.GenerateResponse{Response: s} - }) - }) - - g.Go(func() error { - stream(c, ch) - return nil - }) - - if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + llm, err := llama.New(req.Model, req.Options) + if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + defer llm.Close() + + fn := func(s string) { + ch <- api.GenerateResponse{Response: s} + } + + if err := llm.Predict(req.Prompt, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } func pull(c *gin.Context) { @@ -113,24 +106,17 @@ func pull(c *gin.Context) { } ch := make(chan any) - g, _ := errgroup.WithContext(c.Request.Context()) - g.Go(func() error { - defer close(ch) - return saveModel(remote, func(total, completed int64) { - ch <- api.PullProgress{ - Total: total, - Completed: completed, - Percent: float64(total) / float64(completed) * 100, - } - }) - }) + go stream(c, ch) - g.Go(func() error { - stream(c, ch) - return nil - }) + fn := func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(total) / float64(completed) * 100, + } + } - if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + if err := saveModel(remote, fn); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return }