pr comments
- default to embeddings enabled - move embedding logic for loaded model to request - allow embedding full directory - close llm on reload
This commit is contained in:
parent
a6f6d18f83
commit
21ddcaa1f1
|
@ -275,6 +275,7 @@ func DefaultOptions() Options {
|
||||||
UseMLock: false,
|
UseMLock: false,
|
||||||
RopeFrequencyBase: 10000.0,
|
RopeFrequencyBase: 10000.0,
|
||||||
RopeFrequencyScale: 1.0,
|
RopeFrequencyScale: 1.0,
|
||||||
|
EmbeddingOnly: true,
|
||||||
|
|
||||||
RepeatLastN: 64,
|
RepeatLastN: 64,
|
||||||
RepeatPenalty: 1.1,
|
RepeatPenalty: 1.1,
|
||||||
|
|
160
server/images.go
160
server/images.go
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
"github.com/jmorganca/ollama/vector"
|
"github.com/jmorganca/ollama/vector"
|
||||||
"gonum.org/v1/gonum/mat"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RegistryOptions struct {
|
type RegistryOptions struct {
|
||||||
|
@ -42,7 +41,7 @@ type Model struct {
|
||||||
Embeddings []vector.Embedding
|
Embeddings []vector.Embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
|
||||||
t := m.Template
|
t := m.Template
|
||||||
if request.Template != "" {
|
if request.Template != "" {
|
||||||
t = request.Template
|
t = request.Template
|
||||||
|
@ -67,26 +66,12 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||||
vars.System = m.System
|
vars.System = m.System
|
||||||
vars.Prompt = request.Prompt
|
vars.Prompt = request.Prompt
|
||||||
vars.Context = request.Context
|
vars.Context = request.Context
|
||||||
|
vars.Embed = embedding
|
||||||
|
|
||||||
if request.System != "" {
|
if request.System != "" {
|
||||||
vars.System = request.System
|
vars.System = request.System
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Embeddings) > 0 {
|
|
||||||
promptEmbed, err := loaded.llm.Embedding(request.Prompt)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to get embedding for prompt: %v", err)
|
|
||||||
}
|
|
||||||
// TODO: set embed_top from specified parameters in modelfile
|
|
||||||
embed_top := 3
|
|
||||||
embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
|
|
||||||
toEmbed := ""
|
|
||||||
for _, e := range embed {
|
|
||||||
toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data)
|
|
||||||
}
|
|
||||||
vars.Embed = toEmbed
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -432,85 +417,98 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, filePath := range e.files {
|
addedFiles := make(map[string]bool) // keep track of files that have already been added
|
||||||
// TODO: check if txt file type
|
for _, filePattern := range e.files {
|
||||||
f, err := os.Open(filePath)
|
matchingFiles, err := filepath.Glob(filePattern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not open embed file: %w", err)
|
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
|
||||||
}
|
}
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
|
|
||||||
data := []string{}
|
for _, filePath := range matchingFiles {
|
||||||
for scanner.Scan() {
|
if addedFiles[filePath] {
|
||||||
data = append(data, scanner.Text())
|
|
||||||
}
|
|
||||||
f.Close()
|
|
||||||
|
|
||||||
// the digest of the file is set here so that the client knows a new operation is in progress
|
|
||||||
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
|
|
||||||
|
|
||||||
embeddings := []vector.Embedding{}
|
|
||||||
for i, d := range data {
|
|
||||||
if strings.TrimSpace(d) == "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
e.fn(api.ProgressResponse{
|
addedFiles[filePath] = true
|
||||||
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
// TODO: check file type
|
||||||
Digest: fileDigest,
|
f, err := os.Open(filePath)
|
||||||
Total: len(data) - 1,
|
|
||||||
Completed: i,
|
|
||||||
})
|
|
||||||
retry := 0
|
|
||||||
generate:
|
|
||||||
if retry > 3 {
|
|
||||||
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
embed, err := llm.Embedding(d)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
|
return nil, fmt.Errorf("could not open embed file: %w", err)
|
||||||
retry++
|
|
||||||
goto generate
|
|
||||||
}
|
}
|
||||||
// Check for NaN and Inf in the embedding, which can't be stored
|
scanner := bufio.NewScanner(f)
|
||||||
for _, value := range embed {
|
scanner.Split(bufio.ScanLines)
|
||||||
if math.IsNaN(value) || math.IsInf(value, 0) {
|
|
||||||
log.Printf("reloading model, embedding contains NaN or Inf")
|
data := []string{}
|
||||||
// reload the model to get a new embedding
|
for scanner.Scan() {
|
||||||
llm, err = llama.New(model.ModelPath, e.opts)
|
data = append(data, scanner.Text())
|
||||||
if err != nil {
|
}
|
||||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
f.Close()
|
||||||
}
|
|
||||||
|
// the digest of the file is set here so that the client knows a new operation is in progress
|
||||||
|
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
|
||||||
|
|
||||||
|
embeddings := []vector.Embedding{}
|
||||||
|
for i, d := range data {
|
||||||
|
if strings.TrimSpace(d) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e.fn(api.ProgressResponse{
|
||||||
|
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
||||||
|
Digest: fileDigest,
|
||||||
|
Total: len(data) - 1,
|
||||||
|
Completed: i,
|
||||||
|
})
|
||||||
|
retry := 0
|
||||||
|
generate:
|
||||||
|
if retry > 3 {
|
||||||
|
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
embed, err := llm.Embedding(d)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
|
||||||
retry++
|
retry++
|
||||||
goto generate
|
goto generate
|
||||||
}
|
}
|
||||||
|
// Check for NaN and Inf in the embedding, which can't be stored
|
||||||
|
for _, value := range embed {
|
||||||
|
if math.IsNaN(value) || math.IsInf(value, 0) {
|
||||||
|
log.Printf("reloading model, embedding contains NaN or Inf")
|
||||||
|
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
|
||||||
|
llm.Close()
|
||||||
|
llm, err = llama.New(model.ModelPath, e.opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||||
|
}
|
||||||
|
retry++
|
||||||
|
goto generate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
||||||
}
|
}
|
||||||
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := json.Marshal(embeddings)
|
b, err := json.Marshal(embeddings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
|
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
|
||||||
}
|
}
|
||||||
r := bytes.NewReader(b)
|
r := bytes.NewReader(b)
|
||||||
|
|
||||||
digest, size := GetSHA256Digest(r)
|
digest, size := GetSHA256Digest(r)
|
||||||
// Reset the position of the reader after calculating the digest
|
// Reset the position of the reader after calculating the digest
|
||||||
if _, err := r.Seek(0, 0); err != nil {
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
||||||
return nil, fmt.Errorf("could not reset embed reader: %w", err)
|
return nil, fmt.Errorf("could not reset embed reader: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
layer := &LayerReader{
|
layer := &LayerReader{
|
||||||
Layer: Layer{
|
Layer: Layer{
|
||||||
MediaType: "application/vnd.ollama.image.embed",
|
MediaType: "application/vnd.ollama.image.embed",
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: size,
|
Size: size,
|
||||||
},
|
},
|
||||||
Reader: r,
|
Reader: r,
|
||||||
}
|
}
|
||||||
|
|
||||||
layers = append(layers, layer)
|
layers = append(layers, layer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return layers, nil
|
return layers, nil
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
|
@ -114,7 +115,22 @@ func GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt, err := model.Prompt(req)
|
embedding := ""
|
||||||
|
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||||
|
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TODO: set embed_top from specified parameters in modelfile
|
||||||
|
embed_top := 3
|
||||||
|
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
|
||||||
|
for _, e := range topK {
|
||||||
|
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, err := model.Prompt(req, embedding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in a new issue