From 2e99e7d5cb573e97d3a7a5fe7ac1f3ad1bd45eb8 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 26 Jun 2023 12:08:03 -0400 Subject: [PATCH] restructure proto.py --- proto.py | 94 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/proto.py b/proto.py index 54be0f52..be99550e 100644 --- a/proto.py +++ b/proto.py @@ -1,5 +1,6 @@ import json import os +import threading from llama_cpp import Llama from flask import Flask, Response, stream_with_context, request from flask_cors import CORS @@ -9,77 +10,96 @@ CORS(app) # enable CORS for all routes # llms tracks which models are loaded llms = {} +lock = threading.Lock() + + +def load(model): + with lock: + if not os.path.exists(f"./models/{model}.bin"): + return {"error": "The model does not exist."} + if model not in llms: + llms[model] = Llama(model_path=f"./models/{model}.bin") + return None + + +def unload(model): + with lock: + if not os.path.exists(f"./models/{model}.bin"): + return {"error": "The model does not exist."} + llms.pop(model, None) + return None + + +def generate(model, prompt): + # auto load + error = load(model) + if error is not None: + return error + stream = llms[model]( + str(prompt), # TODO: optimize prompt based on model + max_tokens=4096, + stop=["Q:", "\n"], + echo=True, + stream=True, + ) + for output in stream: + yield json.dumps(output) + + +def models(): + all_files = os.listdir("./models") + bin_files = [ + file.replace(".bin", "") for file in all_files if file.endswith(".bin") + ] + return bin_files @app.route("/load", methods=["POST"]) -def load(): +def load_route_handler(): data = request.get_json() model = data.get("model") - if not model: return Response("Model is required", status=400) - if not os.path.exists(f"./models/{model}.bin"): - return {"error": "The model does not exist."}, 400 - - if model not in llms: - llms[model] = Llama(model_path=f"./models/{model}.bin") - + error = load(model) + if error is not None: + return error return Response(status=204) @app.route("/unload", methods=["POST"]) -def unload(): +def unload_route_handler(): data = request.get_json() model = data.get("model") - if not model: return Response("Model is required", status=400) - if not os.path.exists(f"./models/{model}.bin"): - return {"error": "The model does not exist."}, 400 - - llms.pop(model, None) - + error = unload(model) + if error is not None: + return error return Response(status=204) @app.route("/generate", methods=["POST"]) -def generate(): +def generate_route_handler(): data = request.get_json() model = data.get("model") prompt = data.get("prompt") - if not model: return Response("Model is required", status=400) if not prompt: return Response("Prompt is required", status=400) if not os.path.exists(f"./models/{model}.bin"): return {"error": "The model does not exist."}, 400 - - if model not in llms: - # auto load - llms[model] = Llama(model_path=f"./models/{model}.bin") - - def stream_response(): - stream = llms[model]( - str(prompt), # TODO: optimize prompt based on model - max_tokens=4096, - stop=["Q:", "\n"], - echo=True, - stream=True, - ) - for output in stream: - yield json.dumps(output) - return Response( - stream_with_context(stream_response()), mimetype="text/event-stream" + stream_with_context(generate(model, prompt)), mimetype="text/event-stream" ) + @app.route("/models", methods=["GET"]) -def models(): - all_files = os.listdir("./models") - bin_files = [file.replace(".bin", "") for file in all_files if file.endswith(".bin")] +def models_route_handler(): + bin_files = models() return Response(json.dumps(bin_files), mimetype="application/json") + if __name__ == "__main__": app.run(debug=True, threaded=True, port=5001) app.run()