diff options
Diffstat (limited to 'forged/internal/bare/unmarshal.go')
-rw-r--r-- | forged/internal/bare/unmarshal.go | 362 |
1 files changed, 362 insertions, 0 deletions
diff --git a/forged/internal/bare/unmarshal.go b/forged/internal/bare/unmarshal.go new file mode 100644 index 0000000..af06529 --- /dev/null +++ b/forged/internal/bare/unmarshal.go @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "sync" +) + +// A type which implements this interface will be responsible for unmarshaling +// itself when encountered. +type Unmarshalable interface { + Unmarshal(r *Reader) error +} + +// Unmarshals a BARE message into val, which must be a pointer to a value of +// the message type. +func Unmarshal(data []byte, val interface{}) error { + b := bytes.NewReader(data) + r := NewReader(b) + return UnmarshalBareReader(r, val) +} + +// Unmarshals a BARE message into value (val, which must be a pointer), from a +// reader. See Unmarshal for details. +func UnmarshalReader(r io.Reader, val interface{}) error { + r = newLimitedReader(r) + return UnmarshalBareReader(NewReader(r), val) +} + +type decodeFunc func(r *Reader, v reflect.Value) error + +var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc + +func UnmarshalBareReader(r *Reader, val interface{}) error { + t := reflect.TypeOf(val) + v := reflect.ValueOf(val) + if t.Kind() != reflect.Ptr { + return errors.New("Expected val to be pointer type") + } + + return getDecoder(t.Elem())(r, v.Elem()) +} + +// get decoder from cache +func getDecoder(t reflect.Type) decodeFunc { + if f, ok := decodeFuncCache.Load(t); ok { + return f.(decodeFunc) + } + + f := decoderFunc(t) + decodeFuncCache.Store(t, f) + return f +} + +var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem() + +func decoderFunc(t reflect.Type) decodeFunc { + if reflect.PtrTo(t).Implements(unmarshalableInterface) { + return func(r *Reader, v reflect.Value) error { + uv := v.Addr().Interface().(Unmarshalable) + return uv.Unmarshal(r) + } + } + + if t.Kind() == reflect.Interface && t.Implements(unionInterface) { + return decodeUnion(t) + } + + switch t.Kind() { + case reflect.Ptr: + return decodeOptional(t.Elem()) + case reflect.Struct: + return decodeStruct(t) + case reflect.Array: + return decodeArray(t) + case reflect.Slice: + return decodeSlice(t) + case reflect.Map: + return decodeMap(t) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return decodeUint + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return decodeInt + case reflect.Float32, reflect.Float64: + return decodeFloat + case reflect.Bool: + return decodeBool + case reflect.String: + return decodeString + } + + return func(r *Reader, v reflect.Value) error { + return &UnsupportedTypeError{v.Type()} + } +} + +func decodeOptional(t reflect.Type) decodeFunc { + return func(r *Reader, v reflect.Value) error { + s, err := r.ReadU8() + if err != nil { + return err + } + + if s > 1 { + return fmt.Errorf("Invalid optional value: %#x", s) + } + + if s == 0 { + return nil + } + + v.Set(reflect.New(t)) + return getDecoder(t)(r, v.Elem()) + } +} + +func decodeStruct(t reflect.Type) decodeFunc { + n := t.NumField() + decoders := make([]decodeFunc, n) + for i := 0; i < n; i++ { + field := t.Field(i) + if field.Tag.Get("bare") == "-" { + continue + } + decoders[i] = getDecoder(field.Type) + } + + return func(r *Reader, v reflect.Value) error { + for i := 0; i < n; i++ { + if decoders[i] == nil { + continue + } + err := decoders[i](r, v.Field(i)) + if err != nil { + return err + } + } + return nil + } +} + +func decodeArray(t reflect.Type) decodeFunc { + f := getDecoder(t.Elem()) + len := t.Len() + + return func(r *Reader, v reflect.Value) error { + for i := 0; i < len; i++ { + err := f(r, v.Index(i)) + if err != nil { + return err + } + } + return nil + } +} + +func decodeSlice(t reflect.Type) decodeFunc { + elem := t.Elem() + f := getDecoder(elem) + + return func(r *Reader, v reflect.Value) error { + len, err := r.ReadUint() + if err != nil { + return err + } + + if len > maxArrayLength { + return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength) + } + + v.Set(reflect.MakeSlice(t, int(len), int(len))) + + for i := 0; i < int(len); i++ { + if err := f(r, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func decodeMap(t reflect.Type) decodeFunc { + keyType := t.Key() + keyf := getDecoder(keyType) + + valueType := t.Elem() + valf := getDecoder(valueType) + + return func(r *Reader, v reflect.Value) error { + size, err := r.ReadUint() + if err != nil { + return err + } + + if size > maxMapSize { + return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize) + } + + v.Set(reflect.MakeMapWithSize(t, int(size))) + + key := reflect.New(keyType).Elem() + value := reflect.New(valueType).Elem() + + for i := uint64(0); i < size; i++ { + if err := keyf(r, key); err != nil { + return err + } + + if v.MapIndex(key).Kind() > reflect.Invalid { + return fmt.Errorf("Encountered duplicate map key: %v", key.Interface()) + } + + if err := valf(r, value); err != nil { + return err + } + + v.SetMapIndex(key, value) + } + return nil + } +} + +func decodeUnion(t reflect.Type) decodeFunc { + ut, ok := unionRegistry[t] + if !ok { + return func(r *Reader, v reflect.Value) error { + return fmt.Errorf("Union type %s is not registered", t.Name()) + } + } + + decoders := make(map[uint64]decodeFunc) + for tag, t := range ut.types { + t := t + f := getDecoder(t) + + decoders[tag] = func(r *Reader, v reflect.Value) error { + nv := reflect.New(t) + if err := f(r, nv.Elem()); err != nil { + return err + } + + v.Set(nv) + return nil + } + } + + return func(r *Reader, v reflect.Value) error { + tag, err := r.ReadUint() + if err != nil { + return err + } + + if f, ok := decoders[tag]; ok { + return f(r, v) + } + + return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name()) + } +} + +func decodeUint(r *Reader, v reflect.Value) error { + var err error + switch getIntKind(v.Type()) { + case reflect.Uint: + var u uint64 + u, err = r.ReadUint() + v.SetUint(u) + + case reflect.Uint8: + var u uint8 + u, err = r.ReadU8() + v.SetUint(uint64(u)) + + case reflect.Uint16: + var u uint16 + u, err = r.ReadU16() + v.SetUint(uint64(u)) + case reflect.Uint32: + var u uint32 + u, err = r.ReadU32() + v.SetUint(uint64(u)) + + case reflect.Uint64: + var u uint64 + u, err = r.ReadU64() + v.SetUint(uint64(u)) + + default: + panic("not an uint") + } + + return err +} + +func decodeInt(r *Reader, v reflect.Value) error { + var err error + switch getIntKind(v.Type()) { + case reflect.Int: + var i int64 + i, err = r.ReadInt() + v.SetInt(i) + + case reflect.Int8: + var i int8 + i, err = r.ReadI8() + v.SetInt(int64(i)) + + case reflect.Int16: + var i int16 + i, err = r.ReadI16() + v.SetInt(int64(i)) + case reflect.Int32: + var i int32 + i, err = r.ReadI32() + v.SetInt(int64(i)) + + case reflect.Int64: + var i int64 + i, err = r.ReadI64() + v.SetInt(int64(i)) + + default: + panic("not an int") + } + + return err +} + +func decodeFloat(r *Reader, v reflect.Value) error { + var err error + switch v.Type().Kind() { + case reflect.Float32: + var f float32 + f, err = r.ReadF32() + v.SetFloat(float64(f)) + case reflect.Float64: + var f float64 + f, err = r.ReadF64() + v.SetFloat(f) + default: + panic("not a float") + } + return err +} + +func decodeBool(r *Reader, v reflect.Value) error { + b, err := r.ReadBool() + v.SetBool(b) + return err +} + +func decodeString(r *Reader, v reflect.Value) error { + s, err := r.ReadString() + v.SetString(s) + return err +} |