diff --git a/server/README.md b/server/README.md new file mode 100644 index 00000000..0607716f --- /dev/null +++ b/server/README.md @@ -0,0 +1,34 @@ +# Server + +🙊 + +## Installation + +If using Apple silicon, you need a Python version that supports arm64: + +```bash +wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh +bash Miniforge3-MacOSX-arm64.sh +``` + +Get the dependencies: + +```bash +pip install llama-cpp-python +pip install -r requirements.txt +``` + +## Running + +Put your model in `models/` and run: + +```bash +python server.py +``` + +## API + +### `POST /generate` + +model: `string` - The name of the model to use in the `models` folder. +prompt: `string` - The prompt to use. diff --git a/server/build.sh b/server/build.sh deleted file mode 100755 index 0dc7242d..00000000 --- a/server/build.sh +++ /dev/null @@ -1,2 +0,0 @@ -LIBRARY_PATH=$PWD/go-llama.cpp C_INCLUDE_PATH=$PWD/go-llama.cpp go build . - diff --git a/server/go.mod b/server/go.mod deleted file mode 100644 index 99945d85..00000000 --- a/server/go.mod +++ /dev/null @@ -1,8 +0,0 @@ -module github.com/keypairdev/keypair - -go 1.20 - -require ( - github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1 - github.com/sashabaranov/go-openai v1.11.3 -) diff --git a/server/go.sum b/server/go.sum deleted file mode 100644 index 3a48bc2d..00000000 --- a/server/go.sum +++ /dev/null @@ -1,15 +0,0 @@ -github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= -github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1 h1:UQ8y3kHxBgh3BnaW06y/X97fEN48yHPwWobMz8/aztU= -github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= -github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= -github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= -github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc= -github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/server/main.go b/server/main.go deleted file mode 100644 index 31c18cf4..00000000 --- a/server/main.go +++ /dev/null @@ -1,113 +0,0 @@ -package main - -import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "os" - "runtime" - - "github.com/sashabaranov/go-openai" - - llama "github.com/go-skynet/go-llama.cpp" -) - - -type Model interface { - Name() string - Handler(w http.ResponseWriter, r *http.Request) -} - -type LLama7B struct { - llama *llama.LLama -} - -func NewLLama7B() *LLama7B { - llama, err := llama.New("./models/7B/ggml-model-q4_0.bin", llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(128)) - if err != nil { - fmt.Println("Loading the model failed:", err.Error()) - os.Exit(1) - } - - return &LLama7B{ - llama: llama, - } -} - -func (l *LLama7B) Name() string { - return "LLaMA 7B" -} - -func (m *LLama7B) Handler(w http.ResponseWriter, r *http.Request) { - var text bytes.Buffer - io.Copy(&text, r.Body) - - _, err := m.llama.Predict(text.String(), llama.Debug, llama.SetTokenCallback(func(token string) bool { - w.Write([]byte(token)) - return true - }), llama.SetTokens(512), llama.SetThreads(runtime.NumCPU()), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama")) - - if err != nil { - fmt.Println("Predict failed:", err.Error()) - os.Exit(1) - } - - embeds, err := m.llama.Embeddings(text.String()) - if err != nil { - fmt.Printf("Embeddings: error %s \n", err.Error()) - } - fmt.Printf("Embeddings: %v", embeds) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") -} - -type GPT4 struct { - apiKey string -} - -func (g *GPT4) Name() string { - return "OpenAI GPT-4" -} - -func (g *GPT4) Handler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - client := openai.NewClient("your token") - resp, err := client.CreateChatCompletion( - context.Background(), - openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - }, - ) - if err != nil { - fmt.Printf("chat completion error: %v\n", err) - return - } - - fmt.Println(resp.Choices[0].Message.Content) - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusOK) -} - -// TODO: add subcommands to spawn different models -func main() { - model := &LLama7B{} - - http.HandleFunc("/generate", model.Handler) - - fmt.Println("Starting server on :8080") - if err := http.ListenAndServe(":8080", nil); err != nil { - fmt.Printf("Error starting server: %s\n", err) - return - } -} diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 00000000..91c27d18 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,2 @@ +Flask==2.3.2 +flask_cors==3.0.10 diff --git a/server/server.py b/server/server.py new file mode 100644 index 00000000..813836f3 --- /dev/null +++ b/server/server.py @@ -0,0 +1,47 @@ +import json +import os +from llama_cpp import Llama +from flask import Flask, Response, stream_with_context, request +from flask_cors import CORS, cross_origin + +app = Flask(__name__) +CORS(app) # enable CORS for all routes + +# llms tracks which models are loaded +llms = {} + + +@app.route("/generate", methods=["POST"]) +def generate(): + data = request.get_json() + model = data.get("model") + prompt = data.get("prompt") + + if not model: + return Response("Model is required", status=400) + if not prompt: + return Response("Prompt is required", status=400) + if not os.path.exists(f"../models/{model}.bin"): + return {"error": "The model file does not exist."}, 400 + + if model not in llms: + llms[model] = Llama(model_path=f"../models/{model}.bin") + + def stream_response(): + stream = llms[model]( + str(prompt), # TODO: optimize prompt based on model + max_tokens=4096, + stop=["Q:", "\n"], + echo=True, + stream=True, + ) + for output in stream: + yield json.dumps(output) + + return Response( + stream_with_context(stream_response()), mimetype="text/event-stream" + ) + + +if __name__ == "__main__": + app.run(debug=True, threaded=True, port=5000)