aboutsummaryrefslogtreecommitdiff
path: root/scfg/unmarshal.go
blob: a99910fb4fc1cda7f02df406235f5ebc047b9085 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
package scfg

import (
	"encoding"
	"fmt"
	"io"
	"reflect"
	"strconv"
)

// Decoder reads and decodes an scfg document from an input stream.
type Decoder struct {
	r io.Reader
}

// NewDecoder returns a new decoder which reads from r.
func NewDecoder(r io.Reader) *Decoder {
	return &Decoder{r}
}

// Decode reads scfg document from the input and stores it in the value pointed
// to by v.
//
// If v is nil or not a pointer, Decode returns an error.
//
// Blocks can be unmarshaled to:
//
//   - Maps. Each directive is unmarshaled into a map entry. The map key must
//     be a string.
//   - Structs. Each directive is unmarshaled into a struct field.
//
// Duplicate directives are not allowed, unless the struct field or map value
// is a slice of values representing a directive: structs or maps.
//
// Directives can be unmarshaled to:
//
//   - Maps. The children block is unmarshaled into the map. Parameters are not
//     allowed.
//   - Structs. The children block is unmarshaled into the struct. Parameters
//     are allowed if one of the struct fields contains the "param" option in
//     its tag.
//   - Slices. Parameters are unmarshaled into the slice. Children blocks are
//     not allowed.
//   - Arrays. Parameters are unmarshaled into the array. The number of
//     parameters must match exactly the length of the array. Children blocks
//     are not allowed.
//   - Strings, booleans, integers, floating-point values, values implementing
//     encoding.TextUnmarshaler. Only a single parameter is allowed and is
//     unmarshaled into the value. Children blocks are not allowed.
//
// The decoding of each struct field can be customized by the format string
// stored under the "scfg" key in the struct field's tag. The tag contains the
// name of the field possibly followed by a comma-separated list of options.
// The name may be empty in order to specify options without overriding the
// default field name. As a special case, if the field name is "-", the field
// is ignored. The "param" option specifies that directive parameters are
// stored in this field (the name must be empty).
func (dec *Decoder) Decode(v interface{}) error {
	block, err := Read(dec.r)
	if err != nil {
		return err
	}

	rv := reflect.ValueOf(v)
	if rv.Kind() != reflect.Ptr || rv.IsNil() {
		return fmt.Errorf("scfg: invalid value for unmarshaling")
	}

	return unmarshalBlock(block, rv)
}

func unmarshalBlock(block Block, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	dirsByName := make(map[string][]*Directive, len(block))
	for _, dir := range block {
		dirsByName[dir.Name] = append(dirsByName[dir.Name], dir)
	}

	switch v.Kind() {
	case reflect.Map:
		if t.Key().Kind() != reflect.String {
			return fmt.Errorf("scfg: map key type must be string")
		}
		if v.IsNil() {
			v.Set(reflect.MakeMap(t))
		} else if v.Len() > 0 {
			clearMap(v)
		}

		for name, dirs := range dirsByName {
			mv := reflect.New(t.Elem()).Elem()
			if err := unmarshalDirectiveList(dirs, mv); err != nil {
				return err
			}
			v.SetMapIndex(reflect.ValueOf(name), mv)
		}
	case reflect.Struct:
		si, err := getStructInfo(t)
		if err != nil {
			return err
		}

		for name, dirs := range dirsByName {
			fieldIndex, ok := si.children[name]
			if !ok {
				return newUnmarshalDirectiveError(dirs[0], "unknown directive")
			}
			fv := v.Field(fieldIndex)
			if err := unmarshalDirectiveList(dirs, fv); err != nil {
				return err
			}
		}
	default:
		return fmt.Errorf("scfg: unsupported type for unmarshaling blocks: %v", t)
	}

	return nil
}

func unmarshalDirectiveList(dirs []*Directive, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	if v.Kind() != reflect.Slice || !isDirectiveType(t.Elem()) {
		if len(dirs) > 1 {
			return newUnmarshalDirectiveError(dirs[1], "directive must not be specified more than once")
		}
		return unmarshalDirective(dirs[0], v)
	}

	sv := reflect.MakeSlice(t, len(dirs), len(dirs))
	for i, dir := range dirs {
		if err := unmarshalDirective(dir, sv.Index(i)); err != nil {
			return err
		}
	}
	v.Set(sv)
	return nil
}

// isDirectiveType checks whether a type can only be unmarshaled as a
// directive, not as a parameter. Accepting too many types here would result in
// ambiguities, see:
// https://lists.sr.ht/~emersion/public-inbox/%3C20230629132458.152205-1-contact%40emersion.fr%3E#%3Ch4Y2peS_YBqY3ar4XlmPDPiNBFpYGns3EBYUx3_6zWEhV2o8_-fBQveRujGADWYhVVCucHBEryFGoPtpC3d3mQ-x10pWnFogfprbQTSvtxc=@emersion.fr%3E
func isDirectiveType(t reflect.Type) bool {
	for t.Kind() == reflect.Ptr {
		t = t.Elem()
	}

	textUnmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
	if reflect.PtrTo(t).Implements(textUnmarshalerType) {
		return false
	}

	switch t.Kind() {
	case reflect.Struct, reflect.Map:
		return true
	default:
		return false
	}
}

