diff --git a/cmd/cmd.go b/cmd/cmd.go index 41276f8f..436c02c5 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -33,6 +33,17 @@ import ( "github.com/jmorganca/ollama/version" ) +type Painter struct{} + +func (p Painter) Paint(line []rune, l int) []rune { + termType := os.Getenv("TERM") + if termType == "xterm-256color" && len(line) == 0 { + prompt := "Send a message (/? for help)" + return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt))) + } + return line +} + func CreateHandler(cmd *cobra.Command, args []string) error { filename, _ := cmd.Flags().GetString("file") filename, err := filepath.Abs(filename) @@ -387,71 +398,71 @@ func RunGenerate(cmd *cobra.Command, args []string) error { type generateContextKey string func generate(cmd *cobra.Command, model, prompt string) error { - if len(strings.TrimSpace(prompt)) > 0 { - client, err := api.FromEnv() - if err != nil { - return err - } - - spinner := NewSpinner("") - go spinner.Spin(60 * time.Millisecond) - - var latest api.GenerateResponse - - generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) - if !ok { - generateContext = []int{} - } - - request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} - fn := func(response api.GenerateResponse) error { - if !spinner.IsFinished() { - spinner.Finish() - } - - latest = response - - fmt.Print(response.Response) - return nil - } - - if err := client.Generate(context.Background(), &request, fn); err != nil { - if strings.Contains(err.Error(), "failed to load model") { - // tell the user to check the server log, if it exists locally - home, nestedErr := os.UserHomeDir() - if nestedErr != nil { - // return the original error - return err - } - logPath := filepath.Join(home, ".ollama", "logs", "server.log") - if _, nestedErr := os.Stat(logPath); nestedErr == nil { - err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath) - } - } - return err - } - - fmt.Println() - fmt.Println() - - if !latest.Done { - return errors.New("unexpected end of response") - } - - verbose, err := cmd.Flags().GetBool("verbose") - if err != nil { - return err - } - - if verbose { - latest.Summary() - } - - ctx := cmd.Context() - ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) - cmd.SetContext(ctx) + client, err := api.FromEnv() + if err != nil { + return err } + spinner := NewSpinner("") + go spinner.Spin(60 * time.Millisecond) + + var latest api.GenerateResponse + + generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) + if !ok { + generateContext = []int{} + } + + request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} + fn := func(response api.GenerateResponse) error { + if !spinner.IsFinished() { + spinner.Finish() + } + + latest = response + + fmt.Print(response.Response) + return nil + } + + if err := client.Generate(context.Background(), &request, fn); err != nil { + if strings.Contains(err.Error(), "failed to load model") { + // tell the user to check the server log, if it exists locally + home, nestedErr := os.UserHomeDir() + if nestedErr != nil { + // return the original error + return err + } + logPath := filepath.Join(home, ".ollama", "logs", "server.log") + if _, nestedErr := os.Stat(logPath); nestedErr == nil { + err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath) + } + } + return err + } + + if prompt != "" { + fmt.Println() + fmt.Println() + } + + if !latest.Done { + return errors.New("unexpected end of response") + } + + verbose, err := cmd.Flags().GetBool("verbose") + if err != nil { + return err + } + + if verbose { + latest.Summary() + } + + ctx := cmd.Context() + ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) + cmd.SetContext(ctx) + return nil } @@ -461,6 +472,11 @@ func generateInteractive(cmd *cobra.Command, model string) error { return err } + // load the model + if err := generate(cmd, model, ""); err != nil { + return err + } + completer := readline.NewPrefixCompleter( readline.PcItem("/help"), readline.PcItem("/list"), @@ -492,6 +508,7 @@ func generateInteractive(cmd *cobra.Command, model string) error { } config := readline.Config{ + Painter: Painter{}, Prompt: ">>> ", HistoryFile: filepath.Join(home, ".ollama", "history"), AutoComplete: completer, @@ -621,8 +638,10 @@ func generateInteractive(cmd *cobra.Command, model string) error { return nil } - if err := generate(cmd, model, line); err != nil { - return err + if len(line) > 0 && line[0] != '/' { + if err := generate(cmd, model, line); err != nil { + return err + } } } } diff --git a/server/routes.go b/server/routes.go index b1fa2aff..d3d3d11c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -218,8 +218,12 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { - ch <- gin.H{"error": err.Error()} + if req.Prompt == "" { + ch <- api.GenerateResponse{Model: req.Model, Done: true} + } else { + if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { + ch <- gin.H{"error": err.Error()} + } } }()