search command

This commit is contained in:
Bruce MacDonald 2023-06-30 16:27:47 -04:00
parent fd1207a44b
commit d01be075b6
5 changed files with 62 additions and 24 deletions

View file

@ -87,8 +87,6 @@ Download a model
ollama.pull("huggingface.co/thebloke/llama-7b-ggml") ollama.pull("huggingface.co/thebloke/llama-7b-ggml")
``` ```
## Coming Soon
### `ollama.search("query")` ### `ollama.search("query")`
Search for compatible models that Ollama can run Search for compatible models that Ollama can run

View file

@ -37,14 +37,6 @@ def main():
title='commands', title='commands',
) )
server.set_parser(
subparsers.add_parser(
"serve",
description="Start a persistent server to interact with models via the API.",
help="Start a persistent server to interact with models via the API.",
)
)
list_parser = subparsers.add_parser( list_parser = subparsers.add_parser(
"models", "models",
description="List all available models stored locally.", description="List all available models stored locally.",
@ -52,6 +44,18 @@ def main():
) )
list_parser.set_defaults(fn=list_models) list_parser.set_defaults(fn=list_models)
search_parser = subparsers.add_parser(
"search",
description="Search for compatible models that Ollama can run.",
help="Search for compatible models that Ollama can run. Usage: search [model]",
)
search_parser.add_argument(
"query",
nargs="?",
help="Optional name of the model to search for.",
)
search_parser.set_defaults(fn=search)
pull_parser = subparsers.add_parser( pull_parser = subparsers.add_parser(
"pull", "pull",
description="Download a specified model from a remote source.", description="Download a specified model from a remote source.",
@ -73,6 +77,14 @@ def main():
) )
run_parser.set_defaults(fn=run) run_parser.set_defaults(fn=run)
server.set_parser(
subparsers.add_parser(
"serve",
description="Start a persistent server to interact with models via the API.",
help="Start a persistent server to interact with models via the API.",
)
)
args = parser.parse_args() args = parser.parse_args()
args = vars(args) args = vars(args)
@ -146,6 +158,22 @@ def generate_batch(*args, **kwargs):
generate_oneshot(*args, **kwargs) generate_oneshot(*args, **kwargs)
def search(*args, **kwargs):
try:
model_names = model.search_directory(*args, **kwargs)
if len(model_names) == 0:
print("No models found.")
return
elif len(model_names) == 1:
print(f"Found {len(model_names)} available model:")
else:
print(f"Found {len(model_names)} available models:")
for model_name in model_names:
print(model_name.lower())
except Exception as e:
print("Failed to fetch available models, check your network connection")
def pull(*args, **kwargs): def pull(*args, **kwargs):
model.pull(model_name=kwargs.pop('model'), *args, **kwargs) model.pull(model_name=kwargs.pop('model'), *args, **kwargs)

View file

@ -1,6 +1,7 @@
import os import os
import sys import sys
from os import path from os import path
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from fuzzywuzzy import process from fuzzywuzzy import process
from llama_cpp import Llama from llama_cpp import Llama
@ -30,7 +31,7 @@ def load(model_name, models={}):
if not models.get(model_name, None): if not models.get(model_name, None):
model_path = path.expanduser(model_name) model_path = path.expanduser(model_name)
if not path.exists(model_path): if not path.exists(model_path):
model_path = MODELS_CACHE_PATH / model_name + ".bin" model_path = str(MODELS_CACHE_PATH / (model_name + ".bin"))
runners = { runners = {
model_type: cls model_type: cls
@ -52,14 +53,10 @@ def unload(model_name, models={}):
class LlamaCppRunner: class LlamaCppRunner:
def __init__(self, model_path, model_type): def __init__(self, model_path, model_type):
try: try:
with suppress(sys.stderr), suppress(sys.stdout): with suppress(sys.stderr), suppress(sys.stdout):
self.model = Llama(model_path, self.model = Llama(model_path, verbose=False, n_gpu_layers=1, seed=-1)
verbose=False,
n_gpu_layers=1,
seed=-1)
except Exception: except Exception:
raise Exception("Failed to load model", model_path, model_type) raise Exception("Failed to load model", model_path, model_type)
@ -88,10 +85,10 @@ class LlamaCppRunner:
class CtransformerRunner: class CtransformerRunner:
def __init__(self, model_path, model_type): def __init__(self, model_path, model_type):
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_type=model_type, local_files_only=True) model_path, model_type=model_type, local_files_only=True
)
@staticmethod @staticmethod
def model_types(): def model_types():

