diff --git a/cmd/cmd.go b/cmd/cmd.go index 212c05e3..38928cd4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "log" "net" "net/http" @@ -13,11 +14,11 @@ import ( "strings" "time" + "github.com/chzyer/readline" "github.com/dustin/go-humanize" "github.com/olekukonko/tablewriter" "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" - "golang.org/x/term" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" @@ -111,7 +112,9 @@ func list(cmd *cobra.Command, args []string) error { var data [][]string for _, m := range models.Models { - data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")}) + if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) { + data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")}) + } } table := tablewriter.NewWriter(os.Stdout) @@ -169,7 +172,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error { return generate(cmd, args[0], strings.Join(args[1:], " ")) } - if term.IsTerminal(int(os.Stdin.Fd())) { + if readline.IsTerminal(int(os.Stdin.Fd())) { return generateInteractive(cmd, args[0]) } @@ -227,17 +230,99 @@ func generate(cmd *cobra.Command, model, prompt string) error { } func generateInteractive(cmd *cobra.Command, model string) error { - fmt.Print(">>> ") - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - if err := generate(cmd, model, scanner.Text()); err != nil { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + completer := readline.NewPrefixCompleter( + readline.PcItem("/help"), + readline.PcItem("/list"), + readline.PcItem("/set", + readline.PcItem("history"), + readline.PcItem("nohistory"), + readline.PcItem("mode", + readline.PcItem("vim"), + readline.PcItem("emacs"), + readline.PcItem("default"), + ), + ), + readline.PcItem("/exit"), + readline.PcItem("/bye"), + ) + + usage := func() { + fmt.Fprintln(os.Stderr, "commands:") + fmt.Fprintln(os.Stderr, completer.Tree(" ")) + } + + config := readline.Config{ + Prompt: ">>> ", + HistoryFile: filepath.Join(home, ".ollama", "history"), + AutoComplete: completer, + } + + scanner, err := readline.NewEx(&config) + if err != nil { + return err + } + defer scanner.Close() + + for { + line, err := scanner.Readline() + switch { + case errors.Is(err, io.EOF): + return nil + case errors.Is(err, readline.ErrInterrupt): + continue + case err != nil: return err } - fmt.Print(">>> ") - } + line = strings.TrimSpace(line) - return nil + switch { + case strings.HasPrefix(line, "/list"): + args := strings.Fields(line) + if err := list(cmd, args[1:]); err != nil { + return err + } + + continue + case strings.HasPrefix(line, "/set"): + args := strings.Fields(line) + if len(args) > 1 { + switch args[1] { + case "history": + scanner.HistoryEnable() + continue + case "nohistory": + scanner.HistoryDisable() + continue + case "mode": + if len(args) > 2 { + switch args[2] { + case "vim": + scanner.SetVimMode(true) + continue + case "emacs", "default": + scanner.SetVimMode(false) + continue + } + } + } + } + case line == "/help", line == "/?": + usage() + continue + case line == "/exit", line == "/bye": + return nil + } + + if err := generate(cmd, model, line); err != nil { + return err + } + } } func generateBatch(cmd *cobra.Command, model string) error { diff --git a/go.mod b/go.mod index 55dfbad3..c9bad06f 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( dario.cat/mergo v1.0.0 github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/chzyer/readline v1.5.1 github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect diff --git a/go.sum b/go.sum index 12c1b0b1..989f7d15 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,12 @@ github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -92,6 +98,7 @@ golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=