From 01c31aac78c15f65dcb988b371b10ca9b0af39f6 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 29 Jun 2023 18:22:45 -0400 Subject: [PATCH] consistency between generate and add naming --- ollama/cmd/cli.py | 20 ++++++++++++-------- ollama/engine.py | 20 +++++++++++--------- ollama/model.py | 28 +++++++++++++++------------- 3 files changed, 38 insertions(+), 30 deletions(-) diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index 28006141..c09741b4 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs): spinner = yaspin() spinner.start() spinner_running = True - for output in engine.generate(*args, **kwargs): - choices = output.get("choices", []) - if len(choices) > 0: - if spinner_running: - spinner.stop() - spinner_running = False - print("\r", end="") # move cursor back to beginning of line again - print(choices[0].get("text", ""), end="", flush=True) + try: + for output in engine.generate(*args, **kwargs): + choices = output.get("choices", []) + if len(choices) > 0: + if spinner_running: + spinner.stop() + spinner_running = False + print("\r", end="") # move cursor back to beginning of line again + print(choices[0].get("text", ""), end="", flush=True) + except Exception: + spinner.stop() + raise # end with a new line print(flush=True) diff --git a/ollama/engine.py b/ollama/engine.py index b45dcd6d..6b7944d1 100644 --- a/ollama/engine.py +++ b/ollama/engine.py @@ -1,5 +1,4 @@ -import os -import json +from os import path, dup, dup2, devnull import sys from contextlib import contextmanager from llama_cpp import Llama as LLM @@ -10,12 +9,12 @@ import ollama.prompt @contextmanager def suppress_stderr(): - stderr = os.dup(sys.stderr.fileno()) - with open(os.devnull, "w") as devnull: - os.dup2(devnull.fileno(), sys.stderr.fileno()) + stderr = dup(sys.stderr.fileno()) + with open(devnull, "w") as devnull: + dup2(devnull.fileno(), sys.stderr.fileno()) yield - os.dup2(stderr, sys.stderr.fileno()) + dup2(stderr, sys.stderr.fileno()) def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): @@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): def load(model, models_home=".", llms={}): llm = llms.get(model, None) if not llm: - stored_model_path = os.path.join(models_home, model, ".bin") - if os.path.exists(stored_model_path): + stored_model_path = path.join(models_home, model) + ".bin" + if path.exists(stored_model_path): model_path = stored_model_path else: # try loading this as a path to a model, rather than a model name - model_path = os.path.abspath(model) + model_path = path.abspath(model) + + if not path.exists(model_path): + raise Exception(f"Model not found: {model}") try: # suppress LLM's output diff --git a/ollama/model.py b/ollama/model.py index 5f99a13a..b1b686e8 100644 --- a/ollama/model.py +++ b/ollama/model.py @@ -1,6 +1,6 @@ -import os import requests import validators +from os import path, walk from urllib.parse import urlsplit, urlunsplit from tqdm import tqdm @@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models' def models(models_home='.', *args, **kwargs): - for _, _, files in os.walk(models_home): + for _, _, files in walk(models_home): for file in files: - base, ext = os.path.splitext(file) + base, ext = path.splitext(file) if ext == '.bin': yield base @@ -27,7 +27,7 @@ def get_url_from_directory(model): return model -def download_from_repo(url, models_home='.'): +def download_from_repo(url, file_name, models_home='.'): parts = urlsplit(url) path_parts = parts.path.split('/tree/') @@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'): location, branch = path_parts location = location.strip('/') + if file_name == '': + file_name = path.basename(location) download_url = urlunsplit( ( @@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'): json_response = response.json() download_url, file_size = find_bin_file(json_response, location, branch) - return download_file(download_url, models_home, location, file_size) + return download_file(download_url, models_home, file_name, file_size) def find_bin_file(json_response, location, branch): @@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch): return download_url, file_size -def download_file(download_url, models_home, location, file_size): - local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin' +def download_file(download_url, models_home, file_name, file_size): + local_filename = path.join(models_home, file_name) + '.bin' - first_byte = ( - os.path.getsize(local_filename) if os.path.exists(local_filename) else 0 - ) + first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 if first_byte >= file_size: return local_filename - print(f'Pulling {os.path.basename(location)}...') + print(f'Pulling {file_name}...') header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {} @@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size): def pull(model, models_home='.', *args, **kwargs): - if os.path.exists(model): + if path.exists(model): # a file on the filesystem is being specified return model # check the remote model location and see if it needs to be downloaded url = model + file_name = "" if not validators.url(url) and not url.startswith('huggingface.co'): url = get_url_from_directory(model) + file_name = model if not (url.startswith('http://') or url.startswith('https://')): url = f'https://{url}' @@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs): return model raise Exception(f'Unknown model {model}') - local_filename = download_from_repo(url, models_home) + local_filename = download_from_repo(url, file_name, models_home) return local_filename