aboutsummaryrefslogtreecommitdiff
path: root/forged/internal/bare/unmarshal.go
diff options
context:
space:
mode:
Diffstat (limited to 'forged/internal/bare/unmarshal.go')
-rw-r--r--forged/internal/bare/unmarshal.go362
1 files changed, 362 insertions, 0 deletions
diff --git a/forged/internal/bare/unmarshal.go b/forged/internal/bare/unmarshal.go
new file mode 100644
index 0000000..af06529
--- /dev/null
+++ b/forged/internal/bare/unmarshal.go
@@ -0,0 +1,362 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>
+
+package bare
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "sync"
+)
+
+// A type which implements this interface will be responsible for unmarshaling
+// itself when encountered.
+type Unmarshalable interface {
+ Unmarshal(r *Reader) error
+}
+
+// Unmarshals a BARE message into val, which must be a pointer to a value of
+// the message type.
+func Unmarshal(data []byte, val interface{}) error {
+ b := bytes.NewReader(data)
+ r := NewReader(b)
+ return UnmarshalBareReader(r, val)
+}
+
+// Unmarshals a BARE message into value (val, which must be a pointer), from a
+// reader. See Unmarshal for details.
+func UnmarshalReader(r io.Reader, val interface{}) error {
+ r = newLimitedReader(r)
+ return UnmarshalBareReader(NewReader(r), val)
+}
+
+type decodeFunc func(r *Reader, v reflect.Value) error
+
+var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc
+
+func UnmarshalBareReader(r *Reader, val interface{}) error {
+ t := reflect.TypeOf(val)
+ v := reflect.ValueOf(val)
+ if t.Kind() != reflect.Ptr {
+ return errors.New("Expected val to be pointer type")
+ }
+
+ return getDecoder(t.Elem())(r, v.Elem())
+}
+
+// get decoder from cache
+func getDecoder(t reflect.Type) decodeFunc {
+ if f, ok := decodeFuncCache.Load(t); ok {
+ return f.(decodeFunc)
+ }
+
+ f := decoderFunc(t)
+ decodeFuncCache.Store(t, f)
+ return f
+}
+
+var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
+
+func decoderFunc(t reflect.Type) decodeFunc {
+ if reflect.PtrTo(t).Implements(unmarshalableInterface) {
+ return func(r *Reader, v reflect.Value) error {
+ uv := v.Addr().Interface().(Unmarshalable)
+ return uv.Unmarshal(r)
+ }
+ }
+
+ if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
+ return decodeUnion(t)
+ }
+
+ switch t.Kind() {
+ case reflect.Ptr:
+ return decodeOptional(t.Elem())
+ case reflect.Struct:
+ return decodeStruct(t)
+ case reflect.Array:
+ return decodeArray(t)
+ case reflect.Slice:
+ return decodeSlice(t)
+ case reflect.Map:
+ return decodeMap(t)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return decodeUint
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return decodeInt
+ case reflect.Float32, reflect.Float64:
+ return decodeFloat
+ case reflect.Bool:
+ return decodeBool
+ case reflect.String:
+ return decodeString
+ }
+
+ return func(r *Reader, v reflect.Value) error {
+ return &UnsupportedTypeError{v.Type()}
+ }
+}
+
+func decodeOptional(t reflect.Type) decodeFunc {
+ return func(r *Reader, v reflect.Value) error {
+ s, err := r.ReadU8()
+ if err != nil {
+ return err
+ }
+
+ if s > 1 {
+ return fmt.Errorf("Invalid optional value: %#x", s)
+ }
+
+ if s == 0 {
+ return nil
+ }
+
+ v.Set(reflect.New(t))
+ return getDecoder(t)(r, v.Elem())
+ }
+}
+
+func decodeStruct(t reflect.Type) decodeFunc {
+ n := t.NumField()
+ decoders := make([]decodeFunc, n)
+ for i := 0; i < n; i++ {
+ field := t.Field(i)
+ if field.Tag.Get("bare") == "-" {
+ continue
+ }
+ decoders[i] = getDecoder(field.Type)
+ }
+
+ return func(r *Reader, v reflect.Value) error {
+ for i := 0; i < n; i++ {
+ if decoders[i] == nil {
+ continue
+ }
+ err := decoders[i](r, v.Field(i))
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func decodeArray(t reflect.Type) decodeFunc {
+ f := getDecoder(t.Elem())
+ len := t.Len()
+
+ return func(r *Reader, v reflect.Value) error {
+ for i := 0; i < len; i++ {
+ err := f(r, v.Index(i))
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func decodeSlice(t reflect.Type) decodeFunc {
+ elem := t.Elem()
+ f := getDecoder(elem)
+
+ return func(r *Reader, v reflect.Value) error {
+ len, err := r.ReadUint()
+ if err != nil {
+ return err
+ }
+
+ if len > maxArrayLength {
+ return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
+ }
+
+ v.Set(reflect.MakeSlice(t, int(len), int(len)))
+
+ for i := 0; i < int(len); i++ {
+ if err := f(r, v.Index(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func decodeMap(t reflect.Type) decodeFunc {
+ keyType := t.Key()
+ keyf := getDecoder(keyType)
+
+ valueType := t.Elem()
+ valf := getDecoder(valueType)
+
+ return func(r *Reader, v reflect.Value) error {
+ size, err := r.ReadUint()
+ if err != nil {
+ return err
+ }
+
+ if size > maxMapSize {
+ return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
+ }
+
+ v.Set(reflect.MakeMapWithSize(t, int(size)))
+
+ key := reflect.New(keyType).Elem()
+ value := reflect.New(valueType).Elem()
+
+ for i := uint64(0); i < size; i++ {
+ if err := keyf(r, key); err != nil {
+ return err
+ }
+
+ if v.MapIndex(key).Kind() > reflect.Invalid {
+ return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
+ }
+
+ if err := valf(r, value); err != nil {
+ return err
+ }
+
+ v.SetMapIndex(key, value)
+ }
+ return nil
+ }
+}
+
+func decodeUnion(t reflect.Type) decodeFunc {
+ ut, ok := unionRegistry[t]
+ if !ok {
+ return func(r *Reader, v reflect.Value) error {
+ return fmt.Errorf("Union type %s is not registered", t.Name())
+ }
+ }
+
+ decoders := make(map[uint64]decodeFunc)
+ for tag, t := range ut.types {
+ t := t
+ f := getDecoder(t)
+
+ decoders[tag] = func(r *Reader, v reflect.Value) error {
+ nv := reflect.New(t)
+ if err := f(r, nv.Elem()); err != nil {
+ return err
+ }
+
+ v.Set(nv)
+ return nil
+ }
+ }
+
+ return func(r *Reader, v reflect.Value) error {
+ tag, err := r.ReadUint()
+ if err != nil {
+ return err
+ }
+
+ if f, ok := decoders[tag]; ok {
+ return f(r, v)
+ }
+
+ return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
+ }
+}
+
+func decodeUint(r *Reader, v reflect.Value) error {
+ var err error
+ switch getIntKind(v.Type()) {
+ case reflect.Uint:
+ var u uint64
+ u, err = r.ReadUint()
+ v.SetUint(u)
+
+ case reflect.Uint8:
+ var u uint8
+ u, err = r.ReadU8()
+ v.SetUint(uint64(u))
+
+ case reflect.Uint16:
+ var u uint16
+ u, err = r.ReadU16()
+ v.SetUint(uint64(u))
+ case reflect.Uint32:
+ var u uint32
+ u, err = r.ReadU32()
+ v.SetUint(uint64(u))
+
+ case reflect.Uint64:
+ var u uint64
+ u, err = r.ReadU64()
+ v.SetUint(uint64(u))
+
+ default:
+ panic("not an uint")
+ }
+
+ return err
+}
+
+func decodeInt(r *Reader, v reflect.Value) error {
+ var err error
+ switch getIntKind(v.Type()) {
+ case reflect.Int:
+ var i int64
+ i, err = r.ReadInt()
+ v.SetInt(i)
+
+ case reflect.Int8:
+ var i int8
+ i, err = r.ReadI8()
+ v.SetInt(int64(i))
+
+ case reflect.Int16:
+ var i int16
+ i, err = r.ReadI16()
+ v.SetInt(int64(i))
+ case reflect.Int32:
+ var i int32
+ i, err = r.ReadI32()
+ v.SetInt(int64(i))
+
+ case reflect.Int64:
+ var i int64
+ i, err = r.ReadI64()
+ v.SetInt(int64(i))
+
+ default:
+ panic("not an int")
+ }
+
+ return err
+}
+
+func decodeFloat(r *Reader, v reflect.Value) error {
+ var err error
+ switch v.Type().Kind() {
+ case reflect.Float32:
+ var f float32
+ f, err = r.ReadF32()
+ v.SetFloat(float64(f))
+ case reflect.Float64:
+ var f float64
+ f, err = r.ReadF64()
+ v.SetFloat(f)
+ default:
+ panic("not a float")
+ }
+ return err
+}
+
+func decodeBool(r *Reader, v reflect.Value) error {
+ b, err := r.ReadBool()
+ v.SetBool(b)
+ return err
+}
+
+func decodeString(r *Reader, v reflect.Value) error {
+ s, err := r.ReadString()
+ v.SetString(s)
+ return err
+}