diff --git a/api/types.go b/api/types.go index 95ed5d37..3b67d57a 100644 --- a/api/types.go +++ b/api/types.go @@ -159,49 +159,18 @@ type Options struct { // Runner options which must be set when the model is loaded into memory type Runner struct { - UseNUMA bool `json:"numa,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` - NumBatch int `json:"num_batch,omitempty"` - NumGPU int `json:"num_gpu,omitempty"` - MainGPU int `json:"main_gpu,omitempty"` - LowVRAM bool `json:"low_vram,omitempty"` - F16KV bool `json:"f16_kv,omitempty"` - LogitsAll bool `json:"logits_all,omitempty"` - VocabOnly bool `json:"vocab_only,omitempty"` - UseMMap TriState `json:"use_mmap,omitempty"` - UseMLock bool `json:"use_mlock,omitempty"` - NumThread int `json:"num_thread,omitempty"` -} - -type TriState int - -const ( - TriStateUndefined TriState = -1 - TriStateFalse TriState = 0 - TriStateTrue TriState = 1 -) - -func (b *TriState) UnmarshalJSON(data []byte) error { - var v bool - if err := json.Unmarshal(data, &v); err != nil { - return err - } - if v { - *b = TriStateTrue - } - *b = TriStateFalse - return nil -} - -func (b *TriState) MarshalJSON() ([]byte, error) { - if *b == TriStateUndefined { - return nil, nil - } - var v bool - if *b == TriStateTrue { - v = true - } - return json.Marshal(v) + UseNUMA bool `json:"numa,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumBatch int `json:"num_batch,omitempty"` + NumGPU int `json:"num_gpu,omitempty"` + MainGPU int `json:"main_gpu,omitempty"` + LowVRAM bool `json:"low_vram,omitempty"` + F16KV bool `json:"f16_kv,omitempty"` + LogitsAll bool `json:"logits_all,omitempty"` + VocabOnly bool `json:"vocab_only,omitempty"` + UseMMap *bool `json:"use_mmap,omitempty"` + UseMLock bool `json:"use_mlock,omitempty"` + NumThread int `json:"num_thread,omitempty"` } // EmbeddingRequest is the request passed to [Client.Embeddings]. @@ -437,19 +406,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error { continue } - if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) { - val, ok := val.(bool) - if !ok { - return fmt.Errorf("option %q must be of type boolean", key) - } - if val { - field.SetInt(int64(TriStateTrue)) - } else { - field.SetInt(int64(TriStateFalse)) - } - continue - } - switch field.Kind() { case reflect.Int: switch t := val.(type) { @@ -496,6 +452,17 @@ func (opts *Options) FromMap(m map[string]interface{}) error { slice[i] = str } field.Set(reflect.ValueOf(slice)) + case reflect.Pointer: + var b bool + if field.Type() == reflect.TypeOf(&b) { + val, ok := val.(bool) + if !ok { + return fmt.Errorf("option %q must be of type boolean", key) + } + field.Set(reflect.ValueOf(&val)) + } else { + return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type()) + } default: return fmt.Errorf("unknown type loading config params: %v", field.Kind()) } @@ -538,7 +505,7 @@ func DefaultOptions() Options { LowVRAM: false, F16KV: true, UseMLock: false, - UseMMap: TriStateUndefined, + UseMMap: nil, UseNUMA: false, }, } @@ -608,19 +575,6 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) { } else { field := valueOpts.FieldByName(opt.Name) if field.IsValid() && field.CanSet() { - if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) { - boolVal, err := strconv.ParseBool(vals[0]) - if err != nil { - return nil, fmt.Errorf("invalid bool value %s", vals) - } - if boolVal { - out[key] = TriStateTrue - } else { - out[key] = TriStateFalse - } - continue - } - switch field.Kind() { case reflect.Float32: floatVal, err := strconv.ParseFloat(vals[0], 32) @@ -648,6 +602,17 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) { case reflect.Slice: // TODO: only string slices are supported right now out[key] = vals + case reflect.Pointer: + var b bool + if field.Type() == reflect.TypeOf(&b) { + boolVal, err := strconv.ParseBool(vals[0]) + if err != nil { + return nil, fmt.Errorf("invalid bool value %s", vals) + } + out[key] = &boolVal + } else { + return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) + } default: return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) } diff --git a/api/types_test.go b/api/types_test.go index 8b6c60c6..c60ed90e 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -108,25 +108,27 @@ func TestDurationMarshalUnmarshal(t *testing.T) { } func TestUseMmapParsingFromJSON(t *testing.T) { + tr := true + fa := false tests := []struct { name string req string - exp TriState + exp *bool }{ { name: "Undefined", req: `{ }`, - exp: TriStateUndefined, + exp: nil, }, { name: "True", req: `{ "use_mmap": true }`, - exp: TriStateTrue, + exp: &tr, }, { name: "False", req: `{ "use_mmap": false }`, - exp: TriStateFalse, + exp: &fa, }, } @@ -144,50 +146,52 @@ func TestUseMmapParsingFromJSON(t *testing.T) { } func TestUseMmapFormatParams(t *testing.T) { + tr := true + fa := false tests := []struct { name string req map[string][]string - exp TriState + exp *bool err error }{ { name: "True", req: map[string][]string{ - "use_mmap": []string{"true"}, + "use_mmap": {"true"}, }, - exp: TriStateTrue, + exp: &tr, err: nil, }, { name: "False", req: map[string][]string{ - "use_mmap": []string{"false"}, + "use_mmap": {"false"}, }, - exp: TriStateFalse, + exp: &fa, err: nil, }, { name: "Numeric True", req: map[string][]string{ - "use_mmap": []string{"1"}, + "use_mmap": {"1"}, }, - exp: TriStateTrue, + exp: &tr, err: nil, }, { name: "Numeric False", req: map[string][]string{ - "use_mmap": []string{"0"}, + "use_mmap": {"0"}, }, - exp: TriStateFalse, + exp: &fa, err: nil, }, { name: "invalid string", req: map[string][]string{ - "use_mmap": []string{"foo"}, + "use_mmap": {"foo"}, }, - exp: TriStateUndefined, + exp: nil, err: fmt.Errorf("invalid bool value [foo]"), }, } @@ -195,11 +199,11 @@ func TestUseMmapFormatParams(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { resp, err := FormatParams(test.req) - require.Equal(t, err, test.err) + require.Equal(t, test.err, err) respVal, ok := resp["use_mmap"] - if test.exp != TriStateUndefined { + if test.exp != nil { assert.True(t, ok, "resp: %v", resp) - assert.Equal(t, test.exp, respVal) + assert.Equal(t, *test.exp, *respVal.(*bool)) } }) } diff --git a/llm/server.go b/llm/server.go index 61346069..821f6efd 100644 --- a/llm/server.go +++ b/llm/server.go @@ -208,7 +208,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr if g.Library == "metal" && uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { - opts.UseMMap = api.TriStateFalse + opts.UseMMap = new(bool) + *opts.UseMMap = false } } @@ -219,10 +220,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache - if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) || - (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) || - (gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) || - opts.UseMMap == api.TriStateFalse { + if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) || + (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) || + (gpus[0].Library == "cpu" && opts.UseMMap == nil) || + (opts.UseMMap != nil && !*opts.UseMMap) { params = append(params, "--no-mmap") }