diff --git a/cmd/cmd.go b/cmd/cmd.go index 693d0cab..fda70b8b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -21,6 +21,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/progressbar" "github.com/jmorganca/ollama/server" ) @@ -36,6 +37,24 @@ func CreateHandler(cmd *cobra.Command, args []string) error { var spinner *Spinner + // pull the model file if needed + mf, err := os.Open(filename) + defer mf.Close() + cmds, err := parser.Parse(mf) + if err != nil { + return err + } + mf.Close() + for _, c := range cmds { + if c.Name == "model" { + // check if the model file needs to be pulled + checkPull(c.Args) + } + } + if err != nil { + return err + } + request := api.CreateRequest{Name: args[0], Path: filename} fn := func(resp api.CreateProgress) error { if spinner != nil { @@ -59,8 +78,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return nil } -func RunHandler(cmd *cobra.Command, args []string) error { - mp := server.ParseModelPath(args[0]) +func checkPull(model string) error { + mp := server.ParseModelPath(model) fp, err := mp.GetManifestPath(false) if err != nil { return err @@ -69,7 +88,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { _, err = os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): - if err := pull(args[0], false); err != nil { + if err := pull(model, false); err != nil { var apiStatusError api.StatusError if !errors.As(err, &apiStatusError) { return err @@ -83,6 +102,13 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } + return nil +} + +func RunHandler(cmd *cobra.Command, args []string) error { + if err := checkPull(args[0]); err != nil { + return err + } return RunGenerate(cmd, args) }