aboutsummaryrefslogtreecommitdiff
path: root/forged/internal/bare/marshal.go
diff options
context:
space:
mode:
Diffstat (limited to 'forged/internal/bare/marshal.go')
-rw-r--r--forged/internal/bare/marshal.go311
1 files changed, 311 insertions, 0 deletions
diff --git a/forged/internal/bare/marshal.go b/forged/internal/bare/marshal.go
new file mode 100644
index 0000000..13f7b77
--- /dev/null
+++ b/forged/internal/bare/marshal.go
@@ -0,0 +1,311 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>
+
+package bare
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "reflect"
+ "sync"
+)
+
+// A type which implements this interface will be responsible for marshaling
+// itself when encountered.
+type Marshalable interface {
+ Marshal(w *Writer) error
+}
+
+var encoderBufferPool = sync.Pool{
+ New: func() interface{} {
+ buf := &bytes.Buffer{}
+ buf.Grow(32)
+ return buf
+ },
+}
+
+// Marshals a value (val, which must be a pointer) into a BARE message.
+//
+// The encoding of each struct field can be customized by the format string
+// stored under the "bare" key in the struct field's tag.
+//
+// As a special case, if the field tag is "-", the field is always omitted.
+func Marshal(val interface{}) ([]byte, error) {
+ // reuse buffers from previous serializations
+ b := encoderBufferPool.Get().(*bytes.Buffer)
+ defer func() {
+ b.Reset()
+ encoderBufferPool.Put(b)
+ }()
+
+ w := NewWriter(b)
+ err := MarshalWriter(w, val)
+
+ msg := make([]byte, b.Len())
+ copy(msg, b.Bytes())
+
+ return msg, err
+}
+
+// Marshals a value (val, which must be a pointer) into a BARE message and
+// writes it to a Writer. See Marshal for details.
+func MarshalWriter(w *Writer, val interface{}) error {
+ t := reflect.TypeOf(val)
+ v := reflect.ValueOf(val)
+ if t.Kind() != reflect.Ptr {
+ return errors.New("Expected val to be pointer type")
+ }
+
+ return getEncoder(t.Elem())(w, v.Elem())
+}
+
+type encodeFunc func(w *Writer, v reflect.Value) error
+
+var encodeFuncCache sync.Map // map[reflect.Type]encodeFunc
+
+// get decoder from cache
+func getEncoder(t reflect.Type) encodeFunc {
+ if f, ok := encodeFuncCache.Load(t); ok {
+ return f.(encodeFunc)
+ }
+
+ f := encoderFunc(t)
+ encodeFuncCache.Store(t, f)
+ return f
+}
+
+var marshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
+
+func encoderFunc(t reflect.Type) encodeFunc {
+ if reflect.PtrTo(t).Implements(marshalableInterface) {
+ return func(w *Writer, v reflect.Value) error {
+ uv := v.Addr().Interface().(Marshalable)
+ return uv.Marshal(w)
+ }
+ }
+
+ if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
+ return encodeUnion(t)
+ }
+
+ switch t.Kind() {
+ case reflect.Ptr:
+ return encodeOptional(t.Elem())
+ case reflect.Struct:
+ return encodeStruct(t)
+ case reflect.Array:
+ return encodeArray(t)
+ case reflect.Slice:
+ return encodeSlice(t)
+ case reflect.Map:
+ return encodeMap(t)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return encodeUint
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return encodeInt
+ case reflect.Float32, reflect.Float64:
+ return encodeFloat
+ case reflect.Bool:
+ return encodeBool
+ case reflect.String:
+ return encodeString
+ }
+
+ return func(w *Writer, v reflect.Value) error {
+ return &UnsupportedTypeError{v.Type()}
+ }
+}
+
+func encodeOptional(t reflect.Type) encodeFunc {
+ return func(w *Writer, v reflect.Value) error {
+ if v.IsNil() {
+ return w.WriteBool(false)
+ }
+
+ if err := w.WriteBool(true); err != nil {
+ return err
+ }
+
+ return getEncoder(t)(w, v.Elem())
+ }
+}
+
+func encodeStruct(t reflect.Type) encodeFunc {
+ n := t.NumField()
+ encoders := make([]encodeFunc, n)
+ for i := 0; i < n; i++ {
+ field := t.Field(i)
+ if field.Tag.Get("bare") == "-" {
+ continue
+ }
+ encoders[i] = getEncoder(field.Type)
+ }
+
+ return func(w *Writer, v reflect.Value) error {
+ for i := 0; i < n; i++ {
+ if encoders[i] == nil {
+ continue
+ }
+ err := encoders[i](w, v.Field(i))
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func encodeArray(t reflect.Type) encodeFunc {
+ f := getEncoder(t.Elem())
+ len := t.Len()
+
+ return func(w *Writer, v reflect.Value) error {
+ for i := 0; i < len; i++ {
+ if err := f(w, v.Index(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func encodeSlice(t reflect.Type) encodeFunc {
+ elem := t.Elem()
+ f := getEncoder(elem)
+
+ return func(w *Writer, v reflect.Value) error {
+ if err := w.WriteUint(uint64(v.Len())); err != nil {
+ return err
+ }
+
+ for i := 0; i < v.Len(); i++ {
+ if err := f(w, v.Index(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func encodeMap(t reflect.Type) encodeFunc {
+ keyType := t.Key()
+ keyf := getEncoder(keyType)
+
+ valueType := t.Elem()
+ valf := getEncoder(valueType)
+
+ return func(w *Writer, v reflect.Value) error {
+ if err := w.WriteUint(uint64(v.Len())); err != nil {
+ return err
+ }
+
+ iter := v.MapRange()
+ for iter.Next() {
+ if err := keyf(w, iter.Key()); err != nil {
+ return err
+ }
+ if err := valf(w, iter.Value()); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+func encodeUnion(t reflect.Type) encodeFunc {
+ ut, ok := unionRegistry[t]
+ if !ok {
+ return func(w *Writer, v reflect.Value) error {
+ return fmt.Errorf("Union type %s is not registered", t.Name())
+ }
+ }
+
+ encoders := make(map[uint64]encodeFunc)
+ for tag, t := range ut.types {
+ encoders[tag] = getEncoder(t)
+ }
+
+ return func(w *Writer, v reflect.Value) error {
+ t := v.Elem().Type()
+ if t.Kind() == reflect.Ptr {
+ // If T is a valid union value type, *T is valid too.
+ t = t.Elem()
+ v = v.Elem()
+ }
+ tag, ok := ut.tags[t]
+ if !ok {
+ return fmt.Errorf("Invalid union value: %s", v.Elem().String())
+ }
+
+ if err := w.WriteUint(tag); err != nil {
+ return err
+ }
+
+ return encoders[tag](w, v.Elem())
+ }
+}
+
+func encodeUint(w *Writer, v reflect.Value) error {
+ switch getIntKind(v.Type()) {
+ case reflect.Uint:
+ return w.WriteUint(v.Uint())
+
+ case reflect.Uint8:
+ return w.WriteU8(uint8(v.Uint()))
+
+ case reflect.Uint16:
+ return w.WriteU16(uint16(v.Uint()))
+
+ case reflect.Uint32:
+ return w.WriteU32(uint32(v.Uint()))
+
+ case reflect.Uint64:
+ return w.WriteU64(uint64(v.Uint()))
+ }
+
+ panic("not uint")
+}
+
+func encodeInt(w *Writer, v reflect.Value) error {
+ switch getIntKind(v.Type()) {
+ case reflect.Int:
+ return w.WriteInt(v.Int())
+
+ case reflect.Int8:
+ return w.WriteI8(int8(v.Int()))
+
+ case reflect.Int16:
+ return w.WriteI16(int16(v.Int()))
+
+ case reflect.Int32:
+ return w.WriteI32(int32(v.Int()))
+
+ case reflect.Int64:
+ return w.WriteI64(int64(v.Int()))
+ }
+
+ panic("not int")
+}
+
+func encodeFloat(w *Writer, v reflect.Value) error {
+ switch v.Type().Kind() {
+ case reflect.Float32:
+ return w.WriteF32(float32(v.Float()))
+ case reflect.Float64:
+ return w.WriteF64(v.Float())
+ }
+
+ panic("not float")
+}
+
+func encodeBool(w *Writer, v reflect.Value) error {
+ return w.WriteBool(v.Bool())
+}
+
+func encodeString(w *Writer, v reflect.Value) error {
+ if v.Kind() != reflect.String {
+ panic("not string")
+ }
+ return w.WriteString(v.String())
+}