From 8ea5e5e1471941c3d4d66338c7a729d48d9e3e24 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 5 Jul 2023 15:37:33 -0400 Subject: [PATCH] separate routes --- server/routes.go | 81 +++++++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/server/routes.go b/server/routes.go index bd68ce04..9490441e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -14,54 +14,63 @@ import ( "github.com/jmorganca/ollama/api" ) -func Serve(ln net.Listener) error { - r := gin.Default() +func pull(c *gin.Context) { + // TODO + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +func generate(c *gin.Context) { // TODO: these should be request parameters gpulayers := 0 tokens := 512 threads := runtime.NumCPU() + // TODO: set prompt from template + fmt.Println("Generating text...") - r.POST("/api/generate", func(c *gin.Context) { - // TODO: set prompt from template - fmt.Println("Generating text...") + var req api.GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } - var req api.GenerateRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) - return - } + fmt.Println(req) - fmt.Println(req) + l, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) + if err != nil { + fmt.Println("Loading the model failed:", err.Error()) + return + } - l, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) - if err != nil { - fmt.Println("Loading the model failed:", err.Error()) - return - } + ch := make(chan string) - ch := make(chan string) - - go func() { - defer close(ch) - _, err := l.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool { - ch <- token - return true - }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama")) - if err != nil { - panic(err) - } - }() - - c.Stream(func(w io.Writer) bool { - tok, ok := <-ch - if !ok { - return false - } - c.SSEvent("token", tok) + go func() { + defer close(ch) + _, err := l.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool { + ch <- token return true - }) + }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama")) + if err != nil { + panic(err) + } + }() + + c.Stream(func(w io.Writer) bool { + tok, ok := <-ch + if !ok { + return false + } + c.SSEvent("token", tok) + return true }) +} + +func Serve(ln net.Listener) error { + r := gin.Default() + + r.POST("api/pull", pull) + + r.POST("/api/generate", generate) log.Printf("Listening on %s", ln.Addr()) s := &http.Server{