From a6f6d18f83408a8d06c77f07bc51f734e866e990 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 4 Aug 2023 18:56:40 -0400 Subject: [PATCH 1/5] embed text document in modelfile --- cmd/cmd.go | 18 ++-- go.mod | 1 + go.sum | 2 + llama/llama.go | 37 +++++++ parser/parser.go | 2 +- server/images.go | 250 +++++++++++++++++++++++++++++++++++++---------- server/routes.go | 10 +- vector/store.go | 69 +++++++++++++ 8 files changed, 330 insertions(+), 59 deletions(-) create mode 100644 vector/store.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 9526c864..1e0d1f59 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner.Stop() } currentDigest = resp.Digest - bar = progressbar.DefaultBytes( - int64(resp.Total), - fmt.Sprintf("pulling %s...", resp.Digest[7:19]), - ) - - bar.Set(resp.Completed) + switch { + case strings.Contains(resp.Status, "embeddings"): + bar = progressbar.Default(int64(resp.Total), resp.Status) + bar.Set(resp.Completed) + default: + // pulling + bar = progressbar.DefaultBytes( + int64(resp.Total), + resp.Status, + ) + bar.Set(resp.Completed) + } } else if resp.Digest == currentDigest && resp.Digest != "" { bar.Set(resp.Completed) } else { diff --git a/go.mod b/go.mod index 554473cb..a0583e65 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect + gonum.org/v1/gonum v0.13.0 google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c4097bdb..7ec060d3 100644 --- a/go.sum +++ b/go.sum @@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM= +gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= diff --git a/llama/llama.go b/llama/llama.go index 0a523321..2c11cbc3 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -85,6 +85,7 @@ llama_token llama_sample( } */ import "C" + import ( "bytes" "embed" @@ -93,6 +94,7 @@ import ( "io" "log" "os" + "reflect" "strings" "sync" "unicode/utf8" @@ -414,3 +416,38 @@ func (llm *LLM) next() (C.llama_token, error) { return token, nil } + +func (llm *LLM) Embedding(input string) ([]float64, error) { + if !llm.EmbeddingOnly { + return nil, errors.New("llama: embedding not enabled") + } + + tokens := llm.tokenize(input) + if tokens == nil { + return nil, errors.New("llama: tokenize embedding") + } + + retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)) + if retval != 0 { + return nil, errors.New("llama: eval") + } + + n := int(C.llama_n_embd(llm.ctx)) + if n <= 0 { + return nil, errors.New("llama: no embeddings generated") + } + + embedPtr := C.llama_get_embeddings(llm.ctx) + if embedPtr == nil { + return nil, errors.New("llama: embedding retrieval failed") + } + + header := reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(embedPtr)), + Len: n, + Cap: n, + } + embedSlice := *(*[]float64)(unsafe.Pointer(&header)) + + return embedSlice, nil +} diff --git a/parser/parser.go b/parser/parser.go index c89b13e6..06ccf786 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) { command.Args = string(fields[1]) // copy command for validation modelCommand = command - case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": + case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED": command.Name = string(bytes.ToLower(fields[0])) command.Args = string(fields[1]) case "PARAMETER": diff --git a/server/images.go b/server/images.go index e06a40a1..4dcaf37d 100644 --- a/server/images.go +++ b/server/images.go @@ -1,6 +1,7 @@ package server import ( + "bufio" "bytes" "crypto/sha256" "encoding/json" @@ -9,6 +10,7 @@ import ( "html/template" "io" "log" + "math" "net/http" "os" "path" @@ -18,7 +20,10 @@ import ( "strings" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/parser" + "github.com/jmorganca/ollama/vector" + "gonum.org/v1/gonum/mat" ) type RegistryOptions struct { @@ -28,12 +33,13 @@ type RegistryOptions struct { } type Model struct { - Name string `json:"name"` - ModelPath string - Template string - System string - Digest string - Options map[string]interface{} + Name string `json:"name"` + ModelPath string + Template string + System string + Digest string + Options map[string]interface{} + Embeddings []vector.Embedding } func (m *Model) Prompt(request api.GenerateRequest) (string, error) { @@ -51,6 +57,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { First bool System string Prompt string + Embed string // deprecated: versions <= 0.0.7 used this to omit the system prompt Context []int @@ -65,6 +72,21 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { vars.System = request.System } + if len(m.Embeddings) > 0 { + promptEmbed, err := loaded.llm.Embedding(request.Prompt) + if err != nil { + return "", fmt.Errorf("failed to get embedding for prompt: %v", err) + } + // TODO: set embed_top from specified parameters in modelfile + embed_top := 3 + embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings) + toEmbed := "" + for _, e := range embed { + toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data) + } + vars.Embed = toEmbed + } + var sb strings.Builder if err := tmpl.Execute(&sb, vars); err != nil { return "", err @@ -157,6 +179,16 @@ func GetModel(name string) (*Model, error) { switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename + case "application/vnd.ollama.image.embed": + file, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("failed to open file: %s", filename) + } + defer file.Close() + + if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil { + return nil, err + } case "application/vnd.ollama.image.template": bts, err := os.ReadFile(filename) if err != nil { @@ -195,6 +227,26 @@ func GetModel(name string) (*Model, error) { return model, nil } +func filenameWithPath(path, f string) (string, error) { + // if filePath starts with ~/, replace it with the user's home directory. + if strings.HasPrefix(f, "~/") { + parts := strings.Split(f, "/") + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to open file: %v", err) + } + + f = filepath.Join(home, filepath.Join(parts[1:]...)) + } + + // if filePath is not an absolute path, make it relative to the modelfile path + if !filepath.IsAbs(f) { + f = filepath.Join(filepath.Dir(path), f) + } + + return f, nil +} + func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { mf, err := os.Open(path) if err != nil { @@ -211,52 +263,37 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e var layers []*LayerReader params := make(map[string][]string) - + embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} for _, c := range commands { log.Printf("[%s] - %s\n", c.Name, c.Args) switch c.Name { case "model": fn(api.ProgressResponse{Status: "looking for model"}) + embed.model = c.Args mf, err := GetManifest(ParseModelPath(c.Args)) if err != nil { - fp := c.Args - - // If filePath starts with ~/, replace it with the user's home directory. - if strings.HasPrefix(fp, "~/") { - parts := strings.Split(fp, "/") - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - - fp = filepath.Join(home, filepath.Join(parts[1:]...)) + modelFile, err := filenameWithPath(path, c.Args) + if err != nil { + return err } - - // If filePath is not an absolute path, make it relative to the modelfile path - if !filepath.IsAbs(fp) { - fp = filepath.Join(filepath.Dir(path), fp) - } - - if _, err := os.Stat(fp); err != nil { + if _, err := os.Stat(modelFile); err != nil { // the model file does not exist, try pulling it if errors.Is(err, os.ErrNotExist) { fn(api.ProgressResponse{Status: "pulling model file"}) if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { return err } - mf, err = GetManifest(ParseModelPath(c.Args)) + mf, err = GetManifest(ParseModelPath(modelFile)) if err != nil { return fmt.Errorf("failed to open file after pull: %v", err) } - } else { return err } } else { // create a model from this specified file fn(api.ProgressResponse{Status: "creating model layer"}) - - file, err := os.Open(fp) + file, err := os.Open(modelFile) if err != nil { return fmt.Errorf("failed to open file: %v", err) } @@ -280,19 +317,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e layers = append(layers, newLayer) } } - case "license": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - // remove the prompt layer if one exists - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) - - layer, err := CreateLayer(strings.NewReader(c.Args)) + case "embed": + // TODO: support entire directories here + embedFilePath, err := filenameWithPath(path, c.Args) if err != nil { return err } - - layer.MediaType = mediaType - layers = append(layers, layer) - case "template", "system", "prompt": + embed.files = append(embed.files, embedFilePath) + case "license", "template", "system", "prompt": fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) // remove the prompt layer if one exists mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) @@ -315,18 +347,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e if len(params) > 0 { fn(api.ProgressResponse{Status: "creating parameter layer"}) layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") - paramData, err := paramsToReader(params) + formattedParams, err := formatParams(params) if err != nil { return fmt.Errorf("couldn't create params json: %v", err) } - l, err := CreateLayer(paramData) + + bts, err := json.Marshal(formattedParams) + if err != nil { + return err + } + + l, err := CreateLayer(bytes.NewReader(bts)) if err != nil { return fmt.Errorf("failed to create layer: %v", err) } l.MediaType = "application/vnd.ollama.image.params" layers = append(layers, l) + + // apply these parameters to the embedding options, in case embeddings need to be generated using this model + embed.opts = api.DefaultOptions() + embed.opts.FromMap(formattedParams) } + // generate the embedding layers + embeddingLayers, err := embeddingLayers(embed) + if err != nil { + return err + } + layers = append(layers, embeddingLayers...) + digests, err := getLayerDigests(layers) if err != nil { return err @@ -361,6 +410,112 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e return nil } +type EmbeddingParams struct { + model string + opts api.Options + files []string // paths to files to embed + fn func(resp api.ProgressResponse) +} + +// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file +func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { + layers := []*LayerReader{} + if len(e.files) > 0 { + model, err := GetModel(e.model) + if err != nil { + return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) + } + + e.opts.EmbeddingOnly = true + llm, err := llama.New(model.ModelPath, e.opts) + if err != nil { + return nil, fmt.Errorf("load model to generate embeddings: %v", err) + } + + for _, filePath := range e.files { + // TODO: check if txt file type + f, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("could not open embed file: %w", err) + } + scanner := bufio.NewScanner(f) + scanner.Split(bufio.ScanLines) + + data := []string{} + for scanner.Scan() { + data = append(data, scanner.Text()) + } + f.Close() + + // the digest of the file is set here so that the client knows a new operation is in progress + fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath))) + + embeddings := []vector.Embedding{} + for i, d := range data { + if strings.TrimSpace(d) == "" { + continue + } + e.fn(api.ProgressResponse{ + Status: fmt.Sprintf("creating embeddings for file %s", filePath), + Digest: fileDigest, + Total: len(data) - 1, + Completed: i, + }) + retry := 0 + generate: + if retry > 3 { + log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) + continue + } + embed, err := llm.Embedding(d) + if err != nil { + log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) + retry++ + goto generate + } + // Check for NaN and Inf in the embedding, which can't be stored + for _, value := range embed { + if math.IsNaN(value) || math.IsInf(value, 0) { + log.Printf("reloading model, embedding contains NaN or Inf") + // reload the model to get a new embedding + llm, err = llama.New(model.ModelPath, e.opts) + if err != nil { + return nil, fmt.Errorf("load model to generate embeddings: %v", err) + } + retry++ + goto generate + } + } + embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) + } + + b, err := json.Marshal(embeddings) + if err != nil { + return nil, fmt.Errorf("failed to encode embeddings: %w", err) + } + r := bytes.NewReader(b) + + digest, size := GetSHA256Digest(r) + // Reset the position of the reader after calculating the digest + if _, err := r.Seek(0, 0); err != nil { + return nil, fmt.Errorf("could not reset embed reader: %w", err) + } + + layer := &LayerReader{ + Layer: Layer{ + MediaType: "application/vnd.ollama.image.embed", + Digest: digest, + Size: size, + }, + Reader: r, + } + + layers = append(layers, layer) + } + } + return layers, nil +} + func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader { j := 0 for _, l := range layers { @@ -449,8 +604,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) { return newLayer, nil } -// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json -func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { +// formatParams converts specified parameter options to their correct types +func formatParams(params map[string][]string) (map[string]interface{}, error) { opts := api.Options{} valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct @@ -504,12 +659,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { } } - bts, err := json.Marshal(out) - if err != nil { - return nil, err - } - - return bytes.NewReader(bts), nil + return out, nil } func getLayerDigests(layers []*LayerReader) ([]string, error) { @@ -1042,7 +1192,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func for { fn(api.ProgressResponse{ - Status: fmt.Sprintf("downloading %s", digest), + Status: fmt.Sprintf("pulling %s...", digest[7:19]), Digest: digest, Total: int(total), Completed: int(completed), diff --git a/server/routes.go b/server/routes.go index 83afef1a..2a880aaa 100644 --- a/server/routes.go +++ b/server/routes.go @@ -20,12 +20,14 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" + "github.com/jmorganca/ollama/vector" ) var loaded struct { mu sync.Mutex - llm *llama.LLM + llm *llama.LLM + Embeddings []vector.Embedding expireAt time.Time expireTimer *time.Timer @@ -72,6 +74,11 @@ func GenerateHandler(c *gin.Context) { loaded.digest = "" } + if model.Embeddings != nil && len(model.Embeddings) > 0 { + opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work + loaded.Embeddings = model.Embeddings + } + llm, err := llama.New(model.ModelPath, opts) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -82,7 +89,6 @@ func GenerateHandler(c *gin.Context) { loaded.digest = model.Digest loaded.options = opts } - sessionDuration := 5 * time.Minute loaded.expireAt = time.Now().Add(sessionDuration) diff --git a/vector/store.go b/vector/store.go new file mode 100644 index 00000000..510470d8 --- /dev/null +++ b/vector/store.go @@ -0,0 +1,69 @@ +package vector + +import ( + "container/heap" + "sort" + + "gonum.org/v1/gonum/mat" +) + +type Embedding struct { + Vector []float64 // the embedding vector + Data string // the data represted by the embedding +} + +type EmbeddingSimilarity struct { + Embedding Embedding // the embedding that was used to calculate the similarity + Similarity float64 // the similarity between the embedding and the query +} + +type Heap []EmbeddingSimilarity + +func (h Heap) Len() int { return len(h) } +func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity } +func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *Heap) Push(e any) { + *h = append(*h, e.(EmbeddingSimilarity)) +} + +func (h *Heap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors. +// This value will range from -1 to 1, where 1 means the vectors are identical. +func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 { + dotProduct := mat.Dot(vec1, vec2) + norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2) + + if norms == 0 { + return 0 + } + return dotProduct / norms +} + +func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity { + h := &Heap{} + heap.Init(h) + for _, emb := range embeddings { + similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector)) + heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity}) + if h.Len() > k { + heap.Pop(h) + } + } + + topK := make([]EmbeddingSimilarity, 0, h.Len()) + for h.Len() > 0 { + topK = append(topK, heap.Pop(h).(EmbeddingSimilarity)) + } + sort.Slice(topK, func(i, j int) bool { + return topK[i].Similarity > topK[j].Similarity + }) + + return topK +} From 21ddcaa1f1912a443858a244a8c9fe9de3cff755 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 8 Aug 2023 13:49:37 -0400 Subject: [PATCH 2/5] pr comments - default to embeddings enabled - move embedding logic for loaded model to request - allow embedding full directory - close llm on reload --- api/types.go | 1 + server/images.go | 160 +++++++++++++++++++++++------------------------ server/routes.go | 18 +++++- 3 files changed, 97 insertions(+), 82 deletions(-) diff --git a/api/types.go b/api/types.go index 90b964a3..e6773791 100644 --- a/api/types.go +++ b/api/types.go @@ -275,6 +275,7 @@ func DefaultOptions() Options { UseMLock: false, RopeFrequencyBase: 10000.0, RopeFrequencyScale: 1.0, + EmbeddingOnly: true, RepeatLastN: 64, RepeatPenalty: 1.1, diff --git a/server/images.go b/server/images.go index 4dcaf37d..b807e11b 100644 --- a/server/images.go +++ b/server/images.go @@ -23,7 +23,6 @@ import ( "github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/vector" - "gonum.org/v1/gonum/mat" ) type RegistryOptions struct { @@ -42,7 +41,7 @@ type Model struct { Embeddings []vector.Embedding } -func (m *Model) Prompt(request api.GenerateRequest) (string, error) { +func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) { t := m.Template if request.Template != "" { t = request.Template @@ -67,26 +66,12 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { vars.System = m.System vars.Prompt = request.Prompt vars.Context = request.Context + vars.Embed = embedding if request.System != "" { vars.System = request.System } - if len(m.Embeddings) > 0 { - promptEmbed, err := loaded.llm.Embedding(request.Prompt) - if err != nil { - return "", fmt.Errorf("failed to get embedding for prompt: %v", err) - } - // TODO: set embed_top from specified parameters in modelfile - embed_top := 3 - embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings) - toEmbed := "" - for _, e := range embed { - toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data) - } - vars.Embed = toEmbed - } - var sb strings.Builder if err := tmpl.Execute(&sb, vars); err != nil { return "", err @@ -432,85 +417,98 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } - for _, filePath := range e.files { - // TODO: check if txt file type - f, err := os.Open(filePath) + addedFiles := make(map[string]bool) // keep track of files that have already been added + for _, filePattern := range e.files { + matchingFiles, err := filepath.Glob(filePattern) if err != nil { - return nil, fmt.Errorf("could not open embed file: %w", err) + return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err) } - scanner := bufio.NewScanner(f) - scanner.Split(bufio.ScanLines) - data := []string{} - for scanner.Scan() { - data = append(data, scanner.Text()) - } - f.Close() - - // the digest of the file is set here so that the client knows a new operation is in progress - fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath))) - - embeddings := []vector.Embedding{} - for i, d := range data { - if strings.TrimSpace(d) == "" { + for _, filePath := range matchingFiles { + if addedFiles[filePath] { continue } - e.fn(api.ProgressResponse{ - Status: fmt.Sprintf("creating embeddings for file %s", filePath), - Digest: fileDigest, - Total: len(data) - 1, - Completed: i, - }) - retry := 0 - generate: - if retry > 3 { - log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) - continue - } - embed, err := llm.Embedding(d) + addedFiles[filePath] = true + // TODO: check file type + f, err := os.Open(filePath) if err != nil { - log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) - retry++ - goto generate + return nil, fmt.Errorf("could not open embed file: %w", err) } - // Check for NaN and Inf in the embedding, which can't be stored - for _, value := range embed { - if math.IsNaN(value) || math.IsInf(value, 0) { - log.Printf("reloading model, embedding contains NaN or Inf") - // reload the model to get a new embedding - llm, err = llama.New(model.ModelPath, e.opts) - if err != nil { - return nil, fmt.Errorf("load model to generate embeddings: %v", err) - } + scanner := bufio.NewScanner(f) + scanner.Split(bufio.ScanLines) + + data := []string{} + for scanner.Scan() { + data = append(data, scanner.Text()) + } + f.Close() + + // the digest of the file is set here so that the client knows a new operation is in progress + fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath))) + + embeddings := []vector.Embedding{} + for i, d := range data { + if strings.TrimSpace(d) == "" { + continue + } + e.fn(api.ProgressResponse{ + Status: fmt.Sprintf("creating embeddings for file %s", filePath), + Digest: fileDigest, + Total: len(data) - 1, + Completed: i, + }) + retry := 0 + generate: + if retry > 3 { + log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) + continue + } + embed, err := llm.Embedding(d) + if err != nil { + log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) retry++ goto generate } + // Check for NaN and Inf in the embedding, which can't be stored + for _, value := range embed { + if math.IsNaN(value) || math.IsInf(value, 0) { + log.Printf("reloading model, embedding contains NaN or Inf") + // reload the model to get a new embedding, the seed can effect these outputs and reloading changes it + llm.Close() + llm, err = llama.New(model.ModelPath, e.opts) + if err != nil { + return nil, fmt.Errorf("load model to generate embeddings: %v", err) + } + retry++ + goto generate + } + } + embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) } - embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) - } - b, err := json.Marshal(embeddings) - if err != nil { - return nil, fmt.Errorf("failed to encode embeddings: %w", err) - } - r := bytes.NewReader(b) + b, err := json.Marshal(embeddings) + if err != nil { + return nil, fmt.Errorf("failed to encode embeddings: %w", err) + } + r := bytes.NewReader(b) - digest, size := GetSHA256Digest(r) - // Reset the position of the reader after calculating the digest - if _, err := r.Seek(0, 0); err != nil { - return nil, fmt.Errorf("could not reset embed reader: %w", err) - } + digest, size := GetSHA256Digest(r) + // Reset the position of the reader after calculating the digest + if _, err := r.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("could not reset embed reader: %w", err) + } - layer := &LayerReader{ - Layer: Layer{ - MediaType: "application/vnd.ollama.image.embed", - Digest: digest, - Size: size, - }, - Reader: r, - } + layer := &LayerReader{ + Layer: Layer{ + MediaType: "application/vnd.ollama.image.embed", + Digest: digest, + Size: size, + }, + Reader: r, + } - layers = append(layers, layer) + layers = append(layers, layer) + } } } return layers, nil diff --git a/server/routes.go b/server/routes.go index 2a880aaa..b48d7902 100644 --- a/server/routes.go +++ b/server/routes.go @@ -17,6 +17,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "gonum.org/v1/gonum/mat" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" @@ -114,7 +115,22 @@ func GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() - prompt, err := model.Prompt(req) + embedding := "" + if model.Embeddings != nil && len(model.Embeddings) > 0 { + promptEmbed, err := loaded.llm.Embedding(req.Prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + // TODO: set embed_top from specified parameters in modelfile + embed_top := 3 + topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings) + for _, e := range topK { + embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data) + } + } + + prompt, err := model.Prompt(req, embedding) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return From 3ceac05108fec492b8957370677f06de4dffea64 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 8 Aug 2023 14:04:11 -0400 Subject: [PATCH 3/5] Add embedding docs --- docs/modelfile.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/modelfile.md b/docs/modelfile.md index 3e8a7ee8..77dd765b 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -12,6 +12,7 @@ A model file is the blueprint to create and share models with Ollama. - [FROM (Required)](#from-required) - [Build from llama2](#build-from-llama2) - [Build from a bin file](#build-from-a-bin-file) + - [EMBED](#embed) - [PARAMETER](#parameter) - [Valid Parameters and Values](#valid-parameters-and-values) - [TEMPLATE](#template) @@ -88,12 +89,23 @@ FROM ./ollama-model.bin This bin file location should be specified as an absolute path or relative to the Modelfile location. +### EMBED + +The EMBED instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer. + +``` +FROM : +EMBED +``` + ### PARAMETER The `PARAMETER` instruction defines a parameter that can be set when the model is run. ``` + PARAMETER + ``` ### Valid Parameters and Values @@ -127,19 +139,25 @@ PARAMETER | `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. | ``` + TEMPLATE """ {{- if .First }} + ### System: + {{ .System }} {{- end }} ### User: + {{ .Prompt }} ### Response: + """ SYSTEM """""" + ``` ### SYSTEM @@ -147,7 +165,9 @@ SYSTEM """""" The `SYSTEM` instruction specifies the system prompt to be used in the template, if applicable. ``` + SYSTEM """""" + ``` ### LICENSE @@ -155,12 +175,18 @@ SYSTEM """""" The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed. ``` + LICENSE """ """ + ``` ## Notes - the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments. - Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable. + +``` + +``` From 884d78ceb3f0b6a3867186460d23015cdca7fab4 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 8 Aug 2023 14:38:57 -0400 Subject: [PATCH 4/5] allow embedding from model binary --- server/images.go | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/server/images.go b/server/images.go index b807e11b..fe41c9be 100644 --- a/server/images.go +++ b/server/images.go @@ -268,7 +268,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { return err } - mf, err = GetManifest(ParseModelPath(modelFile)) + mf, err = GetManifest(ParseModelPath(c.Args)) if err != nil { return fmt.Errorf("failed to open file after pull: %v", err) } @@ -354,6 +354,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e embed.opts.FromMap(formattedParams) } + fmt.Println(embed.model) + // generate the embedding layers embeddingLayers, err := embeddingLayers(embed) if err != nil { @@ -406,13 +408,21 @@ type EmbeddingParams struct { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { layers := []*LayerReader{} if len(e.files) > 0 { - model, err := GetModel(e.model) - if err != nil { - return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) + if _, err := os.Stat(e.model); err != nil { + if os.IsNotExist(err) { + // this is a model name rather than the file + model, err := GetModel(e.model) + if err != nil { + return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) + } + e.model = model.ModelPath + } else { + return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err) + } } e.opts.EmbeddingOnly = true - llm, err := llama.New(model.ModelPath, e.opts) + llm, err := llama.New(e.model, e.opts) if err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } @@ -475,7 +485,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { log.Printf("reloading model, embedding contains NaN or Inf") // reload the model to get a new embedding, the seed can effect these outputs and reloading changes it llm.Close() - llm, err = llama.New(model.ModelPath, e.opts) + llm, err = llama.New(e.model, e.opts) if err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } From 1bee2347bedb0ea43d15ffc42b481cb9a4804aea Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 8 Aug 2023 16:56:48 -0400 Subject: [PATCH 5/5] pr feedback - defer closing llm on embedding - do not override licenses - remove debugging print line - reformat model file docs --- docs/modelfile.md | 18 +----------------- server/images.go | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/docs/modelfile.md b/docs/modelfile.md index 77dd765b..a90cbc0d 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -103,9 +103,7 @@ EMBED The `PARAMETER` instruction defines a parameter that can be set when the model is run. ``` - PARAMETER - ``` ### Valid Parameters and Values @@ -139,25 +137,19 @@ PARAMETER | `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. | ``` - TEMPLATE """ {{- if .First }} - ### System: - {{ .System }} {{- end }} ### User: - {{ .Prompt }} ### Response: - """ SYSTEM """""" - ``` ### SYSTEM @@ -165,9 +157,7 @@ SYSTEM """""" The `SYSTEM` instruction specifies the system prompt to be used in the template, if applicable. ``` - SYSTEM """""" - ``` ### LICENSE @@ -175,18 +165,12 @@ SYSTEM """""" The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed. ``` - LICENSE """ """ - ``` ## Notes - the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments. -- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable. - -``` - -``` +- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable. \ No newline at end of file diff --git a/server/images.go b/server/images.go index fe41c9be..5796d2f4 100644 --- a/server/images.go +++ b/server/images.go @@ -303,13 +303,23 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e } } case "embed": - // TODO: support entire directories here embedFilePath, err := filenameWithPath(path, c.Args) if err != nil { return err } embed.files = append(embed.files, embedFilePath) - case "license", "template", "system", "prompt": + case "license": + fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) + mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + + layer, err := CreateLayer(strings.NewReader(c.Args)) + if err != nil { + return err + } + + layer.MediaType = mediaType + layers = append(layers, layer) + case "template", "system", "prompt": fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) // remove the prompt layer if one exists mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) @@ -354,8 +364,6 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e embed.opts.FromMap(formattedParams) } - fmt.Println(embed.model) - // generate the embedding layers embeddingLayers, err := embeddingLayers(embed) if err != nil { @@ -426,6 +434,11 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { if err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } + defer func() { + if llm != nil { + llm.Close() + } + }() addedFiles := make(map[string]bool) // keep track of files that have already been added for _, filePattern := range e.files {