diff --git a/README.md b/README.md index d456aaef..6ebe2827 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Ollama -Ollama is a tool for running large language models. It's designed to be easy to use and fast. +A fast runtime for large language models, powered by [llama.cpp](https://github.com/ggerganov/llama.cpp). > _Note: this project is a work in progress. Certain models that can be run with `ollama` are intended for research and/or non-commercial use only._ diff --git a/api/client.go b/api/client.go new file mode 100644 index 00000000..c653fbaa --- /dev/null +++ b/api/client.go @@ -0,0 +1,99 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/ollama/ollama/signature" +) + +type Client struct { + Name string + Version string + URL string + HTTP http.Client + Headers http.Header + PrivateKey []byte +} + +func checkError(resp *http.Response, body []byte) error { + if resp.StatusCode >= 200 && resp.StatusCode < 400 { + return nil + } + + apiError := Error{Code: int32(resp.StatusCode)} + + err := json.Unmarshal(body, &apiError) + if err != nil { + // Use the full body as the message if we fail to decode a response. + apiError.Message = string(body) + } + + return apiError +} + +func (c *Client) do(ctx context.Context, method string, path string, stream bool, reqData any, respData any) error { + var reqBody io.Reader + var data []byte + var err error + if reqData != nil { + data, err = json.Marshal(reqData) + if err != nil { + return err + } + reqBody = bytes.NewReader(data) + } + + url := fmt.Sprintf("%s%s", c.URL, path) + + req, err := http.NewRequestWithContext(ctx, method, url, reqBody) + if err != nil { + return err + } + + if c.PrivateKey != nil { + s := signature.SignatureData{ + Method: method, + Path: url, + Data: data, + } + authHeader, err := signature.SignAuthData(s, c.PrivateKey) + if err != nil { + return err + } + req.Header.Set("Authorization", authHeader) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + for k, v := range c.Headers { + req.Header[k] = v + } + + respObj, err := c.HTTP.Do(req) + if err != nil { + return err + } + defer respObj.Body.Close() + + respBody, err := io.ReadAll(respObj.Body) + if err != nil { + return err + } + + if err := checkError(respObj, respBody); err != nil { + return err + } + + if len(respBody) > 0 && respData != nil { + if err := json.Unmarshal(respBody, respData); err != nil { + return err + } + } + return nil +} diff --git a/api/types.go b/api/types.go new file mode 100644 index 00000000..5f104415 --- /dev/null +++ b/api/types.go @@ -0,0 +1,28 @@ +package api + +import ( + "fmt" + "net/http" + "strings" +) + +type Error struct { + Code int32 `json:"code"` + Message string `json:"message"` +} + +func (e Error) Error() string { + if e.Message == "" { + return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code)))) + } + return e.Message +} + +type GenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type GenerateResponse struct { + Response string `json:"response"` +} diff --git a/cmd/cmd.go b/cmd/cmd.go new file mode 100644 index 00000000..b6a45258 --- /dev/null +++ b/cmd/cmd.go @@ -0,0 +1,139 @@ +package cmd + +import ( + "context" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "path" + "time" + + "github.com/spf13/cobra" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/server" +) + +func NewAPIClient(cmd *cobra.Command) (*api.Client, error) { + var rawKey []byte + var err error + + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + socket := path.Join(home, ".ollama", "ollama.sock") + + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + } + + k, _ := cmd.Flags().GetString("key") + + if k != "" { + fn := path.Join(home, ".ollama/keys/", k) + rawKey, err = ioutil.ReadFile(fn) + if err != nil { + return nil, err + } + } + + return &api.Client{ + URL: "http://localhost", + HTTP: http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", socket) + }, + }, + }, + PrivateKey: rawKey, + }, nil +} + +func NewCLI() *cobra.Command { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + rootCmd := &cobra.Command{ + Use: "gollama", + Short: "Run any large language model on any machine.", + CompletionOptions: cobra.CompletionOptions{ + DisableDefaultCmd: true, + }, + PersistentPreRun: func(cmd *cobra.Command, args []string) { + // Disable usage printing on errors + cmd.SilenceUsage = true + }, + } + + rootCmd.PersistentFlags().StringP("key", "k", "", "Private key to use for authenticating") + + cobra.EnableCommandSorting = false + + modelsCmd := &cobra.Command{ + Use: "models", + Args: cobra.MaximumNArgs(1), + Short: "List models", + Long: "List the models", + RunE: func(cmd *cobra.Command, args []string) error { + client, err := NewAPIClient(cmd) + if err != nil { + return err + } + fmt.Printf("client = %q\n", client) + return nil + }, + } + +/* + runCmd := &cobra.Command{ + Use: "run", + Short: "Run a model and submit prompts.", + RunE: func(cmd *cobra.Command. args []string) error { + }, + } +*/ + + serveCmd := &cobra.Command{ + Use: "serve", + Aliases: []string{"start"}, + Short: "Start ollama", + RunE: func(cmd *cobra.Command, args []string) error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + socket := path.Join(home, ".ollama", "ollama.sock") + if err := os.MkdirAll(path.Dir(socket), 0o700); err != nil { + return err + } + + if err := os.RemoveAll(socket); err != nil { + return err + } + + ln, err := net.Listen("unix", socket) + if err != nil { + return err + } + + if err := os.Chmod(socket, 0o700); err != nil { + return err + } + + return server.Serve(ln) + }, + } + + rootCmd.AddCommand( + modelsCmd, + serveCmd, + ) + + return rootCmd +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..d79b0c19 --- /dev/null +++ b/go.mod @@ -0,0 +1,40 @@ +module github.com/ollama/ollama + +go 1.20 + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144 + github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc + github.com/spf13/cobra v1.7.0 + golang.org/x/crypto v0.10.0 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..36c0acc6 --- /dev/null +++ b/go.sum @@ -0,0 +1,96 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144 h1:fszkmZG3pW9/bqhuWB6sfJMArJPx1RPzjZSqNdhuSQ0= +github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc/go.mod h1:S8xSOnV3CgpNrWd0GQ/OoQfMtlg2uPRSuTzcSGrzwK8= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= +golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/lib/.gitignore b/lib/.gitignore new file mode 100644 index 00000000..378eac25 --- /dev/null +++ b/lib/.gitignore @@ -0,0 +1 @@ +build diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt new file mode 100644 index 00000000..4c5b96dd --- /dev/null +++ b/lib/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.10) +include(FetchContent) + +FetchContent_Declare( + llama_cpp + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git + GIT_TAG master +) + +FetchContent_MakeAvailable(llama_cpp) + +project(binding) + +set(LLAMA_METAL ON CACHE BOOL "Enable Llama Metal by default on macOS") + +add_library(binding binding.cpp ${llama_cpp_SOURCE_DIR}/examples/common.cpp) +target_compile_features(binding PRIVATE cxx_std_11) +target_include_directories(binding PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +target_include_directories(binding PRIVATE ${llama_cpp_SOURCE_DIR}) +target_include_directories(binding PRIVATE ${llama_cpp_SOURCE_DIR}/examples) +target_link_libraries(binding llama ggml_static) diff --git a/lib/README.md b/lib/README.md new file mode 100644 index 00000000..1addfbe6 --- /dev/null +++ b/lib/README.md @@ -0,0 +1,10 @@ +# Bindings + +These are Llama.cpp bindings + +## Build + +``` +cmake -S . -B build +cmake --build build +``` diff --git a/lib/binding.cpp b/lib/binding.cpp new file mode 100644 index 00000000..f29afbae --- /dev/null +++ b/lib/binding.cpp @@ -0,0 +1,708 @@ +#include "common.h" +#include "llama.h" + +#include "binding.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) +#include +#include +#elif defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#endif + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || \ + defined(_WIN32) +void sigint_handler(int signo) { + if (signo == SIGINT) { + _exit(130); + } +} +#endif + +int get_embeddings(void *params_ptr, void *state_pr, float *res_embeddings) { + gpt_params *params_p = (gpt_params *)params_ptr; + llama_context *ctx = (llama_context *)state_pr; + gpt_params params = *params_p; + + if (params.seed <= 0) { + params.seed = time(NULL); + } + + std::mt19937 rng(params.seed); + + llama_init_backend(params.numa); + + int n_past = 0; + + // Add a space in front of the first character to match OG llama tokenizer + // behavior + params.prompt.insert(0, 1, ' '); + + // tokenize the prompt + auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); + + // determine newline token + auto llama_token_newline = ::llama_tokenize(ctx, "\n", false); + + if (embd_inp.size() > 0) { + if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, + params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + } + + const int n_embd = llama_n_embd(ctx); + + const auto embeddings = llama_get_embeddings(ctx); + + for (int i = 0; i < n_embd; i++) { + res_embeddings[i] = embeddings[i]; + } + + return 0; +} + +int get_token_embeddings(void *params_ptr, void *state_pr, int *tokens, + int tokenSize, float *res_embeddings) { + gpt_params *params_p = (gpt_params *)params_ptr; + llama_context *ctx = (llama_context *)state_pr; + gpt_params params = *params_p; + + for (int i = 0; i < tokenSize; i++) { + auto token_str = llama_token_to_str(ctx, tokens[i]); + if (token_str == nullptr) { + continue; + } + std::vector my_vector; + std::string str_token(token_str); // create a new std::string from the char* + params_p->prompt += str_token; + } + + return get_embeddings(params_ptr, state_pr, res_embeddings); +} + +int eval(void *params_ptr, void *state_pr, char *text) { + gpt_params *params_p = (gpt_params *)params_ptr; + llama_context *ctx = (llama_context *)state_pr; + + auto n_past = 0; + auto last_n_tokens_data = + std::vector(params_p->repeat_last_n, 0); + + auto tokens = std::vector(params_p->n_ctx); + auto n_prompt_tokens = + llama_tokenize(ctx, text, tokens.data(), tokens.size(), true); + + if (n_prompt_tokens < 1) { + fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); + return 1; + } + + // evaluate prompt + return llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, + params_p->n_threads); +} + +int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug) { + gpt_params *params_p = (gpt_params *)params_ptr; + llama_context *ctx = (llama_context *)state_pr; + + gpt_params params = *params_p; + + const int n_ctx = llama_n_ctx(ctx); + + if (params.seed <= 0) { + params.seed = time(NULL); + } + + std::mt19937 rng(params.seed); + + std::string path_session = params.path_prompt_cache; + std::vector session_tokens; + + if (!path_session.empty()) { + if (debug) { + fprintf(stderr, "%s: attempting to load saved session from '%s'\n", + __func__, path_session.c_str()); + } + // fopen to check for existing session + FILE *fp = std::fopen(path_session.c_str(), "rb"); + if (fp != NULL) { + std::fclose(fp); + + session_tokens.resize(n_ctx); + size_t n_token_count_out = 0; + if (!llama_load_session_file( + ctx, path_session.c_str(), session_tokens.data(), + session_tokens.capacity(), &n_token_count_out)) { + fprintf(stderr, "%s: error: failed to load session file '%s'\n", + __func__, path_session.c_str()); + return 1; + } + session_tokens.resize(n_token_count_out); + llama_set_rng_seed(ctx, params.seed); + if (debug) { + fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", + __func__, (int)session_tokens.size()); + } + } else { + if (debug) { + fprintf(stderr, "%s: session file does not exist, will create\n", + __func__); + } + } + } + + std::vector embd_inp; + if (!params.prompt.empty() || session_tokens.empty()) { + // Add a space in front of the first character to match OG llama tokenizer + // behavior + params.prompt.insert(0, 1, ' '); + + embd_inp = ::llama_tokenize(ctx, params.prompt, true); + } else { + embd_inp = session_tokens; + } + + // debug message about similarity of saved session, if applicable + size_t n_matching_session_tokens = 0; + if (session_tokens.size()) { + for (llama_token id : session_tokens) { + if (n_matching_session_tokens >= embd_inp.size() || + id != embd_inp[n_matching_session_tokens]) { + break; + } + n_matching_session_tokens++; + } + if (debug) { + if (params.prompt.empty() && + n_matching_session_tokens == embd_inp.size()) { + fprintf(stderr, "%s: using full prompt from session file\n", __func__); + } else if (n_matching_session_tokens >= embd_inp.size()) { + fprintf(stderr, "%s: session file has exact match for prompt!\n", + __func__); + } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { + fprintf(stderr, + "%s: warning: session file has low similarity to prompt (%zu / " + "%zu tokens); will mostly be reevaluated\n", + __func__, n_matching_session_tokens, embd_inp.size()); + } else { + fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n", + __func__, n_matching_session_tokens, embd_inp.size()); + } + } + } + // if we will use the cache for the full prompt without reaching the end of + // the cache, force reevaluation of the last token token to recalculate the + // cached logits + if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && + session_tokens.size() > embd_inp.size()) { + session_tokens.resize(embd_inp.size() - 1); + } + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size()) { + params.n_keep = (int)embd_inp.size(); + } + + // determine newline token + auto llama_token_newline = ::llama_tokenize(ctx, "\n", false); + + // TODO: replace with ring-buffer + std::vector last_n_tokens(n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + + bool need_to_save_session = + !path_session.empty() && n_matching_session_tokens < embd_inp.size(); + int n_past = 0; + int n_remain = params.n_predict; + int n_consumed = 0; + int n_session_consumed = 0; + + std::vector embd; + std::string res = ""; + + // do one empty run to warm up the model + { + const std::vector tmp = { + llama_token_bos(), + }; + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); + llama_reset_timings(ctx); + } + + while (n_remain != 0) { + // predict + if (embd.size() > 0) { + // infinite text generation via context swapping + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the + // logits in batches + if (n_past + (int)embd.size() > n_ctx) { + const int n_left = n_past - params.n_keep; + + // always keep the first token - BOS + n_past = std::max(1, params.n_keep); + + // insert n_left/2 tokens at the start of embd from last_n_tokens + embd.insert(embd.begin(), + last_n_tokens.begin() + n_ctx - n_left / 2 - embd.size(), + last_n_tokens.end() - embd.size()); + + // stop saving session if we run out of context + path_session.clear(); + + // printf("\n---\n"); + // printf("resetting: '"); + // for (int i = 0; i < (int) embd.size(); i++) { + // printf("%s", llama_token_to_str(ctx, embd[i])); + // } + // printf("'\n"); + // printf("\n---\n"); + } + + // try to reuse a matching prefix from the loaded session instead of + // re-eval (via n_past) + if (n_session_consumed < (int)session_tokens.size()) { + size_t i = 0; + for (; i < embd.size(); i++) { + if (embd[i] != session_tokens[n_session_consumed]) { + session_tokens.resize(n_session_consumed); + break; + } + + n_past++; + n_session_consumed++; + + if (n_session_consumed >= (int)session_tokens.size()) { + ++i; + break; + } + } + if (i > 0) { + embd.erase(embd.begin(), embd.begin() + i); + } + } + + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not + // always + for (int i = 0; i < (int)embd.size(); i += params.n_batch) { + int n_eval = (int)embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + n_past += n_eval; + } + + if (embd.size() > 0 && !path_session.empty()) { + session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); + n_session_consumed = session_tokens.size(); + } + } + + embd.clear(); + + if ((int)embd_inp.size() <= n_consumed) { + // out of user input, sample next token + const float temp = params.temp; + const int32_t top_k = + params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = + params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + + // optionally save the session on first sample (for faster prompt loading + // next time) + if (!path_session.empty() && need_to_save_session && + !params.prompt_cache_ro) { + need_to_save_session = false; + llama_save_session_file(ctx, path_session.c_str(), + session_tokens.data(), session_tokens.size()); + } + + llama_token id = 0; + + { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); + it++) { + logits[it->first] += it->second; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back( + llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = {candidates.data(), + candidates.size(), false}; + + // Apply penalties + float nl_logit = logits[llama_token_nl()]; + auto last_n_repeat = + std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + llama_sample_repetition_penalty( + ctx, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties( + ctx, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + if (!penalize_nl) { + logits[llama_token_nl()] = nl_logit; + } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } else { + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, + mirostat_eta, mirostat_m, + &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat_v2( + ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); + } + } + // printf("`%d`", candidates_p.size); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + } + + // add it to the context + embd.push_back(id); + + // decrement remaining sampling budget + --n_remain; + + // call the token callback, no need to check if one is actually + // registered, that will be handled on the Go side. + auto token_str = llama_token_to_str(ctx, id); + if (!tokenCallback(state_pr, (char *)token_str)) { + break; + } + } else { + // some user input remains from prompt or interaction, forward it to + // processing + while ((int)embd_inp.size() > n_consumed) { + embd.push_back(embd_inp[n_consumed]); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[n_consumed]); + ++n_consumed; + if ((int)embd.size() >= params.n_batch) { + break; + } + } + } + + for (auto id : embd) { + res += llama_token_to_str(ctx, id); + } + + // check for stop prompt + if (params.antiprompt.size()) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + // Check if each of the reverse prompts appears at the end of the output. + for (std::string &antiprompt : params.antiprompt) { + // size_t extra_padding = params.interactive ? 0 : 2; + size_t extra_padding = 2; + size_t search_start_pos = + last_output.length() > + static_cast(antiprompt.length() + extra_padding) + ? last_output.length() - + static_cast(antiprompt.length() + extra_padding) + : 0; + + if (last_output.find(antiprompt.c_str(), search_start_pos) != + std::string::npos) { + goto end; + } + } + } + + // end of text token + if (!embd.empty() && embd.back() == llama_token_eos()) { + break; + } + } + + if (!path_session.empty() && params.prompt_cache_all && + !params.prompt_cache_ro) { + if (debug) { + fprintf(stderr, "\n%s: saving final output to session file '%s'\n", + __func__, path_session.c_str()); + } + llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), + session_tokens.size()); + } + +end: +#if defined(_WIN32) + signal(SIGINT, SIG_DFL); +#endif + + if (debug) { + llama_print_timings(ctx); + llama_reset_timings(ctx); + } + + strcpy(result, res.c_str()); + return 0; +} + +void llama_binding_free_model(void *state_ptr) { + llama_context *ctx = (llama_context *)state_ptr; + llama_free(ctx); +} + +void llama_free_params(void *params_ptr) { + gpt_params *params = (gpt_params *)params_ptr; + delete params; +} + +std::vector create_vector(const char **strings, int count) { + std::vector *vec = new std::vector; + for (int i = 0; i < count; i++) { + vec->push_back(std::string(strings[i])); + } + return *vec; +} + +void delete_vector(std::vector *vec) { delete vec; } + +int load_state(void *ctx, char *statefile, char *modes) { + llama_context *state = (llama_context *)ctx; + const llama_context *constState = static_cast(state); + const size_t state_size = llama_get_state_size(state); + uint8_t *state_mem = new uint8_t[state_size]; + + { + FILE *fp_read = fopen(statefile, modes); + if (state_size != llama_get_state_size(constState)) { + fprintf(stderr, "\n%s : failed to validate state size\n", __func__); + return 1; + } + + const size_t ret = fread(state_mem, 1, state_size, fp_read); + if (ret != state_size) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + return 1; + } + + llama_set_state_data( + state, state_mem); // could also read directly from memory mapped file + fclose(fp_read); + } + + return 0; +} + +void save_state(void *ctx, char *dst, char *modes) { + llama_context *state = (llama_context *)ctx; + + const size_t state_size = llama_get_state_size(state); + uint8_t *state_mem = new uint8_t[state_size]; + + // Save state (rng, logits, embedding and kv_cache) to file + { + FILE *fp_write = fopen(dst, modes); + llama_copy_state_data( + state, state_mem); // could also copy directly to memory mapped file + fwrite(state_mem, 1, state_size, fp_write); + fclose(fp_write); + } +} + +void *llama_allocate_params( + const char *prompt, int seed, int threads, int tokens, int top_k, + float top_p, float temp, float repeat_penalty, int repeat_last_n, + bool ignore_eos, bool memory_f16, int n_batch, int n_keep, + const char **antiprompt, int antiprompt_count, float tfs_z, float typical_p, + float frequency_penalty, float presence_penalty, int mirostat, + float mirostat_eta, float mirostat_tau, bool penalize_nl, + const char *logit_bias, const char *session_file, bool prompt_cache_all, + bool mlock, bool mmap, const char *maingpu, const char *tensorsplit, + bool prompt_cache_ro) { + gpt_params *params = new gpt_params; + params->seed = seed; + params->n_threads = threads; + params->n_predict = tokens; + params->repeat_last_n = repeat_last_n; + params->prompt_cache_ro = prompt_cache_ro; + params->top_k = top_k; + params->top_p = top_p; + params->memory_f16 = memory_f16; + params->temp = temp; + params->use_mmap = mmap; + params->use_mlock = mlock; + params->repeat_penalty = repeat_penalty; + params->n_batch = n_batch; + params->n_keep = n_keep; + if (maingpu[0] != '\0') { + params->main_gpu = std::stoi(maingpu); + } + + if (tensorsplit[0] != '\0') { + std::string arg_next = tensorsplit; + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + + for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { + if (i < split_arg.size()) { + params->tensor_split[i] = std::stof(split_arg[i]); + } else { + params->tensor_split[i] = 0.0f; + } + } + } + + params->prompt_cache_all = prompt_cache_all; + params->path_prompt_cache = session_file; + + if (ignore_eos) { + params->logit_bias[llama_token_eos()] = -INFINITY; + } + if (antiprompt_count > 0) { + params->antiprompt = create_vector(antiprompt, antiprompt_count); + } + params->tfs_z = tfs_z; + params->typical_p = typical_p; + params->presence_penalty = presence_penalty; + params->mirostat = mirostat; + params->mirostat_eta = mirostat_eta; + params->mirostat_tau = mirostat_tau; + params->penalize_nl = penalize_nl; + std::stringstream ss(logit_bias); + llama_token key; + char sign; + std::string value_str; + if (ss >> key && ss >> sign && std::getline(ss, value_str) && + (sign == '+' || sign == '-')) { + params->logit_bias[key] = + std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + } + params->frequency_penalty = frequency_penalty; + params->prompt = prompt; + + return params; +} + +void *load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, + bool mlock, bool embeddings, bool mmap, bool low_vram, + bool vocab_only, int n_gpu_layers, int n_batch, + const char *maingpu, const char *tensorsplit, bool numa) { + // load the model + auto lparams = llama_context_default_params(); + + lparams.n_ctx = n_ctx; + lparams.seed = n_seed; + lparams.f16_kv = memory_f16; + lparams.embedding = embeddings; + lparams.use_mlock = mlock; + lparams.n_gpu_layers = n_gpu_layers; + lparams.use_mmap = mmap; + lparams.low_vram = low_vram; + lparams.vocab_only = vocab_only; + + if (maingpu[0] != '\0') { + lparams.main_gpu = std::stoi(maingpu); + } + + if (tensorsplit[0] != '\0') { + std::string arg_next = tensorsplit; + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + + for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { + if (i < split_arg.size()) { + lparams.tensor_split[i] = std::stof(split_arg[i]); + } else { + lparams.tensor_split[i] = 0.0f; + } + } + } + + lparams.n_batch = n_batch; + + llama_init_backend(numa); + void *res = nullptr; + try { + llama_model *model = llama_load_model_from_file(fname, lparams); + if (model == NULL) { + fprintf(stderr, "error: failed to load model \n"); + return res; + } + + llama_context *lctx = llama_new_context_with_model(model, lparams); + if (lctx == NULL) { + fprintf(stderr, "error: failed to create context with model \n"); + llama_free_model(model); + return res; + } + + } catch (std::runtime_error &e) { + fprintf(stderr, "failed %s", e.what()); + return res; + } + + return res; +} diff --git a/lib/binding.h b/lib/binding.h new file mode 100644 index 00000000..7bf02a1a --- /dev/null +++ b/lib/binding.h @@ -0,0 +1,41 @@ +#ifdef __cplusplus +#include +#include +extern "C" { +#endif + +#include + +extern unsigned char tokenCallback(void *, char *); + +int load_state(void *ctx, char *statefile, char*modes); + +int eval(void* params_ptr, void *ctx, char*text); + +void save_state(void *ctx, char *dst, char*modes); + +void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, bool vocab_only, int n_gpu, int n_batch, const char *maingpu, const char *tensorsplit, bool numa); + +int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings); + +int get_token_embeddings(void* params_ptr, void* state_pr, int *tokens, int tokenSize, float * res_embeddings); + +void* llama_allocate_params(const char *prompt, int seed, int threads, int tokens, + int top_k, float top_p, float temp, float repeat_penalty, + int repeat_last_n, bool ignore_eos, bool memory_f16, + int n_batch, int n_keep, const char** antiprompt, int antiprompt_count, + float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu, const char *tensorsplit , bool prompt_cache_ro); + +void llama_free_params(void* params_ptr); + +void llama_binding_free_model(void* state); + +int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug); + +#ifdef __cplusplus +} + + +std::vector create_vector(const char** strings, int count); +void delete_vector(std::vector* vec); +#endif diff --git a/main.go b/main.go new file mode 100644 index 00000000..a7759c5e --- /dev/null +++ b/main.go @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/ollama/ollama/cmd" +) + +func main() { + cmd.NewCLI().Execute() +} diff --git a/python/README.md b/python/README.md new file mode 100644 index 00000000..b871efe7 --- /dev/null +++ b/python/README.md @@ -0,0 +1,39 @@ +# Ollama Python bindings + +``` +pip install ollama +``` + +## Developing + +Ollama is built using Python 3 and uses [Poetry](https://python-poetry.org/) to manage dependencies and build packages. + +``` +pip install poetry +``` + +Install ollama and its dependencies: + +``` +poetry install --extras server --with dev +``` + +Run ollama server: + +``` +poetry run ollama server +``` + +Update dependencies: + +``` +poetry update --extras server --with dev +poetry lock +poetry export >requirements.txt +``` + +Build binary package: + +``` +poetry build +``` diff --git a/ollama/__init__.py b/python/ollama/__init__.py similarity index 100% rename from ollama/__init__.py rename to python/ollama/__init__.py diff --git a/ollama/__main__.py b/python/ollama/__main__.py similarity index 100% rename from ollama/__main__.py rename to python/ollama/__main__.py diff --git a/ollama/cmd/__init__.py b/python/ollama/cmd/__init__.py similarity index 100% rename from ollama/cmd/__init__.py rename to python/ollama/cmd/__init__.py diff --git a/ollama/cmd/cli.py b/python/ollama/cmd/cli.py similarity index 100% rename from ollama/cmd/cli.py rename to python/ollama/cmd/cli.py diff --git a/ollama/cmd/server.py b/python/ollama/cmd/server.py similarity index 100% rename from ollama/cmd/server.py rename to python/ollama/cmd/server.py diff --git a/ollama/engine.py b/python/ollama/engine.py similarity index 100% rename from ollama/engine.py rename to python/ollama/engine.py diff --git a/ollama/model.py b/python/ollama/model.py similarity index 100% rename from ollama/model.py rename to python/ollama/model.py diff --git a/ollama/prompt.py b/python/ollama/prompt.py similarity index 100% rename from ollama/prompt.py rename to python/ollama/prompt.py diff --git a/ollama/templates/alpaca.prompt b/python/ollama/templates/alpaca.prompt similarity index 100% rename from ollama/templates/alpaca.prompt rename to python/ollama/templates/alpaca.prompt diff --git a/ollama/templates/falcon.prompt b/python/ollama/templates/falcon.prompt similarity index 100% rename from ollama/templates/falcon.prompt rename to python/ollama/templates/falcon.prompt diff --git a/ollama/templates/gpt4.prompt b/python/ollama/templates/gpt4.prompt similarity index 100% rename from ollama/templates/gpt4.prompt rename to python/ollama/templates/gpt4.prompt diff --git a/ollama/templates/hermes.prompt b/python/ollama/templates/hermes.prompt similarity index 100% rename from ollama/templates/hermes.prompt rename to python/ollama/templates/hermes.prompt diff --git a/ollama/templates/mpt.prompt b/python/ollama/templates/mpt.prompt similarity index 100% rename from ollama/templates/mpt.prompt rename to python/ollama/templates/mpt.prompt diff --git a/ollama/templates/oasst.prompt b/python/ollama/templates/oasst.prompt similarity index 100% rename from ollama/templates/oasst.prompt rename to python/ollama/templates/oasst.prompt diff --git a/ollama/templates/orca.prompt b/python/ollama/templates/orca.prompt similarity index 100% rename from ollama/templates/orca.prompt rename to python/ollama/templates/orca.prompt diff --git a/ollama/templates/qlora.prompt b/python/ollama/templates/qlora.prompt similarity index 100% rename from ollama/templates/qlora.prompt rename to python/ollama/templates/qlora.prompt diff --git a/ollama/templates/tulu.prompt b/python/ollama/templates/tulu.prompt similarity index 100% rename from ollama/templates/tulu.prompt rename to python/ollama/templates/tulu.prompt diff --git a/ollama/templates/ultralm.prompt b/python/ollama/templates/ultralm.prompt similarity index 100% rename from ollama/templates/ultralm.prompt rename to python/ollama/templates/ultralm.prompt diff --git a/ollama/templates/vicuna.prompt b/python/ollama/templates/vicuna.prompt similarity index 100% rename from ollama/templates/vicuna.prompt rename to python/ollama/templates/vicuna.prompt diff --git a/ollama/templates/wizardcoder.prompt b/python/ollama/templates/wizardcoder.prompt similarity index 100% rename from ollama/templates/wizardcoder.prompt rename to python/ollama/templates/wizardcoder.prompt diff --git a/ollama/templates/wizardlm.prompt b/python/ollama/templates/wizardlm.prompt similarity index 100% rename from ollama/templates/wizardlm.prompt rename to python/ollama/templates/wizardlm.prompt diff --git a/poetry.lock b/python/poetry.lock similarity index 100% rename from poetry.lock rename to python/poetry.lock diff --git a/pyproject.toml b/python/pyproject.toml similarity index 100% rename from pyproject.toml rename to python/pyproject.toml diff --git a/requirements.txt b/python/requirements.txt similarity index 100% rename from requirements.txt rename to python/requirements.txt diff --git a/server/routes.go b/server/routes.go new file mode 100644 index 00000000..f8d2bd72 --- /dev/null +++ b/server/routes.go @@ -0,0 +1,82 @@ +package server + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + "runtime" + + "github.com/gin-gonic/gin" + llama "github.com/go-skynet/go-llama.cpp" + + "github.com/ollama/ollama/api" +) + +func Serve(ln net.Listener) error { + r := gin.Default() + + var l *llama.LLama + + gpulayers := 1 + tokens := 512 + threads := runtime.NumCPU() + model := "/Users/pdevine/.cache/gpt4all/GPT4All-13B-snoozy.ggmlv3.q4_0.bin" + + r.POST("/api/load", func(c *gin.Context) { + var err error + l, err = llama.New(model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) + if err != nil { + fmt.Println("Loading the model failed:", err.Error()) + } + }) + + r.POST("/api/unload", func(c *gin.Context) { + }) + + r.POST("/api/generate", func(c *gin.Context) { + var req api.GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + + ch := make(chan string) + + go func() { + defer close(ch) + _, err := l.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool { + ch <- token + return true + }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama")) + if err != nil { + panic(err) + } + }() + + c.Stream(func(w io.Writer) bool { + tok, ok := <-ch + if !ok { + return false + } + c.SSEvent("token", tok) + return true + }) + +/* + embeds, err := l.Embeddings(text) + if err != nil { + fmt.Printf("Embeddings: error %s \n", err.Error()) + } +*/ + + }) + + log.Printf("Listening on %s", ln.Addr()) + s := &http.Server{ + Handler: r, + } + + return s.Serve(ln) +} diff --git a/signature/signature.go b/signature/signature.go new file mode 100644 index 00000000..5f228144 --- /dev/null +++ b/signature/signature.go @@ -0,0 +1,63 @@ +package signature + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + + "golang.org/x/crypto/ssh" +) + +type SignatureData struct { + Method string + Path string + Data []byte +} + +func GetBytesToSign(s SignatureData) []byte { + // contentHash = base64(hex(sha256(s.Data))) + hash := sha256.Sum256(s.Data) + hashHex := make([]byte, hex.EncodedLen(len(hash))) + hex.Encode(hashHex, hash[:]) + contentHash := base64.StdEncoding.EncodeToString(hashHex) + + // bytesToSign e.g.: "GET,http://localhost,OTdkZjM1O... + bytesToSign := []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ",")) + + return bytesToSign +} + +// SignData takes a SignatureData object and signs it with a raw private key +func SignAuthData(s SignatureData, rawKey []byte) (string, error) { + bytesToSign := GetBytesToSign(s) + + // TODO replace this w/ a non-SSH based private key + privateKey, err := ssh.ParseRawPrivateKey(rawKey) + if err != nil { + return "", err + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + return "", err + } + + // get the pubkey, but remove the type + pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey()) + parts := bytes.Split(pubKey, []byte(" ")) + if len(parts) < 2 { + return "", fmt.Errorf("malformed private key") + } + + signedData, err := signer.Sign(nil, bytesToSign) + if err != nil { + return "", err + } + + // signature is : + sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)) + return sig, nil +}