![Michael Yang](/assets/img/avatar_default.png)
block on write which only returns when the channel is closed. this is contrary to the previous arrangement where the handler may return but the stream hasn't finished writing. it can lead to the client receiving unexpected responses (since the request has been handled) or worst case a nil-pointer dereference as the stream tries to flush a nil writer
197 lines
3.9 KiB
Go
197 lines
3.9 KiB
Go
package server
|
|
|
|
import (
|
|
"embed"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
"text/template"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/lithammer/fuzzysearch/fuzzy"
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
"github.com/jmorganca/ollama/llama"
|
|
)
|
|
|
|
//go:embed templates/*
|
|
var templatesFS embed.FS
|
|
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
|
|
|
func cacheDir() string {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return path.Join(home, ".ollama")
|
|
}
|
|
|
|
func generate(c *gin.Context) {
|
|
start := time.Now()
|
|
|
|
req := api.GenerateRequest{
|
|
Options: api.DefaultOptions(),
|
|
}
|
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
|
|
req.Model = remoteModel.FullName()
|
|
}
|
|
if _, err := os.Stat(req.Model); err != nil {
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
|
}
|
|
|
|
templateNames := make([]string, 0, len(templates.Templates()))
|
|
for _, template := range templates.Templates() {
|
|
templateNames = append(templateNames, template.Name())
|
|
}
|
|
|
|
match, _ := matchRankOne(path.Base(req.Model), templateNames)
|
|
if template := templates.Lookup(match); template != nil {
|
|
var sb strings.Builder
|
|
if err := template.Execute(&sb, req); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
req.Prompt = sb.String()
|
|
}
|
|
|
|
llm, err := llama.New(req.Model, req.Options)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
defer llm.Close()
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
|
|
r.Model = req.Model
|
|
r.CreatedAt = time.Now().UTC()
|
|
if r.Done {
|
|
r.TotalDuration = time.Since(start)
|
|
}
|
|
|
|
ch <- r
|
|
})
|
|
}()
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func pull(c *gin.Context) {
|
|
var req api.PullRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
remote, err := getRemote(req.Model)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// check if completed file exists
|
|
fi, err := os.Stat(remote.FullName())
|
|
switch {
|
|
case errors.Is(err, os.ErrNotExist):
|
|
// noop, file doesn't exist so create it
|
|
case err != nil:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
default:
|
|
c.JSON(http.StatusOK, api.PullProgress{
|
|
Total: fi.Size(),
|
|
Completed: fi.Size(),
|
|
Percent: 100,
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
saveModel(remote, func(total, completed int64) {
|
|
ch <- api.PullProgress{
|
|
Total: total,
|
|
Completed: completed,
|
|
Percent: float64(completed) / float64(total) * 100,
|
|
}
|
|
})
|
|
}()
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func Serve(ln net.Listener) error {
|
|
r := gin.Default()
|
|
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "Ollama is running")
|
|
})
|
|
|
|
r.POST("/api/pull", pull)
|
|
r.POST("/api/generate", generate)
|
|
|
|
log.Printf("Listening on %s", ln.Addr())
|
|
s := &http.Server{
|
|
Handler: r,
|
|
}
|
|
|
|
return s.Serve(ln)
|
|
}
|
|
|
|
func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
|
|
bestRank = math.MaxInt
|
|
for _, target := range targets {
|
|
if rank := fuzzy.LevenshteinDistance(source, target); bestRank > rank {
|
|
bestRank = rank
|
|
bestMatch = target
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func streamResponse(c *gin.Context, ch chan any) {
|
|
c.Stream(func(w io.Writer) bool {
|
|
val, ok := <-ch
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
bts, err := json.Marshal(val)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
bts = append(bts, '\n')
|
|
if _, err := w.Write(bts); err != nil {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|