View file

@ -18,13 +18,26 @@ def models(*args, **kwargs):
yield base yield base
# search the directory and return all models which contain the search term as a substring,
# or all models if no search term is provided
def search_directory(query):
response = requests.get(MODELS_MANIFEST)
response.raise_for_status()
directory = response.json()
model_names = []
for model_info in directory:
if not query or query.lower() in model_info.get('name', '').lower():
model_names.append(model_info.get('name'))
return model_names
# get the url of the model from our curated directory # get the url of the model from our curated directory
def get_url_from_directory(model): def get_url_from_directory(model):
response = requests.get(MODELS_MANIFEST) response = requests.get(MODELS_MANIFEST)
response.raise_for_status() response.raise_for_status()
directory = response.json() directory = response.json()
for model_info in directory: for model_info in directory:
if model_info.get('name') == model: if model_info.get('name').lower() == model.lower():
return model_info.get('url') return model_info.get('url')
return model return model
@ -42,7 +55,6 @@ def download_from_repo(url, file_name):
location = location.strip('/') location = location.strip('/')
if file_name == '': if file_name == '':
file_name = path.basename(location).lower() file_name = path.basename(location).lower()
download_url = urlunsplit( download_url = urlunsplit(
( (
'https', 'https',
@ -78,7 +90,7 @@ def find_bin_file(json_response, location, branch):
def download_file(download_url, file_name, file_size): def download_file(download_url, file_name, file_size):
local_filename = MODELS_CACHE_PATH / file_name + '.bin' local_filename = MODELS_CACHE_PATH / str(file_name + '.bin')
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
@ -111,7 +123,8 @@ def download_file(download_url, file_name, file_size):
def pull(model_name, *args, **kwargs): def pull(model_name, *args, **kwargs):
if path.exists(model_name): maybe_existing_model_location = MODELS_CACHE_PATH / str(model_name + '.bin')
if path.exists(model_name) or path.exists(maybe_existing_model_location):
# a file on the filesystem is being specified # a file on the filesystem is being specified
return model_name return model_name
# check the remote model location and see if it needs to be downloaded # check the remote model location and see if it needs to be downloaded
@ -120,7 +133,6 @@ def pull(model_name, *args, **kwargs):
if not validators.url(url) and not url.startswith('huggingface.co'): if not validators.url(url) and not url.startswith('huggingface.co'):
url = get_url_from_directory(model_name) url = get_url_from_directory(model_name)
file_name = model_name file_name = model_name
if not (url.startswith('http://') or url.startswith('https://')): if not (url.startswith('http://') or url.startswith('https://')):
url = f'https://{url}' url = f'https://{url}'

View file

@ -1,9 +1,12 @@
from os import path
from difflib import get_close_matches from difflib import get_close_matches
from jinja2 import Environment, PackageLoader from jinja2 import Environment, PackageLoader
def template(name, prompt): def template(name, prompt):
environment = Environment(loader=PackageLoader(__name__, 'templates')) environment = Environment(loader=PackageLoader(__name__, 'templates'))
best_templates = get_close_matches(name, environment.list_templates(), n=1, cutoff=0) best_templates = get_close_matches(
path.basename(name), environment.list_templates(), n=1, cutoff=0
)
template = environment.get_template(best_templates.pop()) template = environment.get_template(best_templates.pop())
return template.render(prompt=prompt) return template.render(prompt=prompt)