diff --git a/cmd/cmd.go b/cmd/cmd.go index 2315ad1a..e3c1d873 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,7 +17,6 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "runtime" "strings" "syscall" @@ -57,12 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() - modelfile, err := os.ReadFile(filename) + modelfile, err := os.Open(filename) if err != nil { return err } + defer modelfile.Close() - commands, err := parser.Parse(bytes.NewReader(modelfile)) + commands, err := parser.Parse(modelfile) if err != nil { return err } @@ -76,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner := progress.NewSpinner(status) p.Add(status, spinner) - for _, c := range commands { - switch c.Name { + for i := range commands { + switch commands[i].Name { case "model", "adapter": - path := c.Args + path := commands[i].Args if path == "~" { path = home } else if strings.HasPrefix(path, "~/") { @@ -91,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } fi, err := os.Stat(path) - if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" { continue } else if err != nil { return err @@ -114,13 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - name := c.Name - if c.Name == "model" { - name = "from" - } - - re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args)) - modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest)) + commands[i].Args = "@"+digest } } @@ -150,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { quantization, _ := cmd.Flags().GetString("quantization") - request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} + request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization} if err := client.Create(cmd.Context(), &request, fn); err != nil { return err } diff --git a/parser/parser.go b/parser/parser.go index c6667d66..22e07235 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -31,6 +31,33 @@ var ( errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") ) +func Format(cmds []Command) string { + var b bytes.Buffer + for _, cmd := range cmds { + name := cmd.Name + args := cmd.Args + + switch cmd.Name { + case "model": + name = "from" + args = cmd.Args + case "license", "template", "system", "adapter": + args = quote(args) + // pass + case "message": + role, message, _ := strings.Cut(cmd.Args, ": ") + args = role + " " + quote(message) + default: + name = "parameter" + args = cmd.Name + " " + cmd.Args + } + + fmt.Fprintln(&b, strings.ToUpper(name), args) + } + + return b.String() +} + func Parse(r io.Reader) (cmds []Command, err error) { var cmd Command var curr state @@ -197,6 +224,18 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { } } +func quote(s string) string { + if strings.Contains(s, "\n") || strings.HasSuffix(s, " ") { + if strings.Contains(s, "\"") { + return `"""` + s + `"""` + } + + return strconv.Quote(s) + } + + return s +} + func unquote(s string) (string, bool) { if len(s) == 0 { return "", false diff --git a/parser/parser_test.go b/parser/parser_test.go index 1eb10157..0b08f1ab 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -429,3 +429,63 @@ FROM foo }) } } + +func TestParseFormatParse(t *testing.T) { + var cases = []string{ + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system You are a Parser. Always Parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE """ +Very long and boring legal text. +Blah blah blah. +"Oh look, a quote!" +""" + +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c)) + assert.NoError(t, err) + + commands2, err := Parse(strings.NewReader(Format(commands))) + assert.NoError(t, err) + + assert.Equal(t, commands, commands2) + }) + } + +}