From 326de489307b5d8115e216148a17abae6ee940aa Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 15 Aug 2023 10:35:39 -0300 Subject: [PATCH] use loaded llm for embeddings --- server/images.go | 38 ++++++++++++++------------------------ server/routes.go | 4 +++- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/server/images.go b/server/images.go index 960f2d58..8bf4344a 100644 --- a/server/images.go +++ b/server/images.go @@ -263,7 +263,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api var layers []*LayerReader params := make(map[string][]string) - embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} + embed := EmbeddingParams{fn: fn} for _, c := range commands { log.Printf("[%s] - %s\n", c.Name, c.Args) switch c.Name { @@ -291,6 +291,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api return err } } else { + embed.model = modelFile // create a model from this specified file fn(api.ProgressResponse{Status: "creating model layer"}) file, err := os.Open(modelFile) @@ -422,8 +423,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api layers = append(layers, l) // apply these parameters to the embedding options, in case embeddings need to be generated using this model - embed.opts = api.DefaultOptions() - embed.opts.FromMap(formattedParams) + embed.opts = formattedParams } // generate the embedding layers @@ -469,7 +469,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api type EmbeddingParams struct { model string - opts api.Options + opts map[string]interface{} files []string // paths to files to embed fn func(resp api.ProgressResponse) } @@ -478,32 +478,22 @@ type EmbeddingParams struct { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { layers := []*LayerReader{} if len(e.files) > 0 { - if _, err := os.Stat(e.model); err != nil { - if os.IsNotExist(err) { - // this is a model name rather than the file - model, err := GetModel(e.model) - if err != nil { - return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) - } - e.model = model.ModelPath - } else { - return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err) + // check if the model is a file path or a model name + model, err := GetModel(e.model) + if err != nil { + if !strings.Contains(err.Error(), "couldn't open file") { + return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err) } + // the model may be a file path, create a model from this file + model = &Model{ModelPath: e.model} } - e.opts.EmbeddingOnly = true - llmModel, err := llm.New(e.model, []string{}, e.opts) - if err != nil { + if err := load(model, e.opts, defaultSessionDuration); err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } - defer func() { - if llmModel != nil { - llmModel.Close() - } - }() // this will be used to check if we already have embeddings for a file - modelInfo, err := os.Stat(e.model) + modelInfo, err := os.Stat(model.ModelPath) if err != nil { return nil, fmt.Errorf("failed to get model file info: %v", err) } @@ -561,7 +551,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]}) continue } - embed, err := llmModel.Embedding(d) + embed, err := loaded.llm.Embedding(d) if err != nil { log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) continue diff --git a/server/routes.go b/server/routes.go index 3e13328b..7e78178c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,6 +38,8 @@ var loaded struct { options api.Options } +var defaultSessionDuration = 5 * time.Minute + // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { opts := api.DefaultOptions() @@ -134,7 +136,7 @@ func GenerateHandler(c *gin.Context) { return } - sessionDuration := 5 * time.Minute + sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified if err := load(model, req.Options, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return