aboutsummaryrefslogtreecommitdiff
path: root/forged/internal/bare/unmarshal.go
diff options
context:
space:
mode:
Diffstat (limited to 'forged/internal/bare/unmarshal.go')
-rw-r--r--forged/internal/bare/unmarshal.go362
1 files changed, 0 insertions, 362 deletions
diff --git a/forged/internal/bare/unmarshal.go b/forged/internal/bare/unmarshal.go
deleted file mode 100644
index d55f32c..0000000
--- a/forged/internal/bare/unmarshal.go
+++ /dev/null
@@ -1,362 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>
-
-package bare
-
-import (
- "bytes"
- "errors"
- "fmt"
- "io"
- "reflect"
- "sync"
-)
-
-// A type which implements this interface will be responsible for unmarshaling
-// itself when encountered.
-type Unmarshalable interface {
- Unmarshal(r *Reader) error
-}
-
-// Unmarshals a BARE message into val, which must be a pointer to a value of
-// the message type.
-func Unmarshal(data []byte, val interface{}) error {
- b := bytes.NewReader(data)
- r := NewReader(b)
- return UnmarshalBareReader(r, val)
-}
-
-// Unmarshals a BARE message into value (val, which must be a pointer), from a
-// reader. See Unmarshal for details.
-func UnmarshalReader(r io.Reader, val interface{}) error {
- r = newLimitedReader(r)
- return UnmarshalBareReader(NewReader(r), val)
-}
-
-type decodeFunc func(r *Reader, v reflect.Value) error
-
-var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc
-
-func UnmarshalBareReader(r *Reader, val interface{}) error {
- t := reflect.TypeOf(val)
- v := reflect.ValueOf(val)
- if t.Kind() != reflect.Ptr {
- return errors.New("Expected val to be pointer type")
- }
-
- return getDecoder(t.Elem())(r, v.Elem())
-}
-
-// get decoder from cache
-func getDecoder(t reflect.Type) decodeFunc {
- if f, ok := decodeFuncCache.Load(t); ok {
- return f.(decodeFunc)
- }
-
- f := decoderFunc(t)
- decodeFuncCache.Store(t, f)
- return f
-}
-
-var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
-
-func decoderFunc(t reflect.Type) decodeFunc {
- if reflect.PointerTo(t).Implements(unmarshalableInterface) {
- return func(r *Reader, v reflect.Value) error {
- uv := v.Addr().Interface().(Unmarshalable)
- return uv.Unmarshal(r)
- }
- }
-
- if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
- return decodeUnion(t)
- }
-
- switch t.Kind() {
- case reflect.Ptr:
- return decodeOptional(t.Elem())
- case reflect.Struct:
- return decodeStruct(t)
- case reflect.Array:
- return decodeArray(t)
- case reflect.Slice:
- return decodeSlice(t)
- case reflect.Map:
- return decodeMap(t)
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- return decodeUint
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return decodeInt
- case reflect.Float32, reflect.Float64:
- return decodeFloat
- case reflect.Bool:
- return decodeBool
- case reflect.String:
- return decodeString
- }
-
- return func(r *Reader, v reflect.Value) error {
- return &UnsupportedTypeError{v.Type()}
- }
-}
-
-func decodeOptional(t reflect.Type) decodeFunc {
- return func(r *Reader, v reflect.Value) error {
- s, err := r.ReadU8()
- if err != nil {
- return err
- }
-
- if s > 1 {
- return fmt.Errorf("Invalid optional value: %#x", s)
- }
-
- if s == 0 {
- return nil
- }
-
- v.Set(reflect.New(t))
- return getDecoder(t)(r, v.Elem())
- }
-}
-
-func decodeStruct(t reflect.Type) decodeFunc {
- n := t.NumField()
- decoders := make([]decodeFunc, n)
- for i := 0; i < n; i++ {
- field := t.Field(i)
- if field.Tag.Get("bare") == "-" {
- continue
- }
- decoders[i] = getDecoder(field.Type)
- }
-
- return func(r *Reader, v reflect.Value) error {
- for i := 0; i < n; i++ {
- if decoders[i] == nil {
- continue
- }
- err := decoders[i](r, v.Field(i))
- if err != nil {
- return err
- }
- }
- return nil
- }
-}
-
-func decodeArray(t reflect.Type) decodeFunc {
- f := getDecoder(t.Elem())
- len := t.Len()
-
- return func(r *Reader, v reflect.Value) error {
- for i := 0; i < len; i++ {
- err := f(r, v.Index(i))
- if err != nil {
- return err
- }
- }
- return nil
- }
-}
-
-func decodeSlice(t reflect.Type) decodeFunc {
- elem := t.Elem()
- f := getDecoder(elem)
-
- return func(r *Reader, v reflect.Value) error {
- len, err := r.ReadUint()
- if err != nil {
- return err
- }
-
- if len > maxArrayLength {
- return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
- }
-
- v.Set(reflect.MakeSlice(t, int(len), int(len)))
-
- for i := 0; i < int(len); i++ {
- if err := f(r, v.Index(i)); err != nil {
- return err
- }
- }
- return nil
- }
-}
-
-func decodeMap(t reflect.Type) decodeFunc {
- keyType := t.Key()
- keyf := getDecoder(keyType)
-
- valueType := t.Elem()
- valf := getDecoder(valueType)
-
- return func(r *Reader, v reflect.Value) error {
- size, err := r.ReadUint()
- if err != nil {
- return err
- }
-
- if size > maxMapSize {
- return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
- }
-
- v.Set(reflect.MakeMapWithSize(t, int(size)))
-
- key := reflect.New(keyType).Elem()
- value := reflect.New(valueType).Elem()
-
- for i := uint64(0); i < size; i++ {
- if err := keyf(r, key); err != nil {
- return err
- }
-
- if v.MapIndex(key).Kind() > reflect.Invalid {
- return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
- }
-
- if err := valf(r, value); err != nil {
- return err
- }
-
- v.SetMapIndex(key, value)
- }
- return nil
- }
-}
-
-func decodeUnion(t reflect.Type) decodeFunc {
- ut, ok := unionRegistry[t]
- if !ok {
- return func(r *Reader, v reflect.Value) error {
- return fmt.Errorf("Union type %s is not registered", t.Name())
- }
- }
-
- decoders := make(map[uint64]decodeFunc)
- for tag, t := range ut.types {
- t := t
- f := getDecoder(t)
-
- decoders[tag] = func(r *Reader, v reflect.Value) error {
- nv := reflect.New(t)
- if err := f(r, nv.Elem()); err != nil {
- return err
- }
-
- v.Set(nv)
- return nil
- }
- }
-
- return func(r *Reader, v reflect.Value) error {
- tag, err := r.ReadUint()
- if err != nil {
- return err
- }
-
- if f, ok := decoders[tag]; ok {
- return f(r, v)
- }
-
- return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
- }
-}
-
-func decodeUint(r *Reader, v reflect.Value) error {
- var err error
- switch getIntKind(v.Type()) {
- case reflect.Uint:
- var u uint64
- u, err = r.ReadUint()
- v.SetUint(u)
-
- case reflect.Uint8:
- var u uint8
- u, err = r.ReadU8()
- v.SetUint(uint64(u))
-
- case reflect.Uint16:
- var u uint16
- u, err = r.ReadU16()
- v.SetUint(uint64(u))
- case reflect.Uint32:
- var u uint32
- u, err = r.ReadU32()
- v.SetUint(uint64(u))
-
- case reflect.Uint64:
- var u uint64
- u, err = r.ReadU64()
- v.SetUint(uint64(u))
-
- default:
- panic("not an uint")
- }
-
- return err
-}
-
-func decodeInt(r *Reader, v reflect.Value) error {
- var err error
- switch getIntKind(v.Type()) {
- case reflect.Int:
- var i int64
- i, err = r.ReadInt()
- v.SetInt(i)
-
- case reflect.Int8:
- var i int8
- i, err = r.ReadI8()
- v.SetInt(int64(i))
-
- case reflect.Int16:
- var i int16
- i, err = r.ReadI16()
- v.SetInt(int64(i))
- case reflect.Int32:
- var i int32
- i, err = r.ReadI32()
- v.SetInt(int64(i))
-
- case reflect.Int64:
- var i int64
- i, err = r.ReadI64()
- v.SetInt(int64(i))
-
- default:
- panic("not an int")
- }
-
- return err
-}
-
-func decodeFloat(r *Reader, v reflect.Value) error {
- var err error
- switch v.Type().Kind() {
- case reflect.Float32:
- var f float32
- f, err = r.ReadF32()
- v.SetFloat(float64(f))
- case reflect.Float64:
- var f float64
- f, err = r.ReadF64()
- v.SetFloat(f)
- default:
- panic("not a float")
- }
- return err
-}
-
-func decodeBool(r *Reader, v reflect.Value) error {
- b, err := r.ReadBool()
- v.SetBool(b)
- return err
-}
-
-func decodeString(r *Reader, v reflect.Value) error {
- s, err := r.ReadString()
- v.SetString(s)
- return err
-}