From 34d5ef29b3d01e2a0785af96df1135dfec567a3e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 17 May 2024 12:11:49 -0700 Subject: [PATCH] fix conversion for f16 or f32 inputs --- convert/gemma.go | 49 +++++---------- convert/llama.go | 136 ++++++++++++++++------------------------- convert/mistral.go | 91 ++------------------------- convert/mixtral.go | 6 +- convert/safetensors.go | 85 ++++++++++++++------------ convert/torch.go | 77 +++++++++-------------- go.mod | 2 +- 7 files changed, 152 insertions(+), 294 deletions(-) diff --git a/convert/gemma.go b/convert/gemma.go index e24b8ec5..9dc406e0 100644 --- a/convert/gemma.go +++ b/convert/gemma.go @@ -1,14 +1,11 @@ package convert import ( - "encoding/binary" "fmt" "io" "log/slog" - "os" "strings" - "github.com/d4l3k/go-bfloat16" "github.com/pdevine/tensor" "github.com/pdevine/tensor/native" @@ -19,49 +16,27 @@ type GemmaModel struct { ModelData } -func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error { - slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name)) - - data := make([]byte, r.end-r.start) - if err := binary.Read(f, r.bo, data); err != nil { - return err - } - - tDataF32 := bfloat16.DecodeFloat32(data) - - var err error - tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0])) - if err != nil { - return err - } - - if err := binary.Write(w, r.bo, tDataF32); err != nil { - return err - } - return nil -} - func addOnes(data []float32, vectorSize int) ([]float32, error) { n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data)) ones := tensor.Ones(tensor.Float32, vectorSize) - var err error - n, err = n.Add(ones) + n, err := n.Add(ones) if err != nil { - return []float32{}, err + return nil, err } - newN, err := native.SelectF32(n, 0) + ts, err := native.SelectF32(n, 0) if err != nil { - return []float32{}, err + return nil, err } - var fullTensor []float32 - for _, v := range newN { - fullTensor = append(fullTensor, v...) + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) } - return fullTensor, nil + + return f32s, nil } func (m *GemmaModel) GetTensors() error { @@ -74,7 +49,7 @@ func (m *GemmaModel) GetTensors() error { for _, l := range t { if strings.HasSuffix(l.Name, "norm.weight") { wt := l.WriterTo.(safetensorWriterTo) - wt.handler = gemmaLayerHandler + wt.repacker = m.Repack l.WriterTo = wt } m.Tensors = append(m.Tensors, l) @@ -92,6 +67,10 @@ func (m *GemmaModel) LoadVocab() error { return nil } +func (m *GemmaModel) Repack(_ string, data []float32, shape []uint64) ([]float32, error) { + return addOnes(data, int(shape[0])) +} + func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "gemma", diff --git a/convert/llama.go b/convert/llama.go index a10670e6..7853c4cf 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -1,7 +1,7 @@ package convert import ( - "encoding/binary" + "cmp" "errors" "fmt" "io" @@ -10,10 +10,8 @@ import ( "regexp" "strings" - "github.com/nlpodyssey/gopickle/pytorch" "github.com/pdevine/tensor" "github.com/pdevine/tensor/native" - "github.com/x448/float16" "github.com/ollama/ollama/llm" ) @@ -22,83 +20,6 @@ type LlamaModel struct { ModelData } -func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error { - - var tData []uint16 - switch r.storage.(type) { - case *pytorch.HalfStorage: - data := r.storage.(*pytorch.HalfStorage).Data - tData = make([]uint16, len(data)) - for cnt, v := range data { - tData[cnt] = uint16(float16.Fromfloat32(v)) - } - case *pytorch.BFloat16Storage: - data := r.storage.(*pytorch.BFloat16Storage).Data - tData = make([]uint16, len(data)) - - for cnt, v := range data { - tData[cnt] = uint16(float16.Fromfloat32(v)) - } - default: - return fmt.Errorf("unknown storage type for torch") - } - - var err error - var heads uint32 - if strings.Contains(r.t.Name, "attn_q") { - heads = uint32(r.params.AttentionHeads) - } else if strings.Contains(r.t.Name, "attn_k") { - heads = uint32(r.params.KeyValHeads) - if heads == 0 { - heads = uint32(r.params.AttentionHeads) - } - } else { - return fmt.Errorf("unknown layer type") - } - - tData, err = llamaRepack(tData, int(heads), r.t.Shape) - if err != nil { - return err - } - - if err = binary.Write(w, r.bo, tData); err != nil { - return err - } - return nil -} - -func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) { - n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data)) - origShape := n.Shape().Clone() - - // reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf - if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil { - return nil, err - } - - if err := n.T(0, 2, 1, 3); err != nil { - return nil, err - } - - if err := n.Reshape(origShape...); err != nil { - return nil, err - } - - if err := n.Transpose(); err != nil { - return nil, err - } - newN, err := native.SelectU16(n, 1) - if err != nil { - return nil, err - } - - var fullTensor []uint16 - for _, v := range newN { - fullTensor = append(fullTensor, v...) - } - return fullTensor, nil -} - func (m *LlamaModel) GetTensors() error { t, err := m.Format.GetTensors(m.Path, m.Params) if err != nil { @@ -117,11 +38,11 @@ func (m *LlamaModel) GetTensors() error { switch m.Format.(type) { case *TorchFormat: wt := l.WriterTo.(torchWriterTo) - wt.handler = llamaTorchLayerHandler + wt.repacker = m.Repack l.WriterTo = wt case *SafetensorFormat: wt := l.WriterTo.(safetensorWriterTo) - wt.handler = mistralLayerHandler + wt.repacker = m.Repack l.WriterTo = wt } } @@ -184,3 +105,54 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error { return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } + +func (m *LlamaModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) { + return llamaRepack(name, m.Params, data, shape) +} + +func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([]float32, error) { + var dims []int + for _, dim := range shape { + if dim != 0 { + dims = append(dims, int(dim)) + } + } + + var heads int + if strings.HasSuffix(name, "attn_q.weight") { + heads = params.AttentionHeads + } else if strings.HasSuffix(name, "attn_k.weight") { + heads = cmp.Or(params.KeyValHeads, params.AttentionHeads) + } else { + return nil, fmt.Errorf("unknown tensor name: %s", name) + } + + n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + if err := n.Reshape(append([]int{heads, 2, dims[0] / heads / 2}, dims[1:]...)...); err != nil { + return nil, err + } + + if err := n.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := n.Reshape(dims...); err != nil { + return nil, err + } + + if err := n.Transpose(); err != nil { + return nil, err + } + + ts, err := native.SelectF32(n, 1) + if err != nil { + return nil, err + } + + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) + } + + return f32s, nil +} diff --git a/convert/mistral.go b/convert/mistral.go index 89d2e084..da6874cf 100644 --- a/convert/mistral.go +++ b/convert/mistral.go @@ -1,17 +1,8 @@ package convert import ( - "encoding/binary" - "fmt" "io" - "os" "regexp" - "strings" - - "github.com/d4l3k/go-bfloat16" - "github.com/pdevine/tensor" - "github.com/pdevine/tensor/native" - "github.com/x448/float16" "github.com/ollama/ollama/llm" ) @@ -20,82 +11,6 @@ type MistralModel struct { ModelData } -func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error { - layerSize := r.end - r.start - - var err error - tData := make([]uint16, layerSize/2) - if err = binary.Read(f, r.bo, tData); err != nil { - return err - } - - var heads uint32 - if strings.Contains(r.t.Name, "attn_q") { - heads = uint32(r.params.AttentionHeads) - } else if strings.Contains(r.t.Name, "attn_k") { - heads = uint32(r.params.KeyValHeads) - if heads == 0 { - heads = uint32(r.params.AttentionHeads) - } - } else { - return fmt.Errorf("unknown layer type") - } - - tData, err = repack(tData, int(heads), r.t.Shape) - if err != nil { - return err - } - - var buf []byte - for _, n := range tData { - buf = r.bo.AppendUint16(buf, n) - } - - tempBuf := make([]uint16, len(tData)) - tDataF32 := bfloat16.DecodeFloat32(buf) - for cnt, v := range tDataF32 { - tDataF16 := float16.Fromfloat32(v) - tempBuf[cnt] = uint16(tDataF16) - } - - if err = binary.Write(w, r.bo, tempBuf); err != nil { - return err - } - return nil -} - -func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) { - n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data)) - origShape := n.Shape().Clone() - - // reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf - if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil { - return nil, err - } - - if err := n.T(0, 2, 1, 3); err != nil { - return nil, err - } - - if err := n.Reshape(origShape...); err != nil { - return nil, err - } - - if err := n.Transpose(); err != nil { - return nil, err - } - newN, err := native.SelectU16(n, 1) - if err != nil { - return nil, err - } - - var fullTensor []uint16 - for _, v := range newN { - fullTensor = append(fullTensor, v...) - } - return fullTensor, nil -} - func (m *MistralModel) GetTensors() error { t, err := m.Format.GetTensors(m.Path, m.Params) if err != nil { @@ -112,7 +27,7 @@ func (m *MistralModel) GetTensors() error { matches := re.FindAllStringSubmatch(l.Name, -1) if len(matches) > 0 { wt := l.WriterTo.(safetensorWriterTo) - wt.handler = mistralLayerHandler + wt.repacker = m.Repack l.WriterTo = wt } m.Tensors = append(m.Tensors, l) @@ -158,3 +73,7 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error { return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } + +func (m *MistralModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) { + return llamaRepack(name, m.Params, data, shape) +} diff --git a/convert/mixtral.go b/convert/mixtral.go index 66546fd7..baea68cd 100644 --- a/convert/mixtral.go +++ b/convert/mixtral.go @@ -27,7 +27,7 @@ func (m *MixtralModel) GetTensors() error { matches := re.FindAllStringSubmatch(l.Name, -1) if len(matches) > 0 { wt := l.WriterTo.(safetensorWriterTo) - wt.handler = mistralLayerHandler + wt.repacker = m.Repack l.WriterTo = wt } m.Tensors = append(m.Tensors, l) @@ -81,3 +81,7 @@ func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error { return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } + +func (m *MixtralModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) { + return llamaRepack(name, m.Params, data, shape) +} diff --git a/convert/safetensors.go b/convert/safetensors.go index 2107ae81..9de9a002 100644 --- a/convert/safetensors.go +++ b/convert/safetensors.go @@ -27,9 +27,10 @@ type safetensorWriterTo struct { bo ByteOrder filename string + dtype string start, end, padding uint64 - handler func(w io.Writer, r safetensorWriterTo, f *os.File) error + repacker func(string, []float32, []uint64) ([]float32, error) } type tensorMetaData struct { @@ -150,6 +151,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) params: params, bo: params.ByteOrder, filename: fn, + dtype: data.Type, start: uint64(data.Offsets[0]), end: uint64(data.Offsets[1]), padding: 8 + jsonSize, @@ -235,51 +237,54 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) { return 0, err } - // use the handler if one is present - if r.handler != nil { - return 0, r.handler(w, r, f) - } - - remaining := r.end - r.start - - bufSize := uint64(10240) - var finished bool - for { - data := make([]byte, min(bufSize, remaining)) - - b, err := io.ReadFull(f, data) - remaining -= uint64(b) - - if err == io.EOF || remaining <= 0 { - finished = true - } else if err != nil { + var f32s []float32 + switch r.dtype { + case "F32": + f32s = make([]float32, (r.end-r.start)/4) + if err = binary.Read(f, r.bo, f32s); err != nil { + return 0, err + } + case "F16": + bts := make([]uint16, (r.end-r.start)/2) + if err = binary.Read(f, r.bo, bts); err != nil { return 0, err } - // convert bfloat16 -> ieee float32 - tDataF32 := bfloat16.DecodeFloat32(data) - - switch r.t.Kind { - case 0: - if err := binary.Write(w, r.bo, tDataF32); err != nil { - return 0, err - } - case 1: - // convert float32 -> float16 - tempBuf := make([]uint16, len(data)/2) - for cnt, v := range tDataF32 { - tDataF16 := float16.Fromfloat32(v) - tempBuf[cnt] = uint16(tDataF16) - } - if err := binary.Write(w, r.bo, tempBuf); err != nil { - return 0, err - } + for _, b := range bts { + f32s = append(f32s, float16.Frombits(b).Float32()) } - if finished { - break + + case "BF16": + bts := make([]byte, r.end-r.start) + if err = binary.Read(f, r.bo, bts); err != nil { + return 0, err + } + + f32s = bfloat16.DecodeFloat32(bts) + default: + return 0, fmt.Errorf("unknown data type: %s", r.dtype) + } + + if r.repacker != nil { + f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape) + if err != nil { + return 0, err } } - return 0, nil + + switch r.t.Kind { + case 0: + return 0, binary.Write(w, r.bo, f32s) + case 1: + f16s := make([]uint16, len(f32s)) + for i := range f32s { + f16s[i] = float16.Fromfloat32(f32s[i]).Bits() + } + + return 0, binary.Write(w, r.bo, f16s) + default: + return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind) + } } func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) { diff --git a/convert/torch.go b/convert/torch.go index cb8d74b0..b7ae0f76 100644 --- a/convert/torch.go +++ b/convert/torch.go @@ -24,8 +24,8 @@ type torchWriterTo struct { params *Params bo ByteOrder - storage pytorch.StorageInterface - handler func(w io.Writer, r torchWriterTo) error + storage pytorch.StorageInterface + repacker func(string, []float32, []uint64) ([]float32, error) } type TorchFormat struct{} @@ -230,59 +230,38 @@ func (m *TorchFormat) GetLayerName(n string) (string, error) { } func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) { - // use the handler if one is present - if r.handler != nil { - return 0, r.handler(w, r) - } - - switch storage := r.storage.(type) { + var f32s []float32 + switch s := r.storage.(type) { case *pytorch.FloatStorage: - slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name)) - return 0, nil + f32s = s.Data case *pytorch.HalfStorage: - switch r.t.Kind { - case 0: - data := r.storage.(*pytorch.HalfStorage).Data - slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data))) - if err := binary.Write(w, r.bo, data); err != nil { - return 0, err - } - case 1: - data := r.storage.(*pytorch.HalfStorage).Data - tData := make([]uint16, len(data)) - for cnt, v := range data { - tData[cnt] = uint16(float16.Fromfloat32(v)) - } - slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData))) - if err := binary.Write(w, r.bo, tData); err != nil { - return 0, err - } - } + f32s = s.Data case *pytorch.BFloat16Storage: - data := r.storage.(*pytorch.BFloat16Storage).Data - switch r.t.Kind { - case 0: - if err = binary.Write(w, r.bo, data); err != nil { - return 0, err - } - case 1: - tData := make([]uint16, len(data)) - - for cnt, v := range data { - tData[cnt] = uint16(float16.Fromfloat32(v)) - } - - if err = binary.Write(w, r.bo, tData); err != nil { - return 0, err - } - default: - return 0, fmt.Errorf("unknown storage kind: %d", r.t.Kind) - } + f32s = s.Data default: - return 0, fmt.Errorf("unknown storage type: %T", storage) + return 0, fmt.Errorf("unknown data type: %T", s) } - return 0, nil + if r.repacker != nil { + f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape) + if err != nil { + return 0, err + } + } + + switch r.t.Kind { + case 0: + return 0, binary.Write(w, r.bo, f32s) + case 1: + f16s := make([]uint16, len(f32s)) + for i := range f32s { + f16s[i] = float16.Fromfloat32(f32s[i]).Bits() + } + + return 0, binary.Write(w, r.bo, f16s) + default: + return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind) + } } func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) { diff --git a/go.mod b/go.mod index 5d0d3c33..255c8a04 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.22.0 require ( github.com/containerd/console v1.0.3 - github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/emirpasic/gods v1.18.1 github.com/gin-gonic/gin v1.10.0 github.com/golang/protobuf v1.5.4 // indirect @@ -18,6 +17,7 @@ require ( ) require ( + github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c