From 3ca8f7232722b0936d437daa1835be1d842b8e73 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 26 Jun 2023 13:00:40 -0400 Subject: [PATCH] add generate command --- proto.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/proto.py b/proto.py index 0292c201..00139973 100644 --- a/proto.py +++ b/proto.py @@ -31,19 +31,19 @@ def unload(model): return None -def generate(model, prompt): +def query(model, prompt): # auto load error = load(model) if error is not None: return error - stream = llms[model]( + generated = llms[model]( str(prompt), # TODO: optimize prompt based on model max_tokens=4096, stop=["Q:", "\n"], echo=True, stream=True, ) - for output in stream: + for output in generated: yield json.dumps(output) @@ -91,7 +91,7 @@ def generate_route_handler(): if not os.path.exists(f"./models/{model}.bin"): return {"error": "The model does not exist."}, 400 return Response( - stream_with_context(generate(model, prompt)), mimetype="text/event-stream" + stream_with_context(query(model, prompt)), mimetype="text/event-stream" ) @@ -117,5 +117,19 @@ def serve(port, debug): app.run(host="0.0.0.0", port=port, debug=debug) +@cli.command() +@click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use") +@click.option("--prompt", default="", help="The prompt for the model") +def generate(model, prompt): + if prompt == "": + prompt = input("Prompt: ") + output = "" + for generated in query(model, prompt): + generated_json = json.loads(generated) + text = generated_json["choices"][0]["text"] + output += text + print(f"\r{output}", end="", flush=True) + + if __name__ == "__main__": cli()