ollama/convert/reader.go

80 lines
1.4 KiB
Go
Raw Normal View History

2024-05-31 23:00:49 -04:00
package convert
import (
"errors"
"io"
2024-06-29 19:53:59 -04:00
"io/fs"
2024-05-31 23:00:49 -04:00
"strings"
)
type Tensor interface {
Name() string
Shape() []uint64
Kind() uint32
SetRepacker(repacker)
WriteTo(io.Writer) (int64, error)
}
type tensorBase struct {
name string
shape []uint64
repacker
}
func (t tensorBase) Name() string {
return t.name
}
func (t tensorBase) Shape() []uint64 {
return t.shape
}
2024-07-08 19:59:48 -04:00
const (
tensorKindF32 uint32 = iota
tensorKindF16
)
2024-05-31 23:00:49 -04:00
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".block_sparse_moe.gate.weight") {
return 0
}
switch len(t.shape) {
case 0:
panic("invalid tensor shape")
case 1:
2024-07-08 19:59:48 -04:00
return tensorKindF32
2024-05-31 23:00:49 -04:00
default:
2024-07-08 19:59:48 -04:00
return tensorKindF16
2024-05-31 23:00:49 -04:00
}
}
func (t *tensorBase) SetRepacker(fn repacker) {
t.repacker = fn
}
type repacker func(string, []float32, []uint64) ([]float32, error)
2024-06-29 19:53:59 -04:00
func parseTensors(fsys fs.FS) ([]Tensor, error) {
patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
2024-05-31 23:00:49 -04:00
"model-*-of-*.safetensors": parseSafetensors,
"model.safetensors": parseSafetensors,
"pytorch_model-*-of-*.bin": parseTorch,
"pytorch_model.bin": parseTorch,
"consolidated.*.pth": parseTorch,
}
for pattern, parseFn := range patterns {
2024-06-29 19:53:59 -04:00
matches, err := fs.Glob(fsys, pattern)
2024-05-31 23:00:49 -04:00
if err != nil {
return nil, err
}
if len(matches) > 0 {
2024-06-29 19:53:59 -04:00
return parseFn(fsys, matches...)
2024-05-31 23:00:49 -04:00
}
}
return nil, errors.New("unknown tensor format")
}