func unmarshalDirective(dir *Directive, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	if v.CanAddr() {
		if _, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
			if len(dir.Children) != 0 {
				return newUnmarshalDirectiveError(dir, "directive requires zero children")
			}
			return unmarshalParamList(dir, v)
		}
	}

	switch v.Kind() {
	case reflect.Map:
		if len(dir.Params) > 0 {
			return newUnmarshalDirectiveError(dir, "directive requires zero parameters")
		}
		if err := unmarshalBlock(dir.Children, v); err != nil {
			return err
		}
	case reflect.Struct:
		si, err := getStructInfo(t)
		if err != nil {
			return err
		}

		if si.param >= 0 {
			if err := unmarshalParamList(dir, v.Field(si.param)); err != nil {
				return err
			}
		} else {
			if len(dir.Params) > 0 {
				return newUnmarshalDirectiveError(dir, "directive requires zero parameters")
			}
		}

		if err := unmarshalBlock(dir.Children, v); err != nil {
			return err
		}
	default:
		if len(dir.Children) != 0 {
			return newUnmarshalDirectiveError(dir, "directive requires zero children")
		}
		if err := unmarshalParamList(dir, v); err != nil {
			return err
		}
	}
	return nil
}

func unmarshalParamList(dir *Directive, v reflect.Value) error {
	switch v.Kind() {
	case reflect.Slice:
		t := v.Type()
		sv := reflect.MakeSlice(t, len(dir.Params), len(dir.Params))
		for i, param := range dir.Params {
			if err := unmarshalParam(param, sv.Index(i)); err != nil {
				return newUnmarshalParamError(dir, i, err)
			}
		}
		v.Set(sv)
	case reflect.Array:
		if len(dir.Params) != v.Len() {
			return newUnmarshalDirectiveError(dir, fmt.Sprintf("directive requires exactly %v parameters", v.Len()))
		}
		for i, param := range dir.Params {
			if err := unmarshalParam(param, v.Index(i)); err != nil {
				return newUnmarshalParamError(dir, i, err)
			}
		}
	default:
		if len(dir.Params) != 1 {
			return newUnmarshalDirectiveError(dir, "directive requires exactly one parameter")
		}
		if err := unmarshalParam(dir.Params[0], v); err != nil {
			return newUnmarshalParamError(dir, 0, err)
		}
	}

	return nil
}

func unmarshalParam(param string, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	// TODO: improve our logic following:
	// https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/encoding/json/decode.go;drc=b9b8cecbfc72168ca03ad586cc2ed52b0e8db409;l=421
	if v.CanAddr() {
		if v, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
			return v.UnmarshalText([]byte(param))
		}
	}

	switch v.Kind() {
	case reflect.String:
		v.Set(reflect.ValueOf(param))
	case reflect.Bool:
		switch param {
		case "true":
			v.Set(reflect.ValueOf(true))
		case "false":
			v.Set(reflect.ValueOf(false))
		default:
			return fmt.Errorf("invalid bool parameter %q", param)
		}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		i, err := strconv.ParseInt(param, 10, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(i).Convert(t))
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		u, err := strconv.ParseUint(param, 10, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(u).Convert(t))
	case reflect.Float32, reflect.Float64:
		f, err := strconv.ParseFloat(param, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(f).Convert(t))
	default:
		return fmt.Errorf("unsupported type for unmarshaling parameter: %v", t)
	}

	return nil
}

func unwrapPointers(v reflect.Value) reflect.Value {
	for v.Kind() == reflect.Ptr {
		if v.IsNil() {
			v.Set(reflect.New(v.Type().Elem()))
		}
		v = v.Elem()
	}
	return v
}

func clearMap(v reflect.Value) {
	for _, k := range v.MapKeys() {
		v.SetMapIndex(k, reflect.Value{})
	}
}

type unmarshalDirectiveError struct {
	lineno int
	name   string
	msg    string
}

func newUnmarshalDirectiveError(dir *Directive, msg string) *unmarshalDirectiveError {
	return &unmarshalDirectiveError{
		name:   dir.Name,
		lineno: dir.lineno,
		msg:    msg,
	}
}

func (err *unmarshalDirectiveError) Error() string {
	return fmt.Sprintf("line %v, directive %q: %v", err.lineno, err.name, err.msg)
}

type unmarshalParamError struct {
	lineno     int
	directive  string
	paramIndex int
	err        error
}

func newUnmarshalParamError(dir *Directive, paramIndex int, err error) *unmarshalParamError {
	return &unmarshalParamError{
		directive:  dir.Name,
		lineno:     dir.lineno,
		paramIndex: paramIndex,
		err:        err,
	}
}

func (err *unmarshalParamError) Error() string {
	return fmt.Sprintf("line %v, directive %q, parameter %v: %v", err.lineno, err.directive, err.paramIndex+1, err.err)
}