From 21ddcaa1f1912a443858a244a8c9fe9de3cff755 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 8 Aug 2023 13:49:37 -0400 Subject: [PATCH] pr comments - default to embeddings enabled - move embedding logic for loaded model to request - allow embedding full directory - close llm on reload --- api/types.go | 1 + server/images.go | 160 +++++++++++++++++++++++------------------------ server/routes.go | 18 +++++- 3 files changed, 97 insertions(+), 82 deletions(-) diff --git a/api/types.go b/api/types.go index 90b964a3..e6773791 100644 --- a/api/types.go +++ b/api/types.go @@ -275,6 +275,7 @@ func DefaultOptions() Options { UseMLock: false, RopeFrequencyBase: 10000.0, RopeFrequencyScale: 1.0, + EmbeddingOnly: true, RepeatLastN: 64, RepeatPenalty: 1.1, diff --git a/server/images.go b/server/images.go index 4dcaf37d..b807e11b 100644 --- a/server/images.go +++ b/server/images.go @@ -23,7 +23,6 @@ import ( "github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/vector" - "gonum.org/v1/gonum/mat" ) type RegistryOptions struct { @@ -42,7 +41,7 @@ type Model struct { Embeddings []vector.Embedding } -func (m *Model) Prompt(request api.GenerateRequest) (string, error) { +func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) { t := m.Template if request.Template != "" { t = request.Template @@ -67,26 +66,12 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { vars.System = m.System vars.Prompt = request.Prompt vars.Context = request.Context + vars.Embed = embedding if request.System != "" { vars.System = request.System } - if len(m.Embeddings) > 0 { - promptEmbed, err := loaded.llm.Embedding(request.Prompt) - if err != nil { - return "", fmt.Errorf("failed to get embedding for prompt: %v", err) - } - // TODO: set embed_top from specified parameters in modelfile - embed_top := 3 - embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings) - toEmbed := "" - for _, e := range embed { - toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data) - } - vars.Embed = toEmbed - } - var sb strings.Builder if err := tmpl.Execute(&sb, vars); err != nil { return "", err @@ -432,85 +417,98 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } - for _, filePath := range e.files { - // TODO: check if txt file type - f, err := os.Open(filePath) + addedFiles := make(map[string]bool) // keep track of files that have already been added + for _, filePattern := range e.files { + matchingFiles, err := filepath.Glob(filePattern) if err != nil { - return nil, fmt.Errorf("could not open embed file: %w", err) + return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err) } - scanner := bufio.NewScanner(f) - scanner.Split(bufio.ScanLines) - data := []string{} - for scanner.Scan() { - data = append(data, scanner.Text()) - } - f.Close() - - // the digest of the file is set here so that the client knows a new operation is in progress - fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath))) - - embeddings := []vector.Embedding{} - for i, d := range data { - if strings.TrimSpace(d) == "" { + for _, filePath := range matchingFiles { + if addedFiles[filePath] { continue } - e.fn(api.ProgressResponse{ - Status: fmt.Sprintf("creating embeddings for file %s", filePath), - Digest: fileDigest, - Total: len(data) - 1, - Completed: i, - }) - retry := 0 - generate: - if retry > 3 { - log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) - continue - } - embed, err := llm.Embedding(d) + addedFiles[filePath] = true + // TODO: check file type + f, err := os.Open(filePath) if err != nil { - log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) - retry++ - goto generate + return nil, fmt.Errorf("could not open embed file: %w", err) } - // Check for NaN and Inf in the embedding, which can't be stored - for _, value := range embed { - if math.IsNaN(value) || math.IsInf(value, 0) { - log.Printf("reloading model, embedding contains NaN or Inf") - // reload the model to get a new embedding - llm, err = llama.New(model.ModelPath, e.opts) - if err != nil { - return nil, fmt.Errorf("load model to generate embeddings: %v", err) - } + scanner := bufio.NewScanner(f) + scanner.Split(bufio.ScanLines) + + data := []string{} + for scanner.Scan() { + data = append(data, scanner.Text()) + } + f.Close() + + // the digest of the file is set here so that the client knows a new operation is in progress + fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath))) + + embeddings := []vector.Embedding{} + for i, d := range data { + if strings.TrimSpace(d) == "" { + continue + } + e.fn(api.ProgressResponse{ + Status: fmt.Sprintf("creating embeddings for file %s", filePath), + Digest: fileDigest, + Total: len(data) - 1, + Completed: i, + }) + retry := 0 + generate: + if retry > 3 { + log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) + continue + } + embed, err := llm.Embedding(d) + if err != nil { + log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) retry++ goto generate } + // Check for NaN and Inf in the embedding, which can't be stored + for _, value := range embed { + if math.IsNaN(value) || math.IsInf(value, 0) { + log.Printf("reloading model, embedding contains NaN or Inf") + // reload the model to get a new embedding, the seed can effect these outputs and reloading changes it + llm.Close() + llm, err = llama.New(model.ModelPath, e.opts) + if err != nil { + return nil, fmt.Errorf("load model to generate embeddings: %v", err) + } + retry++ + goto generate + } + } + embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) } - embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) - } - b, err := json.Marshal(embeddings) - if err != nil { - return nil, fmt.Errorf("failed to encode embeddings: %w", err) - } - r := bytes.NewReader(b) + b, err := json.Marshal(embeddings) + if err != nil { + return nil, fmt.Errorf("failed to encode embeddings: %w", err) + } + r := bytes.NewReader(b) - digest, size := GetSHA256Digest(r) - // Reset the position of the reader after calculating the digest - if _, err := r.Seek(0, 0); err != nil { - return nil, fmt.Errorf("could not reset embed reader: %w", err) - } + digest, size := GetSHA256Digest(r) + // Reset the position of the reader after calculating the digest + if _, err := r.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("could not reset embed reader: %w", err) + } - layer := &LayerReader{ - Layer: Layer{ - MediaType: "application/vnd.ollama.image.embed", - Digest: digest, - Size: size, - }, - Reader: r, - } + layer := &LayerReader{ + Layer: Layer{ + MediaType: "application/vnd.ollama.image.embed", + Digest: digest, + Size: size, + }, + Reader: r, + } - layers = append(layers, layer) + layers = append(layers, layer) + } } } return layers, nil diff --git a/server/routes.go b/server/routes.go index 2a880aaa..b48d7902 100644 --- a/server/routes.go +++ b/server/routes.go @@ -17,6 +17,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "gonum.org/v1/gonum/mat" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" @@ -114,7 +115,22 @@ func GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() - prompt, err := model.Prompt(req) + embedding := "" + if model.Embeddings != nil && len(model.Embeddings) > 0 { + promptEmbed, err := loaded.llm.Embedding(req.Prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + // TODO: set embed_top from specified parameters in modelfile + embed_top := 3 + topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings) + for _, e := range topK { + embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data) + } + } + + prompt, err := model.Prompt(req, embedding) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return