Merge pull request #4570 from ollama/mxyng/slices
lint some of the things
This commit is contained in:
commit
89d9900152
4
.github/workflows/test.yaml
vendored
4
.github/workflows/test.yaml
vendored
|
@ -269,9 +269,9 @@ jobs:
|
||||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||||
- uses: golangci/golangci-lint-action@v4
|
- uses: golangci/golangci-lint-action@v6
|
||||||
with:
|
with:
|
||||||
args: --timeout 8m0s -v
|
args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }}
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|
|
@ -9,9 +9,26 @@ linters:
|
||||||
- contextcheck
|
- contextcheck
|
||||||
- exportloopref
|
- exportloopref
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
# FIXME: for some reason this errors on windows
|
# conditionally enable this on linux/macos
|
||||||
# - gofmt
|
# - gofmt
|
||||||
# - goimports
|
# - goimports
|
||||||
|
- intrange
|
||||||
- misspell
|
- misspell
|
||||||
- nilerr
|
- nilerr
|
||||||
|
- nolintlint
|
||||||
|
- nosprintfhostport
|
||||||
|
- testifylint
|
||||||
|
- unconvert
|
||||||
- unused
|
- unused
|
||||||
|
- wastedassign
|
||||||
|
- whitespace
|
||||||
|
- usestdlibvars
|
||||||
|
severity:
|
||||||
|
default-severity: error
|
||||||
|
rules:
|
||||||
|
- linters:
|
||||||
|
- gofmt
|
||||||
|
- goimports
|
||||||
|
- intrange
|
||||||
|
- usestdlibvars
|
||||||
|
severity: info
|
||||||
|
|
|
@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"positive duration",
|
"positive duration",
|
||||||
time.Duration(42 * time.Second),
|
42 * time.Second,
|
||||||
time.Duration(42 * time.Second),
|
42 * time.Second,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"another positive duration",
|
"another positive duration",
|
||||||
time.Duration(42 * time.Minute),
|
42 * time.Minute,
|
||||||
time.Duration(42 * time.Minute),
|
42 * time.Minute,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"zero duration",
|
"zero duration",
|
||||||
|
|
|
@ -69,7 +69,6 @@ func init() {
|
||||||
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
|
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if runtime.GOOS == "darwin" {
|
} else if runtime.GOOS == "darwin" {
|
||||||
// TODO
|
// TODO
|
||||||
AppName += ".app"
|
AppName += ".app"
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func getCLIFullPath(command string) string {
|
func getCLIFullPath(command string) string {
|
||||||
cmdPath := ""
|
var cmdPath string
|
||||||
appExe, err := os.Executable()
|
appExe, err := os.Executable()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cmdPath = filepath.Join(filepath.Dir(appExe), command)
|
cmdPath = filepath.Join(filepath.Dir(appExe), command)
|
||||||
|
@ -65,7 +65,6 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||||
|
|
|
@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dll.Release() // nolint: errcheck
|
//nolint:errcheck
|
||||||
|
defer dll.Release()
|
||||||
|
|
||||||
pid := cmd.Process.Pid
|
pid := cmd.Process.Pid
|
||||||
|
|
||||||
|
@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to open process: %v", err)
|
return false, fmt.Errorf("failed to open process: %v", err)
|
||||||
}
|
}
|
||||||
defer windows.CloseHandle(hProcess) // nolint: errcheck
|
//nolint:errcheck
|
||||||
|
defer windows.CloseHandle(hProcess)
|
||||||
|
|
||||||
var exitCode uint32
|
var exitCode uint32
|
||||||
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
||||||
|
|
|
@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == 204 {
|
if resp.StatusCode == http.StatusNoContent {
|
||||||
slog.Debug("check update response 204 (current version is up to date)")
|
slog.Debug("check update response 204 (current version is up to date)")
|
||||||
return false, updateResp
|
return false, updateResp
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
||||||
slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
|
slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != http.StatusOK {
|
||||||
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
|
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
|
||||||
return false, updateResp
|
return false, updateResp
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error checking update: %w", err)
|
return fmt.Errorf("error checking update: %w", err)
|
||||||
}
|
}
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
|
@ -29,7 +29,6 @@ func GetID() string {
|
||||||
initStore()
|
initStore()
|
||||||
}
|
}
|
||||||
return store.ID
|
return store.ID
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFirstTimeRun() bool {
|
func GetFirstTimeRun() bool {
|
||||||
|
|
|
@ -47,7 +47,6 @@ func nativeLoop() {
|
||||||
default:
|
default:
|
||||||
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
||||||
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,8 +159,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
||||||
lResult, _, _ = pDefWindowProc.Call(
|
lResult, _, _ = pDefWindowProc.Call(
|
||||||
uintptr(hWnd),
|
uintptr(hWnd),
|
||||||
uintptr(message),
|
uintptr(message),
|
||||||
uintptr(wParam),
|
wParam,
|
||||||
uintptr(lParam),
|
lParam,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -186,7 +186,7 @@ func (t *winTray) initInstance() error {
|
||||||
t.muNID.Lock()
|
t.muNID.Lock()
|
||||||
defer t.muNID.Unlock()
|
defer t.muNID.Unlock()
|
||||||
t.nid = ¬ifyIconData{
|
t.nid = ¬ifyIconData{
|
||||||
Wnd: windows.Handle(t.window),
|
Wnd: t.window,
|
||||||
ID: 100,
|
ID: 100,
|
||||||
Flags: NIF_MESSAGE,
|
Flags: NIF_MESSAGE,
|
||||||
CallbackMessage: t.wmSystrayMessage,
|
CallbackMessage: t.wmSystrayMessage,
|
||||||
|
@ -197,7 +197,6 @@ func (t *winTray) initInstance() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTray) createMenu() error {
|
func (t *winTray) createMenu() error {
|
||||||
|
|
||||||
menuHandle, _, err := pCreatePopupMenu.Call()
|
menuHandle, _, err := pCreatePopupMenu.Call()
|
||||||
if menuHandle == 0 {
|
if menuHandle == 0 {
|
||||||
return err
|
return err
|
||||||
|
@ -246,7 +245,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
|
||||||
mi := menuItemInfo{
|
mi := menuItemInfo{
|
||||||
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
|
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
|
||||||
Type: MFT_STRING,
|
Type: MFT_STRING,
|
||||||
ID: uint32(menuItemId),
|
ID: menuItemId,
|
||||||
TypeData: titlePtr,
|
TypeData: titlePtr,
|
||||||
Cch: uint32(len(title)),
|
Cch: uint32(len(title)),
|
||||||
}
|
}
|
||||||
|
@ -302,11 +301,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||||
|
|
||||||
mi := menuItemInfo{
|
mi := menuItemInfo{
|
||||||
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
|
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
|
||||||
Type: MFT_SEPARATOR,
|
Type: MFT_SEPARATOR,
|
||||||
ID: uint32(menuItemId),
|
ID: menuItemId,
|
||||||
}
|
}
|
||||||
|
|
||||||
mi.Size = uint32(unsafe.Sizeof(mi))
|
mi.Size = uint32(unsafe.Sizeof(mi))
|
||||||
|
@ -426,7 +424,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
|
||||||
// Loads an image from file and shows it in tray.
|
// Loads an image from file and shows it in tray.
|
||||||
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
|
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
|
||||||
func (t *winTray) setIcon(src string) error {
|
func (t *winTray) setIcon(src string) error {
|
||||||
|
|
||||||
h, err := t.loadIconFrom(src)
|
h, err := t.loadIconFrom(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -444,7 +441,6 @@ func (t *winTray) setIcon(src string) error {
|
||||||
// Loads an image from file to be shown in tray or menu item.
|
// Loads an image from file to be shown in tray or menu item.
|
||||||
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
|
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
|
||||||
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
|
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
|
||||||
|
|
||||||
// Save and reuse handles of loaded images
|
// Save and reuse handles of loaded images
|
||||||
t.muLoadedImages.RLock()
|
t.muLoadedImages.RLock()
|
||||||
h, ok := t.loadedImages[src]
|
h, ok := t.loadedImages[src]
|
||||||
|
|
21
cmd/cmd.go
21
cmd/cmd.go
|
@ -20,6 +20,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -29,7 +30,6 @@ import (
|
||||||
"github.com/olekukonko/tablewriter"
|
"github.com/olekukonko/tablewriter"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
@ -746,7 +746,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||||
if wordWrap && termWidth >= 10 {
|
if wordWrap && termWidth >= 10 {
|
||||||
for _, ch := range content {
|
for _, ch := range content {
|
||||||
if state.lineLength+1 > termWidth-5 {
|
if state.lineLength+1 > termWidth-5 {
|
||||||
|
|
||||||
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
|
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
|
||||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
state.wordBuffer = ""
|
state.wordBuffer = ""
|
||||||
|
@ -1030,24 +1029,6 @@ func initializeKeypair() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:unused
|
|
||||||
func waitForServer(ctx context.Context, client *api.Client) error {
|
|
||||||
// wait for the server to start
|
|
||||||
timeout := time.After(5 * time.Second)
|
|
||||||
tick := time.Tick(500 * time.Millisecond)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
return errors.New("timed out waiting for server to start")
|
|
||||||
case <-tick:
|
|
||||||
if err := client.Heartbeat(ctx); err == nil {
|
|
||||||
return nil // server has started
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -8,11 +8,11 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
@ -85,11 +86,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||||
`
|
`
|
||||||
|
|
||||||
tmpl, err := template.New("").Parse(expectedModelfile)
|
tmpl, err := template.New("").Parse(expectedModelfile)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err = tmpl.Execute(&buf, opts)
|
err = tmpl.Execute(&buf, opts)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, buf.String(), mf)
|
assert.Equal(t, buf.String(), mf)
|
||||||
|
|
||||||
opts.ParentModel = "horseshark"
|
opts.ParentModel = "horseshark"
|
||||||
|
@ -107,10 +108,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||||
`
|
`
|
||||||
|
|
||||||
tmpl, err = template.New("").Parse(expectedModelfile)
|
tmpl, err = template.New("").Parse(expectedModelfile)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var parentBuf bytes.Buffer
|
var parentBuf bytes.Buffer
|
||||||
err = tmpl.Execute(&parentBuf, opts)
|
err = tmpl.Execute(&parentBuf, opts)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, parentBuf.String(), mf)
|
assert.Equal(t, parentBuf.String(), mf)
|
||||||
}
|
}
|
||||||
|
|
27
cmd/start.go
Normal file
27
cmd/start.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
//go:build darwin || windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func waitForServer(ctx context.Context, client *api.Client) error {
|
||||||
|
// wait for the server to start
|
||||||
|
timeout := time.After(5 * time.Second)
|
||||||
|
tick := time.Tick(500 * time.Millisecond)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
return errors.New("timed out waiting for server to start")
|
||||||
|
case <-tick:
|
||||||
|
if err := client.Heartbeat(ctx); err == nil {
|
||||||
|
return nil // server has started
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||||
if params.VocabSize > len(v.Tokens) {
|
if params.VocabSize > len(v.Tokens) {
|
||||||
missingTokens := params.VocabSize - len(v.Tokens)
|
missingTokens := params.VocabSize - len(v.Tokens)
|
||||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
for cnt := range missingTokens {
|
||||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||||
v.Scores = append(v.Scores, -1)
|
v.Scores = append(v.Scores, -1)
|
||||||
v.Types = append(v.Types, tokenTypeUserDefined)
|
v.Types = append(v.Types, tokenTypeUserDefined)
|
||||||
|
|
|
@ -35,7 +35,6 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||||
f32s = append(f32s, t...)
|
f32s = append(f32s, t...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return f32s, nil
|
return f32s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -119,11 +119,12 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([
|
||||||
}
|
}
|
||||||
|
|
||||||
var heads int
|
var heads int
|
||||||
if strings.HasSuffix(name, "attn_q.weight") {
|
switch {
|
||||||
|
case strings.HasSuffix(name, "attn_q.weight"):
|
||||||
heads = params.AttentionHeads
|
heads = params.AttentionHeads
|
||||||
} else if strings.HasSuffix(name, "attn_k.weight") {
|
case strings.HasSuffix(name, "attn_k.weight"):
|
||||||
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
|
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
|
||||||
} else {
|
default:
|
||||||
return nil, fmt.Errorf("unknown tensor name: %s", name)
|
return nil, fmt.Errorf("unknown tensor name: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||||
Name: name,
|
Name: name,
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Shape: shape[:],
|
Shape: shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
t.WriterTo = safetensorWriterTo{
|
t.WriterTo = safetensorWriterTo{
|
||||||
|
|
|
@ -85,13 +85,10 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
for _, pt := range t.PreTokenizer.PreTokenizers {
|
for _, pt := range t.PreTokenizer.PreTokenizers {
|
||||||
switch pt.Type {
|
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||||
case "Split":
|
|
||||||
if pt.Pattern.Regex != "" {
|
|
||||||
sha256sum.Write([]byte(pt.Pattern.Regex))
|
sha256sum.Write([]byte(pt.Pattern.Regex))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
|
switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
|
||||||
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
|
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
|
||||||
|
|
|
@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
||||||
Name: ggufName,
|
Name: ggufName,
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Offset: offset, // calculate the offset
|
Offset: offset, // calculate the offset
|
||||||
Shape: shape[:],
|
Shape: shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor.WriterTo = torchWriterTo{
|
tensor.WriterTo = torchWriterTo{
|
||||||
|
@ -104,7 +104,6 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensors, nil
|
return tensors, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAltParams(dirpath string) (*Params, error) {
|
func getAltParams(dirpath string) (*Params, error) {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package envconfig
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -126,7 +127,7 @@ func LoadConfig() {
|
||||||
var paths []string
|
var paths []string
|
||||||
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
||||||
paths = append(paths,
|
paths = append(paths,
|
||||||
filepath.Join(root),
|
root,
|
||||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
||||||
)
|
)
|
||||||
|
@ -184,8 +185,8 @@ func LoadConfig() {
|
||||||
AllowOrigins = append(AllowOrigins,
|
AllowOrigins = append(AllowOrigins,
|
||||||
fmt.Sprintf("http://%s", allowOrigin),
|
fmt.Sprintf("http://%s", allowOrigin),
|
||||||
fmt.Sprintf("https://%s", allowOrigin),
|
fmt.Sprintf("https://%s", allowOrigin),
|
||||||
fmt.Sprintf("http://%s:*", allowOrigin),
|
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
|
||||||
fmt.Sprintf("https://%s:*", allowOrigin),
|
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHumanNumber(t *testing.T) {
|
func TestHumanNumber(t *testing.T) {
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
input uint64
|
input uint64
|
||||||
expected string
|
expected string
|
||||||
|
|
|
@ -65,7 +65,7 @@ func AMDGetGPUInfo() []GpuInfo {
|
||||||
|
|
||||||
slog.Debug("detected hip devices", "count", count)
|
slog.Debug("detected hip devices", "count", count)
|
||||||
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
err = hl.HipSetDevice(i)
|
err = hl.HipSetDevice(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("set device", "id", i, "error", err)
|
slog.Warn("set device", "id", i, "error", err)
|
||||||
|
|
|
@ -80,7 +80,7 @@ func cleanupTmpDirs() {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
pid, err := strconv.Atoi(string(raw))
|
pid, err := strconv.Atoi(string(raw))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||||
// Another running ollama, ignore this tmpdir
|
// Another running ollama, ignore this tmpdir
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,5 +18,4 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||||
ids = append(ids, info.ID)
|
ids = append(ids, info.ID)
|
||||||
}
|
}
|
||||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,7 +187,7 @@ func GetGPUInfo() GpuInfoList {
|
||||||
resp := []GpuInfo{}
|
resp := []GpuInfo{}
|
||||||
|
|
||||||
// NVIDIA first
|
// NVIDIA first
|
||||||
for i := 0; i < gpuHandles.deviceCount; i++ {
|
for i := range gpuHandles.deviceCount {
|
||||||
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
||||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||||
continue
|
continue
|
||||||
|
@ -221,8 +221,8 @@ func GetGPUInfo() GpuInfoList {
|
||||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||||
gpuInfo.DependencyPath = depPath
|
gpuInfo.DependencyPath = depPath
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
gpuInfo.DriverMajor = int(driverMajor)
|
gpuInfo.DriverMajor = driverMajor
|
||||||
gpuInfo.DriverMinor = int(driverMinor)
|
gpuInfo.DriverMinor = driverMinor
|
||||||
|
|
||||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
resp = append(resp, gpuInfo)
|
resp = append(resp, gpuInfo)
|
||||||
|
|
|
@ -5,11 +5,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBasicGetGPUInfo(t *testing.T) {
|
func TestBasicGetGPUInfo(t *testing.T) {
|
||||||
info := GetGPUInfo()
|
info := GetGPUInfo()
|
||||||
assert.Greater(t, len(info), 0)
|
assert.NotEmpty(t, len(info))
|
||||||
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||||
if info[0].Library != "cpu" {
|
if info[0].Library != "cpu" {
|
||||||
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||||
|
@ -19,7 +20,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
|
||||||
|
|
||||||
func TestCPUMemInfo(t *testing.T) {
|
func TestCPUMemInfo(t *testing.T) {
|
||||||
info, err := GetCPUMem()
|
info, err := GetCPUMem()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
t.Skip("CPU memory not populated on darwin")
|
t.Skip("CPU memory not populated on darwin")
|
||||||
|
|
|
@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dims := 0
|
var dims int
|
||||||
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
|
for cnt := range len(tensor.Shape) {
|
||||||
if tensor.Shape[cnt] > 0 {
|
if tensor.Shape[cnt] > 0 {
|
||||||
dims++
|
dims++
|
||||||
}
|
}
|
||||||
|
@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < dims; i++ {
|
for i := range dims {
|
||||||
if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
|
if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,9 +5,9 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
|
@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||||
}
|
}
|
||||||
|
|
||||||
var layerCount int
|
var layerCount int
|
||||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
for i := range int(ggml.KV().BlockCount()) {
|
||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
memoryLayer := blk.size()
|
memoryLayer := blk.size()
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
|
|
|
@ -85,7 +85,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
var systemMemory uint64
|
var systemMemory uint64
|
||||||
gpuCount := len(gpus)
|
gpuCount := len(gpus)
|
||||||
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
||||||
|
|
||||||
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
||||||
|
|
||||||
cpuRunner = serverForCpu()
|
cpuRunner = serverForCpu()
|
||||||
|
@ -104,21 +103,22 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
var layers int
|
var layers int
|
||||||
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
|
||||||
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
|
switch {
|
||||||
|
case gpus[0].Library == "metal" && estimatedVRAM > systemMemory:
|
||||||
// disable partial offloading when model is greater than total system memory as this
|
// disable partial offloading when model is greater than total system memory as this
|
||||||
// can lead to locking up the system
|
// can lead to locking up the system
|
||||||
opts.NumGPU = 0
|
opts.NumGPU = 0
|
||||||
} else if gpus[0].Library != "metal" && layers == 0 {
|
case gpus[0].Library != "metal" && layers == 0:
|
||||||
// Don't bother loading into the GPU if no layers can fit
|
// Don't bother loading into the GPU if no layers can fit
|
||||||
cpuRunner = serverForCpu()
|
cpuRunner = serverForCpu()
|
||||||
gpuCount = 0
|
gpuCount = 0
|
||||||
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
|
case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu":
|
||||||
opts.NumGPU = layers
|
opts.NumGPU = layers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop through potential servers
|
// Loop through potential servers
|
||||||
finalErr := fmt.Errorf("no suitable llama servers found")
|
finalErr := errors.New("no suitable llama servers found")
|
||||||
|
|
||||||
if len(adapters) > 1 {
|
if len(adapters) > 1 {
|
||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||||
|
@ -232,7 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
|
|
||||||
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||||
|
|
||||||
for i := 0; i < len(servers); i++ {
|
for i := range len(servers) {
|
||||||
dir := availableServers[servers[i]]
|
dir := availableServers[servers[i]]
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
// Shouldn't happen
|
// Shouldn't happen
|
||||||
|
@ -284,7 +284,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
|
|
||||||
server := filepath.Join(dir, "ollama_llama_server")
|
server := filepath.Join(dir, "ollama_llama_server")
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
server = server + ".exe"
|
server += ".exe"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect tmp cleaners wiping out the file
|
// Detect tmp cleaners wiping out the file
|
||||||
|
@ -315,7 +315,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
s.cmd.Stdout = os.Stdout
|
s.cmd.Stdout = os.Stdout
|
||||||
s.cmd.Stderr = s.status
|
s.cmd.Stderr = s.status
|
||||||
|
|
||||||
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
|
|
||||||
// Update or add the path and visible devices variable with our adjusted version
|
// Update or add the path and visible devices variable with our adjusted version
|
||||||
|
@ -459,7 +459,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
return ServerStatusNotResponding, fmt.Errorf("server not responding")
|
return ServerStatusNotResponding, errors.New("server not responding")
|
||||||
}
|
}
|
||||||
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -245,7 +245,6 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
||||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseFileFile(t *testing.T) {
|
func TestParseFileFile(t *testing.T) {
|
||||||
|
@ -25,7 +26,7 @@ TEMPLATE template1
|
||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
modelfile, err := ParseFile(reader)
|
modelfile, err := ParseFile(reader)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expectedCommands := []Command{
|
expectedCommands := []Command{
|
||||||
{Name: "model", Args: "model1"},
|
{Name: "model", Args: "model1"},
|
||||||
|
@ -88,7 +89,7 @@ func TestParseFileFrom(t *testing.T) {
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||||
assert.ErrorIs(t, err, c.err)
|
require.ErrorIs(t, err, c.err)
|
||||||
if modelfile != nil {
|
if modelfile != nil {
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
}
|
}
|
||||||
|
@ -105,7 +106,7 @@ PARAMETER param1
|
||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
_, err := ParseFile(reader)
|
_, err := ParseFile(reader)
|
||||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileBadCommand(t *testing.T) {
|
func TestParseFileBadCommand(t *testing.T) {
|
||||||
|
@ -114,8 +115,7 @@ FROM foo
|
||||||
BADCOMMAND param1 value1
|
BADCOMMAND param1 value1
|
||||||
`
|
`
|
||||||
_, err := ParseFile(strings.NewReader(input))
|
_, err := ParseFile(strings.NewReader(input))
|
||||||
assert.ErrorIs(t, err, errInvalidCommand)
|
require.ErrorIs(t, err, errInvalidCommand)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileMessages(t *testing.T) {
|
func TestParseFileMessages(t *testing.T) {
|
||||||
|
@ -201,7 +201,7 @@ MESSAGE system`,
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||||
assert.ErrorIs(t, err, c.err)
|
require.ErrorIs(t, err, c.err)
|
||||||
if modelfile != nil {
|
if modelfile != nil {
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
}
|
}
|
||||||
|
@ -355,7 +355,7 @@ TEMPLATE """
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.multiline))
|
modelfile, err := ParseFile(strings.NewReader(c.multiline))
|
||||||
assert.ErrorIs(t, err, c.err)
|
require.ErrorIs(t, err, c.err)
|
||||||
if modelfile != nil {
|
if modelfile != nil {
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
}
|
}
|
||||||
|
@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) {
|
||||||
fmt.Fprintln(&b, "FROM foo")
|
fmt.Fprintln(&b, "FROM foo")
|
||||||
fmt.Fprintln(&b, "PARAMETER", k)
|
fmt.Fprintln(&b, "PARAMETER", k)
|
||||||
modelfile, err := ParseFile(&b)
|
modelfile, err := ParseFile(&b)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, []Command{
|
assert.Equal(t, []Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
|
@ -442,7 +442,7 @@ FROM foo
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -501,15 +501,14 @@ SYSTEM ""
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c))
|
modelfile, err := ParseFile(strings.NewReader(c))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
|
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, modelfile, modelfile2)
|
assert.Equal(t, modelfile, modelfile2)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileUTF16ParseFile(t *testing.T) {
|
func TestParseFileUTF16ParseFile(t *testing.T) {
|
||||||
|
@ -522,10 +521,10 @@ SYSTEM You are a utf16 file.
|
||||||
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
|
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
err := binary.Write(buf, binary.LittleEndian, utf16File)
|
err := binary.Write(buf, binary.LittleEndian, utf16File)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := ParseFile(buf)
|
actual, err := ParseFile(buf)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := []Command{
|
expected := []Command{
|
||||||
{Name: "model", Args: "bob"},
|
{Name: "model", Args: "bob"},
|
||||||
|
@ -539,9 +538,9 @@ SYSTEM You are a utf16 file.
|
||||||
// simulate a utf16 be file
|
// simulate a utf16 be file
|
||||||
buf = new(bytes.Buffer)
|
buf = new(bytes.Buffer)
|
||||||
err = binary.Write(buf, binary.BigEndian, utf16File)
|
err = binary.Write(buf, binary.BigEndian, utf16File)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err = ParseFile(buf)
|
actual, err = ParseFile(buf)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual.Commands)
|
assert.Equal(t, expected, actual.Commands)
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool {
|
||||||
stopped := p.stop()
|
stopped := p.stop()
|
||||||
if stopped {
|
if stopped {
|
||||||
// clear all progress lines
|
// clear all progress lines
|
||||||
for i := 0; i < p.pos; i++ {
|
for i := range p.pos {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
fmt.Fprint(p.w, "\033[A")
|
fmt.Fprint(p.w, "\033[A")
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ func (p *Progress) render() {
|
||||||
defer fmt.Fprint(p.w, "\033[?25h")
|
defer fmt.Fprint(p.w, "\033[?25h")
|
||||||
|
|
||||||
// clear already rendered progress lines
|
// clear already rendered progress lines
|
||||||
for i := 0; i < p.pos; i++ {
|
for i := range p.pos {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
fmt.Fprint(p.w, "\033[A")
|
fmt.Fprint(p.w, "\033[A")
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,6 @@ func (b *Buffer) GetLineSpacing(line int) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
return hasSpace.(bool)
|
return hasSpace.(bool)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) MoveLeft() {
|
func (b *Buffer) MoveLeft() {
|
||||||
|
@ -117,15 +116,12 @@ func (b *Buffer) MoveRight() {
|
||||||
|
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||||
|
|
||||||
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
|
||||||
b.DisplayPos += 1
|
b.DisplayPos += 1
|
||||||
|
|
||||||
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||||
b.DisplayPos += 1
|
b.DisplayPos += 1
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorRightN(rLength))
|
fmt.Print(cursorRightN(rLength))
|
||||||
}
|
}
|
||||||
|
@ -154,7 +150,7 @@ func (b *Buffer) MoveToStart() {
|
||||||
if b.Pos > 0 {
|
if b.Pos > 0 {
|
||||||
currLine := b.DisplayPos / b.LineWidth
|
currLine := b.DisplayPos / b.LineWidth
|
||||||
if currLine > 0 {
|
if currLine > 0 {
|
||||||
for cnt := 0; cnt < currLine; cnt++ {
|
for range currLine {
|
||||||
fmt.Print(CursorUp)
|
fmt.Print(CursorUp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,7 +165,7 @@ func (b *Buffer) MoveToEnd() {
|
||||||
currLine := b.DisplayPos / b.LineWidth
|
currLine := b.DisplayPos / b.LineWidth
|
||||||
totalLines := b.DisplaySize() / b.LineWidth
|
totalLines := b.DisplaySize() / b.LineWidth
|
||||||
if currLine < totalLines {
|
if currLine < totalLines {
|
||||||
for cnt := 0; cnt < totalLines-currLine; cnt++ {
|
for range totalLines - currLine {
|
||||||
fmt.Print(CursorDown)
|
fmt.Print(CursorDown)
|
||||||
}
|
}
|
||||||
remainder := b.DisplaySize() % b.LineWidth
|
remainder := b.DisplaySize() % b.LineWidth
|
||||||
|
@ -185,7 +181,7 @@ func (b *Buffer) MoveToEnd() {
|
||||||
|
|
||||||
func (b *Buffer) DisplaySize() int {
|
func (b *Buffer) DisplaySize() int {
|
||||||
sum := 0
|
sum := 0
|
||||||
for i := 0; i < b.Buf.Size(); i++ {
|
for i := range b.Buf.Size() {
|
||||||
if e, ok := b.Buf.Get(i); ok {
|
if e, ok := b.Buf.Get(i); ok {
|
||||||
if r, ok := e.(rune); ok {
|
if r, ok := e.(rune); ok {
|
||||||
sum += runewidth.RuneWidth(r)
|
sum += runewidth.RuneWidth(r)
|
||||||
|
@ -197,7 +193,6 @@ func (b *Buffer) DisplaySize() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) Add(r rune) {
|
func (b *Buffer) Add(r rune) {
|
||||||
|
|
||||||
if b.Pos == b.Buf.Size() {
|
if b.Pos == b.Buf.Size() {
|
||||||
b.AddChar(r, false)
|
b.AddChar(r, false)
|
||||||
} else {
|
} else {
|
||||||
|
@ -210,7 +205,6 @@ func (b *Buffer) AddChar(r rune, insert bool) {
|
||||||
b.DisplayPos += rLength
|
b.DisplayPos += rLength
|
||||||
|
|
||||||
if b.Pos > 0 {
|
if b.Pos > 0 {
|
||||||
|
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
fmt.Printf("%c", r)
|
fmt.Printf("%c", r)
|
||||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
@ -235,7 +229,6 @@ func (b *Buffer) AddChar(r rune, insert bool) {
|
||||||
} else {
|
} else {
|
||||||
b.LineHasSpace.Add(true)
|
b.LineHasSpace.Add(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("%c", r)
|
fmt.Printf("%c", r)
|
||||||
}
|
}
|
||||||
|
@ -356,7 +349,6 @@ func (b *Buffer) drawRemaining() {
|
||||||
|
|
||||||
func (b *Buffer) Remove() {
|
func (b *Buffer) Remove() {
|
||||||
if b.Buf.Size() > 0 && b.Pos > 0 {
|
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||||
|
|
||||||
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
||||||
if r, ok := e.(rune); ok {
|
if r, ok := e.(rune); ok {
|
||||||
rLength := runewidth.RuneWidth(r)
|
rLength := runewidth.RuneWidth(r)
|
||||||
|
@ -382,7 +374,6 @@ func (b *Buffer) Remove() {
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(" " + CursorLeft)
|
fmt.Print(" " + CursorLeft)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Printf(CursorBOL + ClearToEOL)
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||||
|
@ -391,10 +382,9 @@ func (b *Buffer) Remove() {
|
||||||
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
}
|
}
|
||||||
b.DisplayPos -= 1
|
b.DisplayPos -= 1
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(cursorLeftN(rLength))
|
||||||
for i := 0; i < rLength; i++ {
|
for range rLength {
|
||||||
fmt.Print(" ")
|
fmt.Print(" ")
|
||||||
}
|
}
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(cursorLeftN(rLength))
|
||||||
|
@ -451,7 +441,7 @@ func (b *Buffer) DeleteBefore() {
|
||||||
func (b *Buffer) DeleteRemaining() {
|
func (b *Buffer) DeleteRemaining() {
|
||||||
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
|
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
|
||||||
charsToDel := b.Buf.Size() - b.Pos
|
charsToDel := b.Buf.Size() - b.Pos
|
||||||
for cnt := 0; cnt < charsToDel; cnt++ {
|
for range charsToDel {
|
||||||
b.Delete()
|
b.Delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -495,7 +485,7 @@ func (b *Buffer) ClearScreen() {
|
||||||
if currPos > 0 {
|
if currPos > 0 {
|
||||||
targetLine := currPos / b.LineWidth
|
targetLine := currPos / b.LineWidth
|
||||||
if targetLine > 0 {
|
if targetLine > 0 {
|
||||||
for cnt := 0; cnt < targetLine; cnt++ {
|
for range targetLine {
|
||||||
fmt.Print(CursorDown)
|
fmt.Print(CursorDown)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -525,7 +515,7 @@ func (b *Buffer) Replace(r []rune) {
|
||||||
|
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Printf(CursorBOL + ClearToEOL)
|
||||||
|
|
||||||
for i := 0; i < lineNums; i++ {
|
for range lineNums {
|
||||||
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ func (h *History) Add(l []rune) {
|
||||||
func (h *History) Compact() {
|
func (h *History) Compact() {
|
||||||
s := h.Buf.Size()
|
s := h.Buf.Size()
|
||||||
if s > h.Limit {
|
if s > h.Limit {
|
||||||
for cnt := 0; cnt < s-h.Limit; cnt++ {
|
for range s - h.Limit {
|
||||||
h.Buf.Remove(0)
|
h.Buf.Remove(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -139,7 +139,7 @@ func (h *History) Save() error {
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
buf := bufio.NewWriter(f)
|
buf := bufio.NewWriter(f)
|
||||||
for cnt := 0; cnt < h.Size(); cnt++ {
|
for cnt := range h.Size() {
|
||||||
v, _ := h.Buf.Get(cnt)
|
v, _ := h.Buf.Get(cnt)
|
||||||
line, _ := v.([]rune)
|
line, _ := v.([]rune)
|
||||||
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
|
@ -63,7 +62,7 @@ func New(prompt Prompt) (*Instance, error) {
|
||||||
|
|
||||||
func (i *Instance) Readline() (string, error) {
|
func (i *Instance) Readline() (string, error) {
|
||||||
if !i.Terminal.rawmode {
|
if !i.Terminal.rawmode {
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
termios, err := SetRawMode(fd)
|
termios, err := SetRawMode(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -80,7 +79,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
fmt.Print(prompt)
|
fmt.Print(prompt)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
//nolint:errcheck
|
//nolint:errcheck
|
||||||
UnsetRawMode(fd, i.Terminal.termios)
|
UnsetRawMode(fd, i.Terminal.termios)
|
||||||
i.Terminal.rawmode = false
|
i.Terminal.rawmode = false
|
||||||
|
@ -136,7 +135,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
buf.MoveRight()
|
buf.MoveRight()
|
||||||
case CharBracketedPaste:
|
case CharBracketedPaste:
|
||||||
var code string
|
var code string
|
||||||
for cnt := 0; cnt < 3; cnt++ {
|
for range 3 {
|
||||||
r, err = i.Terminal.Read()
|
r, err = i.Terminal.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", io.EOF
|
return "", io.EOF
|
||||||
|
@ -198,7 +197,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
buf.Remove()
|
buf.Remove()
|
||||||
case CharTab:
|
case CharTab:
|
||||||
// todo: convert back to real tabs
|
// todo: convert back to real tabs
|
||||||
for cnt := 0; cnt < 8; cnt++ {
|
for range 8 {
|
||||||
buf.Add(' ')
|
buf.Add(' ')
|
||||||
}
|
}
|
||||||
case CharDelete:
|
case CharDelete:
|
||||||
|
@ -216,7 +215,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
case CharCtrlW:
|
case CharCtrlW:
|
||||||
buf.DeleteWord()
|
buf.DeleteWord()
|
||||||
case CharCtrlZ:
|
case CharCtrlZ:
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||||
case CharEnter, CharCtrlJ:
|
case CharEnter, CharCtrlJ:
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
|
@ -248,7 +247,7 @@ func (i *Instance) HistoryDisable() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTerminal() (*Terminal, error) {
|
func NewTerminal() (*Terminal, error) {
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
termios, err := SetRawMode(fd)
|
termios, err := SetRawMode(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleCharCtrlZ(fd int, termios any) (string, error) {
|
func handleCharCtrlZ(fd uintptr, termios any) (string, error) {
|
||||||
t := termios.(*Termios)
|
t := termios.(*Termios)
|
||||||
if err := UnsetRawMode(fd, t); err != nil {
|
if err := UnsetRawMode(fd, t); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package readline
|
package readline
|
||||||
|
|
||||||
func handleCharCtrlZ(fd int, state any) (string, error) {
|
func handleCharCtrlZ(fd uintptr, state any) (string, error) {
|
||||||
// not supported
|
// not supported
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
|
|
||||||
type Termios syscall.Termios
|
type Termios syscall.Termios
|
||||||
|
|
||||||
func SetRawMode(fd int) (*Termios, error) {
|
func SetRawMode(fd uintptr) (*Termios, error) {
|
||||||
termios, err := getTermios(fd)
|
termios, err := getTermios(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -25,13 +25,13 @@ func SetRawMode(fd int) (*Termios, error) {
|
||||||
return termios, setTermios(fd, &newTermios)
|
return termios, setTermios(fd, &newTermios)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnsetRawMode(fd int, termios any) error {
|
func UnsetRawMode(fd uintptr, termios any) error {
|
||||||
t := termios.(*Termios)
|
t := termios.(*Termios)
|
||||||
return setTermios(fd, t)
|
return setTermios(fd, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||||
func IsTerminal(fd int) bool {
|
func IsTerminal(fd uintptr) bool {
|
||||||
_, err := getTermios(fd)
|
_, err := getTermios(fd)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,17 +7,17 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getTermios(fd int) (*Termios, error) {
|
func getTermios(fd uintptr) (*Termios, error) {
|
||||||
termios := new(Termios)
|
termios := new(Termios)
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return termios, nil
|
return termios, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setTermios(fd int, termios *Termios) error {
|
func setTermios(fd uintptr, termios *Termios) error {
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,17 +10,17 @@ import (
|
||||||
const tcgets = 0x5401
|
const tcgets = 0x5401
|
||||||
const tcsets = 0x5402
|
const tcsets = 0x5402
|
||||||
|
|
||||||
func getTermios(fd int) (*Termios, error) {
|
func getTermios(fd uintptr) (*Termios, error) {
|
||||||
termios := new(Termios)
|
termios := new(Termios)
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return termios, nil
|
return termios, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setTermios(fd int, termios *Termios) error {
|
func setTermios(fd uintptr, termios *Termios) error {
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,13 +9,13 @@ type State struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTerminal checks if the given file descriptor is associated with a terminal
|
// IsTerminal checks if the given file descriptor is associated with a terminal
|
||||||
func IsTerminal(fd int) bool {
|
func IsTerminal(fd uintptr) bool {
|
||||||
var st uint32
|
var st uint32
|
||||||
err := windows.GetConsoleMode(windows.Handle(fd), &st)
|
err := windows.GetConsoleMode(windows.Handle(fd), &st)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRawMode(fd int) (*State, error) {
|
func SetRawMode(fd uintptr) (*State, error) {
|
||||||
var st uint32
|
var st uint32
|
||||||
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
|
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -32,7 +32,7 @@ func SetRawMode(fd int) (*State, error) {
|
||||||
return &State{st}, nil
|
return &State{st}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnsetRawMode(fd int, state any) error {
|
func UnsetRawMode(fd uintptr, state any) error {
|
||||||
s := state.(*State)
|
s := state.(*State)
|
||||||
return windows.SetConsoleMode(windows.Handle(fd), s.mode)
|
return windows.SetConsoleMode(windows.Handle(fd), s.mode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,17 +18,16 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
"github.com/ollama/ollama/auth"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
@ -988,7 +987,7 @@ func getTokenSubject(token string) string {
|
||||||
|
|
||||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||||
for i := 0; i < 2; i++ {
|
for range 2 {
|
||||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
|
|
@ -72,7 +72,6 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||||
default:
|
default:
|
||||||
layers = append(layers, &layerWithGGML{layer, nil})
|
layers = append(layers, &layerWithGGML{layer, nil})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return layers, nil
|
return layers, nil
|
||||||
|
|
|
@ -6,12 +6,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetBlobsPath(t *testing.T) {
|
func TestGetBlobsPath(t *testing.T) {
|
||||||
// GetBlobsPath expects an actual directory to exist
|
// GetBlobsPath expects an actual directory to exist
|
||||||
dir, err := os.MkdirTemp("", "ollama-test")
|
dir, err := os.MkdirTemp("", "ollama-test")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.RemoveAll(dir)
|
defer os.RemoveAll(dir)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -63,7 +64,7 @@ func TestGetBlobsPath(t *testing.T) {
|
||||||
|
|
||||||
got, err := GetBlobsPath(tc.digest)
|
got, err := GetBlobsPath(tc.digest)
|
||||||
|
|
||||||
assert.ErrorIs(t, tc.err, err, tc.name)
|
require.ErrorIs(t, tc.err, err, tc.name)
|
||||||
assert.Equal(t, tc.expected, got, tc.name)
|
assert.Equal(t, tc.expected, got, tc.name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -23,7 +24,6 @@ import (
|
||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
@ -77,7 +77,6 @@ func isSupportedImageType(image []byte) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
|
@ -942,7 +941,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowedHost(host) {
|
if allowedHost(host) {
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == http.MethodOptions {
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1306,7 +1305,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
fn := func(r llm.CompletionResponse) {
|
fn := func(r llm.CompletionResponse) {
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
@ -25,20 +26,20 @@ func createTestFile(t *testing.T, name string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), name)
|
f, err := os.CreateTemp(t.TempDir(), name)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return f.Name()
|
return f.Name()
|
||||||
}
|
}
|
||||||
|
@ -57,12 +58,12 @@ func Test_Routes(t *testing.T) {
|
||||||
|
|
||||||
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
||||||
modelfile, err := parser.ParseFile(r)
|
modelfile, err := parser.ParseFile(r)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
fn := func(resp api.ProgressResponse) {
|
fn := func(resp api.ProgressResponse) {
|
||||||
t.Logf("Status: %s", resp.Status)
|
t.Logf("Status: %s", resp.Status)
|
||||||
}
|
}
|
||||||
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
|
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
|
@ -74,9 +75,9 @@ func Test_Routes(t *testing.T) {
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
|
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -86,17 +87,17 @@ func Test_Routes(t *testing.T) {
|
||||||
Path: "/api/tags",
|
Path: "/api/tags",
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var modelList api.ListResponse
|
var modelList api.ListResponse
|
||||||
|
|
||||||
err = json.Unmarshal(body, &modelList)
|
err = json.Unmarshal(body, &modelList)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.NotNil(t, modelList.Models)
|
assert.NotNil(t, modelList.Models)
|
||||||
assert.Equal(t, 0, len(modelList.Models))
|
assert.Empty(t, len(modelList.Models))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -108,16 +109,16 @@ func Test_Routes(t *testing.T) {
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var modelList api.ListResponse
|
var modelList api.ListResponse
|
||||||
err = json.Unmarshal(body, &modelList)
|
err = json.Unmarshal(body, &modelList)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, 1, len(modelList.Models))
|
assert.Len(t, modelList.Models, 1)
|
||||||
assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
|
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -134,7 +135,7 @@ func Test_Routes(t *testing.T) {
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(createReq)
|
jsonData, err := json.Marshal(createReq)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
},
|
},
|
||||||
|
@ -142,11 +143,11 @@ func Test_Routes(t *testing.T) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, "application/json", contentType)
|
assert.Equal(t, "application/json", contentType)
|
||||||
_, err := io.ReadAll(resp.Body)
|
_, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, resp.StatusCode, 200)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
model, err := GetModel("t-bone")
|
model, err := GetModel("t-bone")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "t-bone:latest", model.ShortName)
|
assert.Equal(t, "t-bone:latest", model.ShortName)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -161,13 +162,13 @@ func Test_Routes(t *testing.T) {
|
||||||
Destination: "beefsteak",
|
Destination: "beefsteak",
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(copyReq)
|
jsonData, err := json.Marshal(copyReq)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
model, err := GetModel("beefsteak")
|
model, err := GetModel("beefsteak")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "beefsteak:latest", model.ShortName)
|
assert.Equal(t, "beefsteak:latest", model.ShortName)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -179,18 +180,18 @@ func Test_Routes(t *testing.T) {
|
||||||
createTestModel(t, "show-model")
|
createTestModel(t, "show-model")
|
||||||
showReq := api.ShowRequest{Model: "show-model"}
|
showReq := api.ShowRequest{Model: "show-model"}
|
||||||
jsonData, err := json.Marshal(showReq)
|
jsonData, err := json.Marshal(showReq)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var showResp api.ShowResponse
|
var showResp api.ShowResponse
|
||||||
err = json.Unmarshal(body, &showResp)
|
err = json.Unmarshal(body, &showResp)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var params []string
|
var params []string
|
||||||
paramsSplit := strings.Split(showResp.Parameters, "\n")
|
paramsSplit := strings.Split(showResp.Parameters, "\n")
|
||||||
|
@ -221,14 +222,14 @@ func Test_Routes(t *testing.T) {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
u := httpSrv.URL + tc.Path
|
u := httpSrv.URL + tc.Path
|
||||||
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tc.Setup != nil {
|
if tc.Setup != nil {
|
||||||
tc.Setup(t, req)
|
tc.Setup(t, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := httpSrv.Client().Do(req)
|
resp, err := httpSrv.Client().Do(req)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if tc.Expected != nil {
|
if tc.Expected != nil {
|
||||||
|
|
|
@ -7,17 +7,17 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LlmRequest struct {
|
type LlmRequest struct {
|
||||||
|
@ -66,7 +66,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options,
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.NumCtx = opts.NumCtx * envconfig.NumParallel
|
opts.NumCtx *= envconfig.NumParallel
|
||||||
|
|
||||||
req := &LlmRequest{
|
req := &LlmRequest{
|
||||||
ctx: c,
|
ctx: c,
|
||||||
|
@ -370,7 +370,6 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
|
||||||
r.refMu.Lock()
|
r.refMu.Lock()
|
||||||
gpuIDs := make([]string, 0, len(r.gpus))
|
gpuIDs := make([]string, 0, len(r.gpus))
|
||||||
if r.llama != nil {
|
if r.llama != nil {
|
||||||
|
|
||||||
// TODO this should be broken down by GPU instead of assuming uniform spread
|
// TODO this should be broken down by GPU instead of assuming uniform spread
|
||||||
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
|
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
|
||||||
for _, gpu := range r.gpus {
|
for _, gpu := range r.gpus {
|
||||||
|
@ -529,7 +528,6 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return finished
|
return finished
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ByDuration []*runnerRef
|
type ByDuration []*runnerRef
|
||||||
|
|
|
@ -12,11 +12,10 @@ import (
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,10 +52,10 @@ func TestLoad(t *testing.T) {
|
||||||
}
|
}
|
||||||
gpus := gpu.GpuInfoList{}
|
gpus := gpu.GpuInfoList{}
|
||||||
s.load(req, ggml, gpus)
|
s.load(req, ggml, gpus)
|
||||||
require.Len(t, req.successCh, 0)
|
require.Empty(t, req.successCh)
|
||||||
require.Len(t, req.errCh, 1)
|
require.Len(t, req.errCh, 1)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 0)
|
require.Empty(t, s.loaded)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
err := <-req.errCh
|
err := <-req.errCh
|
||||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||||
|
@ -113,7 +112,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
gguf := llm.NewGGUFV3(binary.LittleEndian)
|
gguf := llm.NewGGUFV3(binary.LittleEndian)
|
||||||
|
@ -131,7 +130,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||||
}, []llm.Tensor{
|
}, []llm.Tensor{
|
||||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
|
@ -190,8 +189,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1a.req.successCh:
|
case resp := <-scenario1a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario1a.req.errCh, 0)
|
require.Empty(t, scenario1a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -203,8 +202,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1b.req.successCh:
|
case resp := <-scenario1b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario1b.req.errCh, 0)
|
require.Empty(t, scenario1b.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -221,8 +220,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario2a.req.successCh:
|
case resp := <-scenario2a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario2a.srv)
|
require.Equal(t, resp.llama, scenario2a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario2a.req.errCh, 0)
|
require.Empty(t, scenario2a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -237,8 +236,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3a.req.successCh:
|
case resp := <-scenario3a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3a.srv)
|
require.Equal(t, resp.llama, scenario3a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3a.req.errCh, 0)
|
require.Empty(t, scenario3a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -253,8 +252,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3b.req.successCh:
|
case resp := <-scenario3b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3b.srv)
|
require.Equal(t, resp.llama, scenario3b.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3b.req.errCh, 0)
|
require.Empty(t, scenario3b.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -269,8 +268,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3c.req.successCh:
|
case resp := <-scenario3c.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3c.srv)
|
require.Equal(t, resp.llama, scenario3c.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3c.req.errCh, 0)
|
require.Empty(t, scenario3c.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -296,8 +295,8 @@ func TestRequests(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3d.req.successCh:
|
case resp := <-scenario3d.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3d.srv)
|
require.Equal(t, resp.llama, scenario3d.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3d.req.errCh, 0)
|
require.Empty(t, scenario3d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -332,7 +331,7 @@ func TestGetRunner(t *testing.T) {
|
||||||
slog.Info("scenario1b")
|
slog.Info("scenario1b")
|
||||||
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Len(t, successCh1b, 0)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
err := <-errCh1b
|
err := <-errCh1b
|
||||||
require.Contains(t, err.Error(), "server busy")
|
require.Contains(t, err.Error(), "server busy")
|
||||||
|
@ -340,8 +339,8 @@ func TestGetRunner(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, errCh1a, 0)
|
require.Empty(t, errCh1a)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
|
@ -355,9 +354,9 @@ func TestGetRunner(t *testing.T) {
|
||||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processsed to return an error
|
// Starts in pending channel, then should be quickly processsed to return an error
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
require.Len(t, successCh1c, 0)
|
require.Empty(t, successCh1c)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 0)
|
require.Empty(t, s.loaded)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
require.Len(t, errCh1c, 1)
|
require.Len(t, errCh1c, 1)
|
||||||
err = <-errCh1c
|
err = <-errCh1c
|
||||||
|
@ -386,8 +385,8 @@ func TestPrematureExpired(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, errCh1a, 0)
|
require.Empty(t, errCh1a)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
@ -401,9 +400,9 @@ func TestPrematureExpired(t *testing.T) {
|
||||||
time.Sleep(20 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond)
|
||||||
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
require.Len(t, s.finishedReqCh, 0)
|
require.Empty(t, s.finishedReqCh)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 0)
|
require.Empty(t, s.loaded)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// also shouldn't happen in real life
|
// also shouldn't happen in real life
|
||||||
|
@ -487,7 +486,6 @@ func TestFindRunnerToUnload(t *testing.T) {
|
||||||
r2.refCount = 1
|
r2.refCount = 1
|
||||||
resp = s.findRunnerToUnload()
|
resp = s.findRunnerToUnload()
|
||||||
require.Equal(t, r1, resp)
|
require.Equal(t, r1, resp)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNeedsReload(t *testing.T) {
|
func TestNeedsReload(t *testing.T) {
|
||||||
|
|
|
@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||||
case requestURL := <-b.nextURL:
|
case requestURL := <-b.nextURL:
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
var err error
|
var err error
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := range maxRetries {
|
||||||
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
|
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
|
@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||||
headers.Set("Content-Type", "application/octet-stream")
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
headers.Set("Content-Length", "0")
|
headers.Set("Content-Length", "0")
|
||||||
|
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := range maxRetries {
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
|
@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||||
}
|
}
|
||||||
|
|
||||||
// retry uploading to the redirect URL
|
// retry uploading to the redirect URL
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := range maxRetries {
|
||||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
|
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
|
|
|
@ -268,7 +268,6 @@ func TestNameIsValidPart(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilepathAllocs(t *testing.T) {
|
func TestFilepathAllocs(t *testing.T) {
|
||||||
|
@ -325,7 +324,7 @@ func TestParseNameFromFilepath(t *testing.T) {
|
||||||
filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"},
|
filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"},
|
||||||
filepath.Join("namespace", "model", "tag"): {},
|
filepath.Join("namespace", "model", "tag"): {},
|
||||||
filepath.Join("model", "tag"): {},
|
filepath.Join("model", "tag"): {},
|
||||||
filepath.Join("model"): {},
|
"model": {},
|
||||||
filepath.Join("..", "..", "model", "tag"): {},
|
filepath.Join("..", "..", "model", "tag"): {},
|
||||||
filepath.Join("", "namespace", ".", "tag"): {},
|
filepath.Join("", "namespace", ".", "tag"): {},
|
||||||
filepath.Join(".", ".", ".", "."): {},
|
filepath.Join(".", ".", ".", "."): {},
|
||||||
|
@ -382,7 +381,6 @@ func FuzzName(f *testing.F) {
|
||||||
t.Errorf("String() = %q; want %q", n.String(), s)
|
t.Errorf("String() = %q; want %q", n.String(), s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue