diff --git a/llm/ggml.go b/llm/ggml.go index 18ae4bd6..c3e71b88 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -83,7 +83,7 @@ type model interface { type container interface { Name() string - Decode(*readOffset) (model, error) + Decode(*readSeekOffset) (model, error) } type containerGGML struct{} @@ -92,7 +92,7 @@ func (c *containerGGML) Name() string { return "ggml" } -func (c *containerGGML) Decode(ro *readOffset) (model, error) { +func (c *containerGGML) Decode(ro *readSeekOffset) (model, error) { return nil, nil } @@ -104,7 +104,7 @@ func (c *containerGGMF) Name() string { return "ggmf" } -func (c *containerGGMF) Decode(ro *readOffset) (model, error) { +func (c *containerGGMF) Decode(ro *readSeekOffset) (model, error) { var version uint32 binary.Read(ro, binary.LittleEndian, &version) @@ -126,7 +126,7 @@ func (c *containerGGJT) Name() string { return "ggjt" } -func (c *containerGGJT) Decode(ro *readOffset) (model, error) { +func (c *containerGGJT) Decode(ro *readSeekOffset) (model, error) { var version uint32 binary.Read(ro, binary.LittleEndian, &version) @@ -152,7 +152,7 @@ func (c *containerLORA) Name() string { return "ggla" } -func (c *containerLORA) Decode(ro *readOffset) (model, error) { +func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) { var version uint32 binary.Read(ro, binary.LittleEndian, &version) @@ -180,8 +180,8 @@ const ( FILE_MAGIC_GGUF_BE = 0x47475546 ) -func DecodeGGML(r io.Reader) (*GGML, error) { - ro := readOffset{Reader: r} +func DecodeGGML(r io.ReadSeeker) (*GGML, error) { + ro := readSeekOffset{ReadSeeker: r} var magic uint32 if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil { @@ -219,13 +219,23 @@ func DecodeGGML(r io.Reader) (*GGML, error) { }, nil } -type readOffset struct { - io.Reader +type readSeekOffset struct { + io.ReadSeeker offset int64 } -func (r *readOffset) Read(p []byte) (int, error) { - n, err := r.Reader.Read(p) - r.offset += int64(n) +func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) { + offset, err := rso.ReadSeeker.Seek(offset, whence) + if err != nil { + return 0, err + } + + rso.offset = offset + return offset, nil +} + +func (rso *readSeekOffset) Read(p []byte) (int, error) { + n, err := rso.ReadSeeker.Read(p) + rso.offset += int64(n) return n, err } diff --git a/llm/gguf.go b/llm/gguf.go index dc883187..f68b87b2 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -29,18 +29,18 @@ func (c *containerGGUF) Name() string { return "gguf" } -func (c *containerGGUF) Decode(ro *readOffset) (model, error) { - binary.Read(ro, c.bo, &c.Version) +func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) { + binary.Read(rso, c.bo, &c.Version) switch c.Version { case 1: - binary.Read(ro, c.bo, &c.V1) + binary.Read(rso, c.bo, &c.V1) default: - binary.Read(ro, c.bo, &c.V2) + binary.Read(rso, c.bo, &c.V2) } model := newGGUFModel(c) - if err := model.Decode(ro); err != nil { + if err := model.Decode(rso); err != nil { return nil, err } @@ -154,49 +154,49 @@ func (llm *ggufModel) FileType() string { return "unknown" } -func (llm *ggufModel) Decode(ro *readOffset) error { +func (llm *ggufModel) Decode(rso *readSeekOffset) error { // decode key-values for i := 0; uint64(i) < llm.NumKV(); i++ { - k, err := llm.readString(ro) + k, err := llm.readString(rso) if err != nil { return err } - vtype := llm.readU32(ro) + vtype := llm.readU32(rso) var v any switch vtype { case ggufTypeUint8: - v = llm.readU8(ro) + v = llm.readU8(rso) case ggufTypeInt8: - v = llm.readI8(ro) + v = llm.readI8(rso) case ggufTypeUint16: - v = llm.readU16(ro) + v = llm.readU16(rso) case ggufTypeInt16: - v = llm.readI16(ro) + v = llm.readI16(rso) case ggufTypeUint32: - v = llm.readU32(ro) + v = llm.readU32(rso) case ggufTypeInt32: - v = llm.readI32(ro) + v = llm.readI32(rso) case ggufTypeUint64: - v = llm.readU64(ro) + v = llm.readU64(rso) case ggufTypeInt64: - v = llm.readI64(ro) + v = llm.readI64(rso) case ggufTypeFloat32: - v = llm.readF32(ro) + v = llm.readF32(rso) case ggufTypeFloat64: - v = llm.readF64(ro) + v = llm.readF64(rso) case ggufTypeBool: - v = llm.readBool(ro) + v = llm.readBool(rso) case ggufTypeString: - s, err := llm.readString(ro) + s, err := llm.readString(rso) if err != nil { return err } v = s case ggufTypeArray: - a, err := llm.readArray(ro) + a, err := llm.readArray(rso) if err != nil { return err } @@ -211,20 +211,20 @@ func (llm *ggufModel) Decode(ro *readOffset) error { // decode tensors for i := 0; uint64(i) < llm.NumTensor(); i++ { - name, err := llm.readString(ro) + name, err := llm.readString(rso) if err != nil { return err } - dims := llm.readU32(ro) + dims := llm.readU32(rso) shape := [4]uint64{1, 1, 1, 1} for i := 0; uint32(i) < dims; i++ { - shape[i] = llm.readU64(ro) + shape[i] = llm.readU64(rso) } - kind := llm.readU32(ro) - offset := llm.readU64(ro) + kind := llm.readU32(rso) + offset := llm.readU64(rso) var blockSize uint64 switch { @@ -285,10 +285,10 @@ func (llm *ggufModel) Decode(ro *readOffset) error { alignment = 32 } - io.CopyN(io.Discard, ro, int64(alignment)-ro.offset%int64(alignment)) + rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) for _, tensor := range llm.tensors { padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1) - io.CopyN(io.Discard, ro, padded) + rso.Seek(padded, io.SeekCurrent) } return nil