diff --git a/README.md b/README.md index 6ebe2827..9495a911 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Ollama -A fast runtime for large language models, powered by [llama.cpp](https://github.com/ggerganov/llama.cpp). +An easy, fast runtime for large language models, powered by `llama.cpp`. > _Note: this project is a work in progress. Certain models that can be run with `ollama` are intended for research and/or non-commercial use only._ @@ -38,6 +38,13 @@ Or directly via downloaded model files: ollama run ~/Downloads/orca-mini-13b.ggmlv3.q4_0.bin ``` +## Building + +``` +go generate ./... +go build . +``` + ## Documentation - [Development](docs/development.md) diff --git a/api/client.go b/api/client.go index c653fbaa..e445e589 100644 --- a/api/client.go +++ b/api/client.go @@ -8,7 +8,7 @@ import ( "io" "net/http" - "github.com/ollama/ollama/signature" + "github.com/jmorganca/ollama/signature" ) type Client struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index b6a45258..c8e5b4c4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -3,7 +3,6 @@ package cmd import ( "context" "fmt" - "io/ioutil" "log" "net" "net/http" @@ -13,8 +12,8 @@ import ( "github.com/spf13/cobra" - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/server" + "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/server" ) func NewAPIClient(cmd *cobra.Command) (*api.Client, error) { @@ -36,7 +35,7 @@ func NewAPIClient(cmd *cobra.Command) (*api.Client, error) { if k != "" { fn := path.Join(home, ".ollama/keys/", k) - rawKey, err = ioutil.ReadFile(fn) + rawKey, err = os.ReadFile(fn) if err != nil { return nil, err } @@ -59,7 +58,7 @@ func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile) rootCmd := &cobra.Command{ - Use: "gollama", + Use: "ollama", Short: "Run any large language model on any machine.", CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, diff --git a/go.mod b/go.mod index d79b0c19..219e14cf 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,9 @@ -module github.com/ollama/ollama +module github.com/jmorganca/ollama go 1.20 require ( github.com/gin-gonic/gin v1.9.1 - github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144 - github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc github.com/spf13/cobra v1.7.0 golang.org/x/crypto v0.10.0 ) @@ -19,6 +17,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect @@ -35,6 +34,5 @@ require ( golang.org/x/sys v0.9.0 // indirect golang.org/x/text v0.10.0 // indirect google.golang.org/protobuf v1.30.0 // indirect - gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 36c0acc6..7be172ba 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,7 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhD github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= @@ -13,18 +14,19 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= -github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144 h1:fszkmZG3pW9/bqhuWB6sfJMArJPx1RPzjZSqNdhuSQ0= -github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -44,8 +46,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc/go.mod h1:S8xSOnV3CgpNrWd0GQ/OoQfMtlg2uPRSuTzcSGrzwK8= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= @@ -55,12 +57,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= @@ -69,27 +71,23 @@ github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= -golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/.gitignore b/lib/.gitignore deleted file mode 100644 index 378eac25..00000000 --- a/lib/.gitignore +++ /dev/null @@ -1 +0,0 @@ -build diff --git a/lib/README.md b/lib/README.md deleted file mode 100644 index 1addfbe6..00000000 --- a/lib/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Bindings - -These are Llama.cpp bindings - -## Build - -``` -cmake -S . -B build -cmake --build build -``` diff --git a/lib/binding.h b/lib/binding.h deleted file mode 100644 index 7bf02a1a..00000000 --- a/lib/binding.h +++ /dev/null @@ -1,41 +0,0 @@ -#ifdef __cplusplus -#include -#include -extern "C" { -#endif - -#include - -extern unsigned char tokenCallback(void *, char *); - -int load_state(void *ctx, char *statefile, char*modes); - -int eval(void* params_ptr, void *ctx, char*text); - -void save_state(void *ctx, char *dst, char*modes); - -void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, bool vocab_only, int n_gpu, int n_batch, const char *maingpu, const char *tensorsplit, bool numa); - -int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings); - -int get_token_embeddings(void* params_ptr, void* state_pr, int *tokens, int tokenSize, float * res_embeddings); - -void* llama_allocate_params(const char *prompt, int seed, int threads, int tokens, - int top_k, float top_p, float temp, float repeat_penalty, - int repeat_last_n, bool ignore_eos, bool memory_f16, - int n_batch, int n_keep, const char** antiprompt, int antiprompt_count, - float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu, const char *tensorsplit , bool prompt_cache_ro); - -void llama_free_params(void* params_ptr); - -void llama_binding_free_model(void* state); - -int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug); - -#ifdef __cplusplus -} - - -std::vector create_vector(const char** strings, int count); -void delete_vector(std::vector* vec); -#endif diff --git a/llama/.gitignore b/llama/.gitignore new file mode 100644 index 00000000..c795b054 --- /dev/null +++ b/llama/.gitignore @@ -0,0 +1 @@ +build \ No newline at end of file diff --git a/lib/CMakeLists.txt b/llama/CMakeLists.txt similarity index 56% rename from lib/CMakeLists.txt rename to llama/CMakeLists.txt index 4c5b96dd..b2ef7b39 100644 --- a/lib/CMakeLists.txt +++ b/llama/CMakeLists.txt @@ -9,13 +9,19 @@ FetchContent_Declare( FetchContent_MakeAvailable(llama_cpp) +if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(LLAMA_METAL ON) + add_compile_definitions(GGML_USE_METAL) +endif() + project(binding) -set(LLAMA_METAL ON CACHE BOOL "Enable Llama Metal by default on macOS") - -add_library(binding binding.cpp ${llama_cpp_SOURCE_DIR}/examples/common.cpp) +add_library(binding ${CMAKE_CURRENT_SOURCE_DIR}/binding/binding.cpp ${llama_cpp_SOURCE_DIR}/examples/common.cpp) target_compile_features(binding PRIVATE cxx_std_11) target_include_directories(binding PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(binding PRIVATE ${llama_cpp_SOURCE_DIR}) target_include_directories(binding PRIVATE ${llama_cpp_SOURCE_DIR}/examples) target_link_libraries(binding llama ggml_static) + +configure_file(${llama_cpp_BINARY_DIR}/libllama.a ${CMAKE_CURRENT_BINARY_DIR}/libllama.a COPYONLY) +configure_file(${llama_cpp_BINARY_DIR}/libggml_static.a ${CMAKE_CURRENT_BINARY_DIR}/libggml_static.a COPYONLY) diff --git a/lib/binding.cpp b/llama/binding/binding.cpp similarity index 95% rename from lib/binding.cpp rename to llama/binding/binding.cpp index f29afbae..eff84d45 100644 --- a/lib/binding.cpp +++ b/llama/binding/binding.cpp @@ -1,3 +1,25 @@ +// MIT License + +// Copyright (c) 2023 go-skynet authors + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + #include "common.h" #include "llama.h" diff --git a/llama/binding/binding.h b/llama/binding/binding.h new file mode 100644 index 00000000..56dd4d8e --- /dev/null +++ b/llama/binding/binding.h @@ -0,0 +1,71 @@ +// MIT License + +// Copyright (c) 2023 go-skynet authors + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifdef __cplusplus +#include +#include +extern "C" { +#endif + +#include + +extern unsigned char tokenCallback(void *, char *); + +int load_state(void *ctx, char *statefile, char *modes); + +int eval(void *params_ptr, void *ctx, char *text); + +void save_state(void *ctx, char *dst, char *modes); + +void *load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, + bool mlock, bool embeddings, bool mmap, bool low_vram, + bool vocab_only, int n_gpu, int n_batch, const char *maingpu, + const char *tensorsplit, bool numa); + +int get_embeddings(void *params_ptr, void *state_pr, float *res_embeddings); + +int get_token_embeddings(void *params_ptr, void *state_pr, int *tokens, + int tokenSize, float *res_embeddings); + +void *llama_allocate_params( + const char *prompt, int seed, int threads, int tokens, int top_k, + float top_p, float temp, float repeat_penalty, int repeat_last_n, + bool ignore_eos, bool memory_f16, int n_batch, int n_keep, + const char **antiprompt, int antiprompt_count, float tfs_z, float typical_p, + float frequency_penalty, float presence_penalty, int mirostat, + float mirostat_eta, float mirostat_tau, bool penalize_nl, + const char *logit_bias, const char *session_file, bool prompt_cache_all, + bool mlock, bool mmap, const char *maingpu, const char *tensorsplit, + bool prompt_cache_ro); + +void llama_free_params(void *params_ptr); + +void llama_binding_free_model(void *state); + +int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug); + +#ifdef __cplusplus +} + +std::vector create_vector(const char **strings, int count); +void delete_vector(std::vector *vec); +#endif diff --git a/llama/llama.go b/llama/llama.go new file mode 100644 index 00000000..9da5de3e --- /dev/null +++ b/llama/llama.go @@ -0,0 +1,302 @@ +// MIT License + +// Copyright (c) 2023 go-skynet authors + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +//go:generate cmake -S . -B build +//go:generate cmake --build build +package llama + +// #cgo LDFLAGS: -Lbuild -lbinding -lllama -lggml_static -lstdc++ +// #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders +// #cgo darwin CXXFLAGS: -std=c++11 +// #include "binding/binding.h" +import "C" +import ( + "fmt" + "os" + "strings" + "sync" + "unsafe" +) + +type LLama struct { + state unsafe.Pointer + embeddings bool + contextSize int +} + +func New(model string, opts ...ModelOption) (*LLama, error) { + mo := NewModelOptions(opts...) + modelPath := C.CString(model) + result := C.load_model(modelPath, C.int(mo.ContextSize), C.int(mo.Seed), C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.bool(mo.LowVRAM), C.bool(mo.VocabOnly), C.int(mo.NGPULayers), C.int(mo.NBatch), C.CString(mo.MainGPU), C.CString(mo.TensorSplit), C.bool(mo.NUMA)) + if result == nil { + return nil, fmt.Errorf("failed loading model") + } + + ll := &LLama{state: result, contextSize: mo.ContextSize, embeddings: mo.Embeddings} + + return ll, nil +} + +func (l *LLama) Free() { + C.llama_binding_free_model(l.state) +} + +func (l *LLama) LoadState(state string) error { + d := C.CString(state) + w := C.CString("rb") + + result := C.load_state(l.state, d, w) + if result != 0 { + return fmt.Errorf("error while loading state") + } + + return nil +} + +func (l *LLama) SaveState(dst string) error { + d := C.CString(dst) + w := C.CString("wb") + + C.save_state(l.state, d, w) + + _, err := os.Stat(dst) + return err +} + +// Token Embeddings +func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32, error) { + if !l.embeddings { + return []float32{}, fmt.Errorf("model loaded without embeddings") + } + + po := NewPredictOptions(opts...) + + outSize := po.Tokens + if po.Tokens == 0 { + outSize = 9999999 + } + + floats := make([]float32, outSize) + + myArray := (*C.int)(C.malloc(C.size_t(len(tokens)) * C.sizeof_int)) + + // Copy the values from the Go slice to the C array + for i, v := range tokens { + (*[1<<31 - 1]int32)(unsafe.Pointer(myArray))[i] = int32(v) + } + + params := C.llama_allocate_params(C.CString(""), C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), + C.bool(po.IgnoreEOS), C.bool(po.F16KV), + C.int(po.Batch), C.int(po.NKeep), nil, C.int(0), + C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), + C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.bool(po.PromptCacheRO), + ) + ret := C.get_token_embeddings(params, l.state, myArray, C.int(len(tokens)), (*C.float)(&floats[0])) + if ret != 0 { + return floats, fmt.Errorf("embedding inference failed") + } + return floats, nil +} + +// Embeddings +func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error) { + if !l.embeddings { + return []float32{}, fmt.Errorf("model loaded without embeddings") + } + + po := NewPredictOptions(opts...) + + input := C.CString(text) + if po.Tokens == 0 { + po.Tokens = 99999999 + } + floats := make([]float32, po.Tokens) + reverseCount := len(po.StopPrompts) + reversePrompt := make([]*C.char, reverseCount) + var pass **C.char + for i, s := range po.StopPrompts { + cs := C.CString(s) + reversePrompt[i] = cs + pass = &reversePrompt[0] + } + + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), + C.bool(po.IgnoreEOS), C.bool(po.F16KV), + C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), + C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), + C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.bool(po.PromptCacheRO), + ) + + ret := C.get_embeddings(params, l.state, (*C.float)(&floats[0])) + if ret != 0 { + return floats, fmt.Errorf("embedding inference failed") + } + + return floats, nil +} + +func (l *LLama) Eval(text string, opts ...PredictOption) error { + po := NewPredictOptions(opts...) + + input := C.CString(text) + if po.Tokens == 0 { + po.Tokens = 99999999 + } + + reverseCount := len(po.StopPrompts) + reversePrompt := make([]*C.char, reverseCount) + var pass **C.char + for i, s := range po.StopPrompts { + cs := C.CString(s) + reversePrompt[i] = cs + pass = &reversePrompt[0] + } + + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), + C.bool(po.IgnoreEOS), C.bool(po.F16KV), + C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), + C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), + C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.bool(po.PromptCacheRO), + ) + ret := C.eval(params, l.state, input) + if ret != 0 { + return fmt.Errorf("inference failed") + } + + C.llama_free_params(params) + + return nil +} + +func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { + po := NewPredictOptions(opts...) + + if po.TokenCallback != nil { + setCallback(l.state, po.TokenCallback) + } + + input := C.CString(text) + if po.Tokens == 0 { + po.Tokens = 99999999 + } + out := make([]byte, po.Tokens) + + reverseCount := len(po.StopPrompts) + reversePrompt := make([]*C.char, reverseCount) + var pass **C.char + for i, s := range po.StopPrompts { + cs := C.CString(s) + reversePrompt[i] = cs + pass = &reversePrompt[0] + } + + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), + C.bool(po.IgnoreEOS), C.bool(po.F16KV), + C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), + C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), + C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.bool(po.PromptCacheRO), + ) + ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode)) + if ret != 0 { + return "", fmt.Errorf("inference failed") + } + res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) + + res = strings.TrimPrefix(res, " ") + res = strings.TrimPrefix(res, text) + res = strings.TrimPrefix(res, "\n") + + for _, s := range po.StopPrompts { + res = strings.TrimRight(res, s) + } + + C.llama_free_params(params) + + if po.TokenCallback != nil { + setCallback(l.state, nil) + } + + return res, nil +} + +// CGo only allows us to use static calls from C to Go, we can't just dynamically pass in func's. +// This is the next best thing, we register the callbacks in this map and call tokenCallback from +// the C code. We also attach a finalizer to LLama, so it will unregister the callback when the +// garbage collection frees it. + +// SetTokenCallback registers a callback for the individual tokens created when running Predict. It +// will be called once for each token. The callback shall return true as long as the model should +// continue predicting the next token. When the callback returns false the predictor will return. +// The tokens are just converted into Go strings, they are not trimmed or otherwise changed. Also +// the tokens may not be valid UTF-8. +// Pass in nil to remove a callback. +// +// It is save to call this method while a prediction is running. +func (l *LLama) SetTokenCallback(callback func(token string) bool) { + setCallback(l.state, callback) +} + +var ( + m sync.Mutex + callbacks = map[uintptr]func(string) bool{} +) + +//export tokenCallback +func tokenCallback(statePtr unsafe.Pointer, token *C.char) bool { + m.Lock() + defer m.Unlock() + + if callback, ok := callbacks[uintptr(statePtr)]; ok { + return callback(C.GoString(token)) + } + + return true +} + +// setCallback can be used to register a token callback for LLama. Pass in a nil callback to +// remove the callback. +func setCallback(statePtr unsafe.Pointer, callback func(string) bool) { + m.Lock() + defer m.Unlock() + + if callback == nil { + delete(callbacks, uintptr(statePtr)) + } else { + callbacks[uintptr(statePtr)] = callback + } +} diff --git a/llama/llama_cublas.go b/llama/llama_cublas.go new file mode 100644 index 00000000..efd15192 --- /dev/null +++ b/llama/llama_cublas.go @@ -0,0 +1,9 @@ +//go:build cublas +// +build cublas + +package llama + +/* +#cgo LDFLAGS: -lcublas -lcudart -L/usr/local/cuda/lib64/ +*/ +import "C" diff --git a/llama/llama_openblas.go b/llama/llama_openblas.go new file mode 100644 index 00000000..31e09f7e --- /dev/null +++ b/llama/llama_openblas.go @@ -0,0 +1,9 @@ +//go:build openblas +// +build openblas + +package llama + +/* +#cgo LDFLAGS: -lopenblas +*/ +import "C" diff --git a/llama/options.go b/llama/options.go new file mode 100644 index 00000000..3cc72a53 --- /dev/null +++ b/llama/options.go @@ -0,0 +1,392 @@ +// MIT License + +// Copyright (c) 2023 go-skynet authors + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package llama + +type ModelOptions struct { + ContextSize int + Seed int + NBatch int + F16Memory bool + MLock bool + MMap bool + VocabOnly bool + LowVRAM bool + Embeddings bool + NUMA bool + NGPULayers int + MainGPU string + TensorSplit string +} + +type PredictOptions struct { + Seed, Threads, Tokens, TopK, Repeat, Batch, NKeep int + TopP, Temperature, Penalty float64 + F16KV bool + DebugMode bool + StopPrompts []string + IgnoreEOS bool + + TailFreeSamplingZ float64 + TypicalP float64 + FrequencyPenalty float64 + PresencePenalty float64 + Mirostat int + MirostatETA float64 + MirostatTAU float64 + PenalizeNL bool + LogitBias string + TokenCallback func(string) bool + + PathPromptCache string + MLock, MMap, PromptCacheAll bool + PromptCacheRO bool + MainGPU string + TensorSplit string +} + +type PredictOption func(p *PredictOptions) + +type ModelOption func(p *ModelOptions) + +var DefaultModelOptions ModelOptions = ModelOptions{ + ContextSize: 512, + Seed: 0, + F16Memory: false, + MLock: false, + Embeddings: false, + MMap: true, + LowVRAM: false, +} + +var DefaultOptions PredictOptions = PredictOptions{ + Seed: -1, + Threads: 4, + Tokens: 128, + Penalty: 1.1, + Repeat: 64, + Batch: 512, + NKeep: 64, + TopK: 40, + TopP: 0.95, + TailFreeSamplingZ: 1.0, + TypicalP: 1.0, + Temperature: 0.8, + FrequencyPenalty: 0.0, + PresencePenalty: 0.0, + Mirostat: 0, + MirostatTAU: 5.0, + MirostatETA: 0.1, + MMap: true, +} + +// SetContext sets the context size. +func SetContext(c int) ModelOption { + return func(p *ModelOptions) { + p.ContextSize = c + } +} + +func SetModelSeed(c int) ModelOption { + return func(p *ModelOptions) { + p.Seed = c + } +} + +// SetContext sets the context size. +func SetMMap(b bool) ModelOption { + return func(p *ModelOptions) { + p.MMap = b + } +} + +// SetNBatch sets the n_Batch +func SetNBatch(n_batch int) ModelOption { + return func(p *ModelOptions) { + p.NBatch = n_batch + } +} + +// Set sets the tensor split for the GPU +func SetTensorSplit(maingpu string) ModelOption { + return func(p *ModelOptions) { + p.TensorSplit = maingpu + } +} + +// SetMainGPU sets the main_gpu +func SetMainGPU(maingpu string) ModelOption { + return func(p *ModelOptions) { + p.MainGPU = maingpu + } +} + +// SetPredictionTensorSplit sets the tensor split for the GPU +func SetPredictionTensorSplit(maingpu string) PredictOption { + return func(p *PredictOptions) { + p.TensorSplit = maingpu + } +} + +// SetPredictionMainGPU sets the main_gpu +func SetPredictionMainGPU(maingpu string) PredictOption { + return func(p *PredictOptions) { + p.MainGPU = maingpu + } +} + +var VocabOnly ModelOption = func(p *ModelOptions) { + p.VocabOnly = true +} + +var EnabelLowVRAM ModelOption = func(p *ModelOptions) { + p.LowVRAM = true +} + +var EnableNUMA ModelOption = func(p *ModelOptions) { + p.NUMA = true +} + +var EnableEmbeddings ModelOption = func(p *ModelOptions) { + p.Embeddings = true +} + +var EnableF16Memory ModelOption = func(p *ModelOptions) { + p.F16Memory = true +} + +var EnableF16KV PredictOption = func(p *PredictOptions) { + p.F16KV = true +} + +var Debug PredictOption = func(p *PredictOptions) { + p.DebugMode = true +} + +var EnablePromptCacheAll PredictOption = func(p *PredictOptions) { + p.PromptCacheAll = true +} + +var EnablePromptCacheRO PredictOption = func(p *PredictOptions) { + p.PromptCacheRO = true +} + +var EnableMLock ModelOption = func(p *ModelOptions) { + p.MLock = true +} + +// Create a new PredictOptions object with the given options. +func NewModelOptions(opts ...ModelOption) ModelOptions { + p := DefaultModelOptions + for _, opt := range opts { + opt(&p) + } + return p +} + +var IgnoreEOS PredictOption = func(p *PredictOptions) { + p.IgnoreEOS = true +} + +// SetMlock sets the memory lock. +func SetMlock(b bool) PredictOption { + return func(p *PredictOptions) { + p.MLock = b + } +} + +// SetMemoryMap sets memory mapping. +func SetMemoryMap(b bool) PredictOption { + return func(p *PredictOptions) { + p.MMap = b + } +} + +// SetGPULayers sets the number of GPU layers to use to offload computation +func SetGPULayers(n int) ModelOption { + return func(p *ModelOptions) { + p.NGPULayers = n + } +} + +// SetTokenCallback sets the prompts that will stop predictions. +func SetTokenCallback(fn func(string) bool) PredictOption { + return func(p *PredictOptions) { + p.TokenCallback = fn + } +} + +// SetStopWords sets the prompts that will stop predictions. +func SetStopWords(stop ...string) PredictOption { + return func(p *PredictOptions) { + p.StopPrompts = stop + } +} + +// SetSeed sets the random seed for sampling text generation. +func SetSeed(seed int) PredictOption { + return func(p *PredictOptions) { + p.Seed = seed + } +} + +// SetThreads sets the number of threads to use for text generation. +func SetThreads(threads int) PredictOption { + return func(p *PredictOptions) { + p.Threads = threads + } +} + +// SetTokens sets the number of tokens to generate. +func SetTokens(tokens int) PredictOption { + return func(p *PredictOptions) { + p.Tokens = tokens + } +} + +// SetTopK sets the value for top-K sampling. +func SetTopK(topk int) PredictOption { + return func(p *PredictOptions) { + p.TopK = topk + } +} + +// SetTopP sets the value for nucleus sampling. +func SetTopP(topp float64) PredictOption { + return func(p *PredictOptions) { + p.TopP = topp + } +} + +// SetTemperature sets the temperature value for text generation. +func SetTemperature(temp float64) PredictOption { + return func(p *PredictOptions) { + p.Temperature = temp + } +} + +// SetPathPromptCache sets the session file to store the prompt cache. +func SetPathPromptCache(f string) PredictOption { + return func(p *PredictOptions) { + p.PathPromptCache = f + } +} + +// SetPenalty sets the repetition penalty for text generation. +func SetPenalty(penalty float64) PredictOption { + return func(p *PredictOptions) { + p.Penalty = penalty + } +} + +// SetRepeat sets the number of times to repeat text generation. +func SetRepeat(repeat int) PredictOption { + return func(p *PredictOptions) { + p.Repeat = repeat + } +} + +// SetBatch sets the batch size. +func SetBatch(size int) PredictOption { + return func(p *PredictOptions) { + p.Batch = size + } +} + +// SetKeep sets the number of tokens from initial prompt to keep. +func SetNKeep(n int) PredictOption { + return func(p *PredictOptions) { + p.NKeep = n + } +} + +// Create a new PredictOptions object with the given options. +func NewPredictOptions(opts ...PredictOption) PredictOptions { + p := DefaultOptions + for _, opt := range opts { + opt(&p) + } + return p +} + +// SetTailFreeSamplingZ sets the tail free sampling, parameter z. +func SetTailFreeSamplingZ(tfz float64) PredictOption { + return func(p *PredictOptions) { + p.TailFreeSamplingZ = tfz + } +} + +// SetTypicalP sets the typicality parameter, p_typical. +func SetTypicalP(tp float64) PredictOption { + return func(p *PredictOptions) { + p.TypicalP = tp + } +} + +// SetFrequencyPenalty sets the frequency penalty parameter, freq_penalty. +func SetFrequencyPenalty(fp float64) PredictOption { + return func(p *PredictOptions) { + p.FrequencyPenalty = fp + } +} + +// SetPresencePenalty sets the presence penalty parameter, presence_penalty. +func SetPresencePenalty(pp float64) PredictOption { + return func(p *PredictOptions) { + p.PresencePenalty = pp + } +} + +// SetMirostat sets the mirostat parameter. +func SetMirostat(m int) PredictOption { + return func(p *PredictOptions) { + p.Mirostat = m + } +} + +// SetMirostatETA sets the mirostat ETA parameter. +func SetMirostatETA(me float64) PredictOption { + return func(p *PredictOptions) { + p.MirostatETA = me + } +} + +// SetMirostatTAU sets the mirostat TAU parameter. +func SetMirostatTAU(mt float64) PredictOption { + return func(p *PredictOptions) { + p.MirostatTAU = mt + } +} + +// SetPenalizeNL sets whether to penalize newlines or not. +func SetPenalizeNL(pnl bool) PredictOption { + return func(p *PredictOptions) { + p.PenalizeNL = pnl + } +} + +// SetLogitBias sets the logit bias parameter. +func SetLogitBias(lb string) PredictOption { + return func(p *PredictOptions) { + p.LogitBias = lb + } +} diff --git a/main.go b/main.go index a7759c5e..b445e7ce 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,7 @@ package main import ( - "github.com/ollama/ollama/cmd" + "github.com/jmorganca/ollama/cmd" ) func main() { diff --git a/server/routes.go b/server/routes.go index f8d2bd72..6cea2eb8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,9 +9,9 @@ import ( "runtime" "github.com/gin-gonic/gin" - llama "github.com/go-skynet/go-llama.cpp" + llama "github.com/jmorganca/ollama/llama" - "github.com/ollama/ollama/api" + "github.com/jmorganca/ollama/api" ) func Serve(ln net.Listener) error {