diff --git a/server/model.go b/server/model.go
index 55fb2d8d..124693d3 100644
--- a/server/model.go
+++ b/server/model.go
@@ -272,6 +272,30 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil
}
+func parseObjects(s string) []map[string]any {
+ var objs []map[string]any
+ for offset := 0; offset < len(s); {
+ var obj map[string]any
+ decoder := json.NewDecoder(strings.NewReader(s[offset:]))
+ if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
+ break
+ } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
+ // skip over any syntax errors
+ offset += int(syntax.Offset)
+ } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
+ // skip over any unmarshalable types
+ offset += int(unmarshalType.Offset)
+ } else if err != nil {
+ return nil
+ } else {
+ offset += int(decoder.InputOffset())
+ objs = append(objs, obj)
+ }
+ }
+
+ return objs
+}
+
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
@@ -304,16 +328,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false
}
- var kv map[string]any
- // execute the subtree with placeholders to identify the keys
- // trim any commands that might exist in the template
- if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
+ templateObjects := parseObjects(b.String())
+ if len(templateObjects) == 0 {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
- for k, v := range kv {
+ for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
@@ -326,43 +348,32 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false
}
- var objs []map[string]any
- for offset := 0; offset < len(s); {
- var obj map[string]any
- decoder := json.NewDecoder(strings.NewReader(s[offset:]))
- if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
- break
- } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
- // skip over any syntax errors
- offset += int(syntax.Offset)
- } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
- // skip over any unmarshalable types
- offset += int(unmarshalType.Offset)
- } else if err != nil {
- slog.Error("parseToolCalls", "error", err)
- return nil, false
- } else {
- offset += int(decoder.InputOffset())
+ responseObjects := parseObjects(s)
+ if len(responseObjects) == 0 {
+ return nil, false
+ }
- // collect all nested objects
- var collect func(any) []map[string]any
- collect = func(obj any) (all []map[string]any) {
- switch o := obj.(type) {
- case map[string]any:
- all = append(all, o)
- for _, v := range o {
- all = append(all, collect(v)...)
- }
- case []any:
- for _, v := range o {
- all = append(all, collect(v)...)
- }
- }
-
- return all
+ // collect all nested objects
+ var collect func(any) []map[string]any
+ collect = func(obj any) (all []map[string]any) {
+ switch o := obj.(type) {
+ case map[string]any:
+ all = append(all, o)
+ for _, v := range o {
+ all = append(all, collect(v)...)
+ }
+ case []any:
+ for _, v := range o {
+ all = append(all, collect(v)...)
}
- objs = append(objs, collect(obj)...)
}
+
+ return all
+ }
+
+ var objs []map[string]any
+ for _, p := range responseObjects {
+ objs = append(objs, collect(p)...)
}
var toolCalls []api.ToolCall
diff --git a/server/model_test.go b/server/model_test.go
index e1737a5b..304d4655 100644
--- a/server/model_test.go
+++ b/server/model_test.go
@@ -69,6 +69,7 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
+ {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true},
}
var tools []api.Tool
@@ -217,3 +218,45 @@ func TestParseLayerFromCopy(t *testing.T) {
t.Fatalf("got %d != want 5", len(layers))
}
}
+
+func TestParseObjects(t *testing.T) {
+ tests := []struct {
+ input string
+ want []map[string]any
+ }{
+ {
+ input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
+ want: []map[string]any{
+ {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
+ {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
+ },
+ },
+ {
+ input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
+ want: []map[string]any{
+ {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
+ },
+ },
+ {
+ input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `,
+ want: []map[string]any{
+ {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
+ {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
+ },
+ },
+ {
+ input: `{"name": "get_current_weather", "arguments": `,
+ want: nil,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.input, func(t *testing.T) {
+ got := parseObjects(tc.input)
+
+ if diff := cmp.Diff(got, tc.want); diff != "" {
+ t.Errorf("mismatch (-got +want):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/server/testdata/tools/nemotron.gotmpl b/server/testdata/tools/nemotron.gotmpl
new file mode 100644
index 00000000..1b6b89ec
--- /dev/null
+++ b/server/testdata/tools/nemotron.gotmpl
@@ -0,0 +1,33 @@
+{{- if (or .Tools .System) }}System
+{{ if .System }}{{ .System }}
+
+
+{{ end }}
+{{- if .Tools }}
+{{- range .Tools }} {{ . }} {{ end }}
+
+
+{{ end }}
+{{- end }}
+{{- range $i, $m := .Messages }}
+{{- $last := eq (len (slice $.Messages $i)) 1 -}}
+{{- if eq .Role "user" }}User
+{{ .Content }}
+{{- if $last }}
+Assistant
+{{- end }}
+{{ else if eq .Role "tool" }}Tool
+{{ .Content }}
+{{- if $last }}
+Assistant
+{{- end }}
+{{ else if eq .Role "assistant" }}Assistant
+{{- if .ToolCalls }}
+{{ range .ToolCalls }} {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} {{ end }}
+{{ else }}
+{{ .Content }}
+{{- if not $last }}
+{{ end }}
+{{- end }}
+{{- end }}
+{{- end }}
\ No newline at end of file
diff --git a/server/testdata/tools/nemotron.out b/server/testdata/tools/nemotron.out
new file mode 100644
index 00000000..2166b202
--- /dev/null
+++ b/server/testdata/tools/nemotron.out
@@ -0,0 +1,18 @@
+System
+You are a knowledgable assistant. You can answer questions and perform tasks.
+
+
+ {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
+
+
+User
+What's the weather like today in Paris?
+Assistant
+ {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
+Tool
+22
+Assistant
+The current temperature in Paris, France is 22 degrees Celsius.
+User
+What's the weather like today in San Francisco and Toronto?
+Assistant