diff --git a/server/model.go b/server/model.go index 55fb2d8d..124693d3 100644 --- a/server/model.go +++ b/server/model.go @@ -272,6 +272,30 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } +func parseObjects(s string) []map[string]any { + var objs []map[string]any + for offset := 0; offset < len(s); { + var obj map[string]any + decoder := json.NewDecoder(strings.NewReader(s[offset:])) + if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break + } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { + // skip over any syntax errors + offset += int(syntax.Offset) + } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { + // skip over any unmarshalable types + offset += int(unmarshalType.Offset) + } else if err != nil { + return nil + } else { + offset += int(decoder.InputOffset()) + objs = append(objs, obj) + } + } + + return objs +} + // parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. // mxyng: this only really works if the input contains tool calls in some JSON format func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { @@ -304,16 +328,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { return nil, false } - var kv map[string]any - // execute the subtree with placeholders to identify the keys - // trim any commands that might exist in the template - if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { + templateObjects := parseObjects(b.String()) + if len(templateObjects) == 0 { return nil, false } // find the keys that correspond to the name and arguments fields var name, arguments string - for k, v := range kv { + for k, v := range templateObjects[0] { switch v.(type) { case string: name = k @@ -326,43 +348,32 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { return nil, false } - var objs []map[string]any - for offset := 0; offset < len(s); { - var obj map[string]any - decoder := json.NewDecoder(strings.NewReader(s[offset:])) - if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break - } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { - // skip over any syntax errors - offset += int(syntax.Offset) - } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { - // skip over any unmarshalable types - offset += int(unmarshalType.Offset) - } else if err != nil { - slog.Error("parseToolCalls", "error", err) - return nil, false - } else { - offset += int(decoder.InputOffset()) + responseObjects := parseObjects(s) + if len(responseObjects) == 0 { + return nil, false + } - // collect all nested objects - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - } - - return all + // collect all nested objects + var collect func(any) []map[string]any + collect = func(obj any) (all []map[string]any) { + switch o := obj.(type) { + case map[string]any: + all = append(all, o) + for _, v := range o { + all = append(all, collect(v)...) + } + case []any: + for _, v := range o { + all = append(all, collect(v)...) } - objs = append(objs, collect(obj)...) } + + return all + } + + var objs []map[string]any + for _, p := range responseObjects { + objs = append(objs, collect(p)...) } var toolCalls []api.ToolCall diff --git a/server/model_test.go b/server/model_test.go index e1737a5b..304d4655 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -69,6 +69,7 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} `, true}, {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, + {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true}, } var tools []api.Tool @@ -217,3 +218,45 @@ func TestParseLayerFromCopy(t *testing.T) { t.Fatalf("got %d != want 5", len(layers)) } } + +func TestParseObjects(t *testing.T) { + tests := []struct { + input string + want []map[string]any + }{ + { + input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + want: []map[string]any{ + {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, + {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, + }, + }, + { + input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + want: []map[string]any{ + {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, + }, + }, + { + input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, + want: []map[string]any{ + {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, + {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, + }, + }, + { + input: `{"name": "get_current_weather", "arguments": `, + want: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + got := parseObjects(tc.input) + + if diff := cmp.Diff(got, tc.want); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/server/testdata/tools/nemotron.gotmpl b/server/testdata/tools/nemotron.gotmpl new file mode 100644 index 00000000..1b6b89ec --- /dev/null +++ b/server/testdata/tools/nemotron.gotmpl @@ -0,0 +1,33 @@ +{{- if (or .Tools .System) }}System +{{ if .System }}{{ .System }} + + +{{ end }} +{{- if .Tools }} +{{- range .Tools }} {{ . }} {{ end }} + + +{{ end }} +{{- end }} +{{- range $i, $m := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}User +{{ .Content }} +{{- if $last }} +Assistant +{{- end }} +{{ else if eq .Role "tool" }}Tool +{{ .Content }} +{{- if $last }} +Assistant +{{- end }} +{{ else if eq .Role "assistant" }}Assistant +{{- if .ToolCalls }} +{{ range .ToolCalls }} {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} {{ end }} +{{ else }} +{{ .Content }} +{{- if not $last }} +{{ end }} +{{- end }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/server/testdata/tools/nemotron.out b/server/testdata/tools/nemotron.out new file mode 100644 index 00000000..2166b202 --- /dev/null +++ b/server/testdata/tools/nemotron.out @@ -0,0 +1,18 @@ +System +You are a knowledgable assistant. You can answer questions and perform tasks. + + + {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +User +What's the weather like today in Paris? +Assistant + {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +Tool +22 +Assistant +The current temperature in Paris, France is 22 degrees Celsius. +User +What's the weather like today in San Francisco and Toronto? +Assistant