diff options
Diffstat (limited to 'forged/internal')
71 files changed, 7065 insertions, 0 deletions
diff --git a/forged/internal/common/ansiec/colors.go b/forged/internal/common/ansiec/colors.go new file mode 100644 index 0000000..8be2a0c --- /dev/null +++ b/forged/internal/common/ansiec/colors.go @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package ansiec + +// ANSI color codes +const ( + Black = "\x1b[30m" + Red = "\x1b[31m" + Green = "\x1b[32m" + Yellow = "\x1b[33m" + Blue = "\x1b[34m" + Magenta = "\x1b[35m" + Cyan = "\x1b[36m" + White = "\x1b[37m" + BrightBlack = "\x1b[30;1m" + BrightRed = "\x1b[31;1m" + BrightGreen = "\x1b[32;1m" + BrightYellow = "\x1b[33;1m" + BrightBlue = "\x1b[34;1m" + BrightMagenta = "\x1b[35;1m" + BrightCyan = "\x1b[36;1m" + BrightWhite = "\x1b[37;1m" +) diff --git a/forged/internal/common/ansiec/doc.go b/forged/internal/common/ansiec/doc.go new file mode 100644 index 0000000..542c564 --- /dev/null +++ b/forged/internal/common/ansiec/doc.go @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// Package ansiec provides definitions for ANSI escape sequences. +package ansiec diff --git a/forged/internal/common/ansiec/reset.go b/forged/internal/common/ansiec/reset.go new file mode 100644 index 0000000..51bb312 --- /dev/null +++ b/forged/internal/common/ansiec/reset.go @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package ansiec + +// Reset the colors and styles +const Reset = "\x1b[0m" diff --git a/forged/internal/common/ansiec/style.go b/forged/internal/common/ansiec/style.go new file mode 100644 index 0000000..95edbbe --- /dev/null +++ b/forged/internal/common/ansiec/style.go @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package ansiec + +// ANSI text styles +const ( + Bold = "\x1b[1m" + Underline = "\x1b[4m" + Reversed = "\x1b[7m" + Italic = "\x1b[3m" +) diff --git a/forged/internal/common/argon2id/LICENSE b/forged/internal/common/argon2id/LICENSE new file mode 100644 index 0000000..3649823 --- /dev/null +++ b/forged/internal/common/argon2id/LICENSE @@ -0,0 +1,18 @@ +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/forged/internal/common/argon2id/argon2id.go b/forged/internal/common/argon2id/argon2id.go new file mode 100644 index 0000000..88df8f6 --- /dev/null +++ b/forged/internal/common/argon2id/argon2id.go @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2018 Alex Edwards + +// Package argon2id provides a wrapper around Go's golang.org/x/crypto/argon2. +package argon2id + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "runtime" + "strings" + + "golang.org/x/crypto/argon2" +) + +var ( + // ErrInvalidHash in returned by ComparePasswordAndHash if the provided + // hash isn't in the expected format. + ErrInvalidHash = errors.New("argon2id: hash is not in the correct format") + + // ErrIncompatibleVariant is returned by ComparePasswordAndHash if the + // provided hash was created using a unsupported variant of Argon2. + // Currently only argon2id is supported by this package. + ErrIncompatibleVariant = errors.New("argon2id: incompatible variant of argon2") + + // ErrIncompatibleVersion is returned by ComparePasswordAndHash if the + // provided hash was created using a different version of Argon2. + ErrIncompatibleVersion = errors.New("argon2id: incompatible version of argon2") +) + +// DefaultParams provides some sane default parameters for hashing passwords. +// +// Follows recommendations given by the Argon2 RFC: +// "The Argon2id variant with t=1 and maximum available memory is RECOMMENDED as a +// default setting for all environments. This setting is secure against side-channel +// attacks and maximizes adversarial costs on dedicated bruteforce hardware."" +// +// The default parameters should generally be used for development/testing purposes +// only. Custom parameters should be set for production applications depending on +// available memory/CPU resources and business requirements. +var DefaultParams = &Params{ + Memory: 64 * 1024, + Iterations: 1, + Parallelism: uint8(runtime.NumCPU()), + SaltLength: 16, + KeyLength: 32, +} + +// Params describes the input parameters used by the Argon2id algorithm. The +// Memory and Iterations parameters control the computational cost of hashing +// the password. The higher these figures are, the greater the cost of generating +// the hash and the longer the runtime. It also follows that the greater the cost +// will be for any attacker trying to guess the password. If the code is running +// on a machine with multiple cores, then you can decrease the runtime without +// reducing the cost by increasing the Parallelism parameter. This controls the +// number of threads that the work is spread across. Important note: Changing the +// value of the Parallelism parameter changes the hash output. +// +// For guidance and an outline process for choosing appropriate parameters see +// https://tools.ietf.org/html/draft-irtf-cfrg-argon2-04#section-4 +type Params struct { + // The amount of memory used by the algorithm (in kibibytes). + Memory uint32 + + // The number of iterations over the memory. + Iterations uint32 + + // The number of threads (or lanes) used by the algorithm. + // Recommended value is between 1 and runtime.NumCPU(). + Parallelism uint8 + + // Length of the random salt. 16 bytes is recommended for password hashing. + SaltLength uint32 + + // Length of the generated key. 16 bytes or more is recommended. + KeyLength uint32 +} + +// CreateHash returns an Argon2id hash of a plain-text password using the +// provided algorithm parameters. The returned hash follows the format used by +// the Argon2 reference C implementation and contains the base64-encoded Argon2id d +// derived key prefixed by the salt and parameters. It looks like this: +// +// $argon2id$v=19$m=65536,t=3,p=2$c29tZXNhbHQ$RdescudvJCsgt3ub+b+dWRWJTmaaJObG +func CreateHash(password string, params *Params) (hash string, err error) { + salt, err := generateRandomBytes(params.SaltLength) + if err != nil { + return "", err + } + + key := argon2.IDKey([]byte(password), salt, params.Iterations, params.Memory, params.Parallelism, params.KeyLength) + + b64Salt := base64.RawStdEncoding.EncodeToString(salt) + b64Key := base64.RawStdEncoding.EncodeToString(key) + + hash = fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, params.Memory, params.Iterations, params.Parallelism, b64Salt, b64Key) + return hash, nil +} + +// ComparePasswordAndHash performs a constant-time comparison between a +// plain-text password and Argon2id hash, using the parameters and salt +// contained in the hash. It returns true if they match, otherwise it returns +// false. +func ComparePasswordAndHash(password, hash string) (match bool, err error) { + match, _, err = CheckHash(password, hash) + return match, err +} + +// CheckHash is like ComparePasswordAndHash, except it also returns the params that the hash was +// created with. This can be useful if you want to update your hash params over time (which you +// should). +func CheckHash(password, hash string) (match bool, params *Params, err error) { + params, salt, key, err := DecodeHash(hash) + if err != nil { + return false, nil, err + } + + otherKey := argon2.IDKey([]byte(password), salt, params.Iterations, params.Memory, params.Parallelism, params.KeyLength) + + keyLen := int32(len(key)) + otherKeyLen := int32(len(otherKey)) + + if subtle.ConstantTimeEq(keyLen, otherKeyLen) == 0 { + return false, params, nil + } + if subtle.ConstantTimeCompare(key, otherKey) == 1 { + return true, params, nil + } + return false, params, nil +} + +func generateRandomBytes(n uint32) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + + return b, nil +} + +// DecodeHash expects a hash created from this package, and parses it to return the params used to +// create it, as well as the salt and key (password hash). +func DecodeHash(hash string) (params *Params, salt, key []byte, err error) { + vals := strings.Split(hash, "$") + if len(vals) != 6 { + return nil, nil, nil, ErrInvalidHash + } + + if vals[1] != "argon2id" { + return nil, nil, nil, ErrIncompatibleVariant + } + + var version int + _, err = fmt.Sscanf(vals[2], "v=%d", &version) + if err != nil { + return nil, nil, nil, err + } + if version != argon2.Version { + return nil, nil, nil, ErrIncompatibleVersion + } + + params = &Params{} + _, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", ¶ms.Memory, ¶ms.Iterations, ¶ms.Parallelism) + if err != nil { + return nil, nil, nil, err + } + + salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4]) + if err != nil { + return nil, nil, nil, err + } + params.SaltLength = uint32(len(salt)) + + key, err = base64.RawStdEncoding.Strict().DecodeString(vals[5]) + if err != nil { + return nil, nil, nil, err + } + params.KeyLength = uint32(len(key)) + + return params, salt, key, nil +} diff --git a/forged/internal/common/bare/LICENSE b/forged/internal/common/bare/LICENSE new file mode 100644 index 0000000..6b0b127 --- /dev/null +++ b/forged/internal/common/bare/LICENSE @@ -0,0 +1,203 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/forged/internal/common/bare/doc.go b/forged/internal/common/bare/doc.go new file mode 100644 index 0000000..2f12f55 --- /dev/null +++ b/forged/internal/common/bare/doc.go @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// Package bare provides primitives to encode and decode BARE messages. +// +// There is no guarantee that this is compatible with the upstream +// implementation at https://git.sr.ht/~sircmpwn/go-bare. +package bare diff --git a/forged/internal/common/bare/errors.go b/forged/internal/common/bare/errors.go new file mode 100644 index 0000000..4634f0c --- /dev/null +++ b/forged/internal/common/bare/errors.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "errors" + "fmt" + "reflect" +) + +var ErrInvalidStr = errors.New("string contains invalid UTF-8 sequences") + +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return fmt.Sprintf("unsupported type for marshaling: %s\n", e.Type.String()) +} diff --git a/forged/internal/common/bare/limit.go b/forged/internal/common/bare/limit.go new file mode 100644 index 0000000..7eece8c --- /dev/null +++ b/forged/internal/common/bare/limit.go @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "errors" + "io" +) + +var ( + maxUnmarshalBytes uint64 = 1024 * 1024 * 32 /* 32 MiB */ + maxArrayLength uint64 = 1024 * 4 /* 4096 elements */ + maxMapSize uint64 = 1024 +) + +// MaxUnmarshalBytes sets the maximum size of a message decoded by unmarshal. +// By default, this is set to 32 MiB. +func MaxUnmarshalBytes(bytes uint64) { + maxUnmarshalBytes = bytes +} + +// MaxArrayLength sets maximum number of elements in array. Defaults to 4096 elements +func MaxArrayLength(length uint64) { + maxArrayLength = length +} + +// MaxMapSize sets maximum size of map. Defaults to 1024 key/value pairs +func MaxMapSize(size uint64) { + maxMapSize = size +} + +// Use MaxUnmarshalBytes to prevent this error from occuring on messages which +// are large by design. +var ErrLimitExceeded = errors.New("maximum message size exceeded") + +// Identical to io.LimitedReader, except it returns our custom error instead of +// EOF if the limit is reached. +type limitedReader struct { + R io.Reader + N uint64 +} + +func (l *limitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, ErrLimitExceeded + } + if uint64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= uint64(n) + return +} + +func newLimitedReader(r io.Reader) *limitedReader { + return &limitedReader{r, maxUnmarshalBytes} +} diff --git a/forged/internal/common/bare/marshal.go b/forged/internal/common/bare/marshal.go new file mode 100644 index 0000000..d4c338e --- /dev/null +++ b/forged/internal/common/bare/marshal.go @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "sync" +) + +// A type which implements this interface will be responsible for marshaling +// itself when encountered. +type Marshalable interface { + Marshal(w *Writer) error +} + +var encoderBufferPool = sync.Pool{ + New: func() interface{} { + buf := &bytes.Buffer{} + buf.Grow(32) + return buf + }, +} + +// Marshals a value (val, which must be a pointer) into a BARE message. +// +// The encoding of each struct field can be customized by the format string +// stored under the "bare" key in the struct field's tag. +// +// As a special case, if the field tag is "-", the field is always omitted. +func Marshal(val interface{}) ([]byte, error) { + // reuse buffers from previous serializations + b := encoderBufferPool.Get().(*bytes.Buffer) + defer func() { + b.Reset() + encoderBufferPool.Put(b) + }() + + w := NewWriter(b) + err := MarshalWriter(w, val) + + msg := make([]byte, b.Len()) + copy(msg, b.Bytes()) + + return msg, err +} + +// Marshals a value (val, which must be a pointer) into a BARE message and +// writes it to a Writer. See Marshal for details. +func MarshalWriter(w *Writer, val interface{}) error { + t := reflect.TypeOf(val) + v := reflect.ValueOf(val) + if t.Kind() != reflect.Ptr { + return errors.New("expected val to be pointer type") + } + + return getEncoder(t.Elem())(w, v.Elem()) +} + +type encodeFunc func(w *Writer, v reflect.Value) error + +var encodeFuncCache sync.Map // map[reflect.Type]encodeFunc + +// get decoder from cache +func getEncoder(t reflect.Type) encodeFunc { + if f, ok := encodeFuncCache.Load(t); ok { + return f.(encodeFunc) + } + + f := encoderFunc(t) + encodeFuncCache.Store(t, f) + return f +} + +var marshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem() + +func encoderFunc(t reflect.Type) encodeFunc { + if reflect.PointerTo(t).Implements(marshalableInterface) { + return func(w *Writer, v reflect.Value) error { + uv := v.Addr().Interface().(Marshalable) + return uv.Marshal(w) + } + } + + if t.Kind() == reflect.Interface && t.Implements(unionInterface) { + return encodeUnion(t) + } + + switch t.Kind() { + case reflect.Ptr: + return encodeOptional(t.Elem()) + case reflect.Struct: + return encodeStruct(t) + case reflect.Array: + return encodeArray(t) + case reflect.Slice: + return encodeSlice(t) + case reflect.Map: + return encodeMap(t) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return encodeUint + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt + case reflect.Float32, reflect.Float64: + return encodeFloat + case reflect.Bool: + return encodeBool + case reflect.String: + return encodeString + } + + return func(w *Writer, v reflect.Value) error { + return &UnsupportedTypeError{v.Type()} + } +} + +func encodeOptional(t reflect.Type) encodeFunc { + return func(w *Writer, v reflect.Value) error { + if v.IsNil() { + return w.WriteBool(false) + } + + if err := w.WriteBool(true); err != nil { + return err + } + + return getEncoder(t)(w, v.Elem()) + } +} + +func encodeStruct(t reflect.Type) encodeFunc { + n := t.NumField() + encoders := make([]encodeFunc, n) + for i := 0; i < n; i++ { + field := t.Field(i) + if field.Tag.Get("bare") == "-" { + continue + } + encoders[i] = getEncoder(field.Type) + } + + return func(w *Writer, v reflect.Value) error { + for i := 0; i < n; i++ { + if encoders[i] == nil { + continue + } + err := encoders[i](w, v.Field(i)) + if err != nil { + return err + } + } + return nil + } +} + +func encodeArray(t reflect.Type) encodeFunc { + f := getEncoder(t.Elem()) + len := t.Len() + + return func(w *Writer, v reflect.Value) error { + for i := 0; i < len; i++ { + if err := f(w, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func encodeSlice(t reflect.Type) encodeFunc { + elem := t.Elem() + f := getEncoder(elem) + + return func(w *Writer, v reflect.Value) error { + if err := w.WriteUint(uint64(v.Len())); err != nil { + return err + } + + for i := 0; i < v.Len(); i++ { + if err := f(w, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func encodeMap(t reflect.Type) encodeFunc { + keyType := t.Key() + keyf := getEncoder(keyType) + + valueType := t.Elem() + valf := getEncoder(valueType) + + return func(w *Writer, v reflect.Value) error { + if err := w.WriteUint(uint64(v.Len())); err != nil { + return err + } + + iter := v.MapRange() + for iter.Next() { + if err := keyf(w, iter.Key()); err != nil { + return err + } + if err := valf(w, iter.Value()); err != nil { + return err + } + } + return nil + } +} + +func encodeUnion(t reflect.Type) encodeFunc { + ut, ok := unionRegistry[t] + if !ok { + return func(w *Writer, v reflect.Value) error { + return fmt.Errorf("Union type %s is not registered", t.Name()) + } + } + + encoders := make(map[uint64]encodeFunc) + for tag, t := range ut.types { + encoders[tag] = getEncoder(t) + } + + return func(w *Writer, v reflect.Value) error { + t := v.Elem().Type() + if t.Kind() == reflect.Ptr { + // If T is a valid union value type, *T is valid too. + t = t.Elem() + v = v.Elem() + } + tag, ok := ut.tags[t] + if !ok { + return fmt.Errorf("Invalid union value: %s", v.Elem().String()) + } + + if err := w.WriteUint(tag); err != nil { + return err + } + + return encoders[tag](w, v.Elem()) + } +} + +func encodeUint(w *Writer, v reflect.Value) error { + switch getIntKind(v.Type()) { + case reflect.Uint: + return w.WriteUint(v.Uint()) + + case reflect.Uint8: + return w.WriteU8(uint8(v.Uint())) + + case reflect.Uint16: + return w.WriteU16(uint16(v.Uint())) + + case reflect.Uint32: + return w.WriteU32(uint32(v.Uint())) + + case reflect.Uint64: + return w.WriteU64(uint64(v.Uint())) + } + + panic("not uint") +} + +func encodeInt(w *Writer, v reflect.Value) error { + switch getIntKind(v.Type()) { + case reflect.Int: + return w.WriteInt(v.Int()) + + case reflect.Int8: + return w.WriteI8(int8(v.Int())) + + case reflect.Int16: + return w.WriteI16(int16(v.Int())) + + case reflect.Int32: + return w.WriteI32(int32(v.Int())) + + case reflect.Int64: + return w.WriteI64(int64(v.Int())) + } + + panic("not int") +} + +func encodeFloat(w *Writer, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Float32: + return w.WriteF32(float32(v.Float())) + case reflect.Float64: + return w.WriteF64(v.Float()) + } + + panic("not float") +} + +func encodeBool(w *Writer, v reflect.Value) error { + return w.WriteBool(v.Bool()) +} + +func encodeString(w *Writer, v reflect.Value) error { + if v.Kind() != reflect.String { + panic("not string") + } + return w.WriteString(v.String()) +} diff --git a/forged/internal/common/bare/reader.go b/forged/internal/common/bare/reader.go new file mode 100644 index 0000000..7e872f4 --- /dev/null +++ b/forged/internal/common/bare/reader.go @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "unicode/utf8" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" +) + +type byteReader interface { + io.Reader + io.ByteReader +} + +// A Reader for BARE primitive types. +type Reader struct { + base byteReader + scratch [8]byte +} + +type simpleByteReader struct { + io.Reader + scratch [1]byte +} + +func (r simpleByteReader) ReadByte() (byte, error) { + // using reference type here saves us allocations + _, err := r.Read(r.scratch[:]) + return r.scratch[0], err +} + +// Returns a new BARE primitive reader wrapping the given io.Reader. +func NewReader(base io.Reader) *Reader { + br, ok := base.(byteReader) + if !ok { + br = simpleByteReader{Reader: base} + } + return &Reader{base: br} +} + +func (r *Reader) ReadUint() (uint64, error) { + x, err := binary.ReadUvarint(r.base) + if err != nil { + return x, err + } + return x, nil +} + +func (r *Reader) ReadU8() (uint8, error) { + return r.base.ReadByte() +} + +func (r *Reader) ReadU16() (uint16, error) { + var i uint16 + if _, err := io.ReadAtLeast(r.base, r.scratch[:2], 2); err != nil { + return i, err + } + return binary.LittleEndian.Uint16(r.scratch[:]), nil +} + +func (r *Reader) ReadU32() (uint32, error) { + var i uint32 + if _, err := io.ReadAtLeast(r.base, r.scratch[:4], 4); err != nil { + return i, err + } + return binary.LittleEndian.Uint32(r.scratch[:]), nil +} + +func (r *Reader) ReadU64() (uint64, error) { + var i uint64 + if _, err := io.ReadAtLeast(r.base, r.scratch[:8], 8); err != nil { + return i, err + } + return binary.LittleEndian.Uint64(r.scratch[:]), nil +} + +func (r *Reader) ReadInt() (int64, error) { + return binary.ReadVarint(r.base) +} + +func (r *Reader) ReadI8() (int8, error) { + b, err := r.base.ReadByte() + return int8(b), err +} + +func (r *Reader) ReadI16() (int16, error) { + var i int16 + if _, err := io.ReadAtLeast(r.base, r.scratch[:2], 2); err != nil { + return i, err + } + return int16(binary.LittleEndian.Uint16(r.scratch[:])), nil +} + +func (r *Reader) ReadI32() (int32, error) { + var i int32 + if _, err := io.ReadAtLeast(r.base, r.scratch[:4], 4); err != nil { + return i, err + } + return int32(binary.LittleEndian.Uint32(r.scratch[:])), nil +} + +func (r *Reader) ReadI64() (int64, error) { + var i int64 + if _, err := io.ReadAtLeast(r.base, r.scratch[:], 8); err != nil { + return i, err + } + return int64(binary.LittleEndian.Uint64(r.scratch[:])), nil +} + +func (r *Reader) ReadF32() (float32, error) { + u, err := r.ReadU32() + f := math.Float32frombits(u) + if math.IsNaN(float64(f)) { + return 0.0, fmt.Errorf("NaN is not permitted in BARE floats") + } + return f, err +} + +func (r *Reader) ReadF64() (float64, error) { + u, err := r.ReadU64() + f := math.Float64frombits(u) + if math.IsNaN(f) { + return 0.0, fmt.Errorf("NaN is not permitted in BARE floats") + } + return f, err +} + +func (r *Reader) ReadBool() (bool, error) { + b, err := r.ReadU8() + if err != nil { + return false, err + } + + if b > 1 { + return false, fmt.Errorf("Invalid bool value: %#x", b) + } + + return b == 1, nil +} + +func (r *Reader) ReadString() (string, error) { + buf, err := r.ReadData() + if err != nil { + return "", err + } + if !utf8.Valid(buf) { + return "", ErrInvalidStr + } + return misc.BytesToString(buf), nil +} + +// Reads a fixed amount of arbitrary data, defined by the length of the slice. +func (r *Reader) ReadDataFixed(dest []byte) error { + var amt int + for amt < len(dest) { + n, err := r.base.Read(dest[amt:]) + if err != nil { + return err + } + amt += n + } + return nil +} + +// Reads arbitrary data whose length is read from the message. +func (r *Reader) ReadData() ([]byte, error) { + l, err := r.ReadUint() + if err != nil { + return nil, err + } + if l >= maxUnmarshalBytes { + return nil, ErrLimitExceeded + } + buf := make([]byte, l) + var amt uint64 = 0 + for amt < l { + n, err := r.base.Read(buf[amt:]) + if err != nil { + return nil, err + } + amt += uint64(n) + } + return buf, nil +} diff --git a/forged/internal/common/bare/unions.go b/forged/internal/common/bare/unions.go new file mode 100644 index 0000000..1020fa0 --- /dev/null +++ b/forged/internal/common/bare/unions.go @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "fmt" + "reflect" +) + +// Any type which is a union member must implement this interface. You must +// also call RegisterUnion for go-bare to marshal or unmarshal messages which +// utilize your union type. +type Union interface { + IsUnion() +} + +type UnionTags struct { + iface reflect.Type + tags map[reflect.Type]uint64 + types map[uint64]reflect.Type +} + +var ( + unionInterface = reflect.TypeOf((*Union)(nil)).Elem() + unionRegistry map[reflect.Type]*UnionTags +) + +func init() { + unionRegistry = make(map[reflect.Type]*UnionTags) +} + +// Registers a union type in this context. Pass the union interface and the +// list of types associated with it, sorted ascending by their union tag. +func RegisterUnion(iface interface{}) *UnionTags { + ity := reflect.TypeOf(iface).Elem() + if _, ok := unionRegistry[ity]; ok { + panic(fmt.Errorf("Type %s has already been registered", ity.Name())) + } + + if !ity.Implements(reflect.TypeOf((*Union)(nil)).Elem()) { + panic(fmt.Errorf("Type %s does not implement bare.Union", ity.Name())) + } + + utypes := &UnionTags{ + iface: ity, + tags: make(map[reflect.Type]uint64), + types: make(map[uint64]reflect.Type), + } + unionRegistry[ity] = utypes + return utypes +} + +func (ut *UnionTags) Member(t interface{}, tag uint64) *UnionTags { + ty := reflect.TypeOf(t) + if !ty.AssignableTo(ut.iface) { + panic(fmt.Errorf("Type %s does not implement interface %s", + ty.Name(), ut.iface.Name())) + } + if _, ok := ut.tags[ty]; ok { + panic(fmt.Errorf("Type %s is already registered for union %s", + ty.Name(), ut.iface.Name())) + } + if _, ok := ut.types[tag]; ok { + panic(fmt.Errorf("Tag %d is already registered for union %s", + tag, ut.iface.Name())) + } + ut.tags[ty] = tag + ut.types[tag] = ty + return ut +} + +func (ut *UnionTags) TagFor(v interface{}) (uint64, bool) { + tag, ok := ut.tags[reflect.TypeOf(v)] + return tag, ok +} + +func (ut *UnionTags) TypeFor(tag uint64) (reflect.Type, bool) { + t, ok := ut.types[tag] + return t, ok +} diff --git a/forged/internal/common/bare/unmarshal.go b/forged/internal/common/bare/unmarshal.go new file mode 100644 index 0000000..d55f32c --- /dev/null +++ b/forged/internal/common/bare/unmarshal.go @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "sync" +) + +// A type which implements this interface will be responsible for unmarshaling +// itself when encountered. +type Unmarshalable interface { + Unmarshal(r *Reader) error +} + +// Unmarshals a BARE message into val, which must be a pointer to a value of +// the message type. +func Unmarshal(data []byte, val interface{}) error { + b := bytes.NewReader(data) + r := NewReader(b) + return UnmarshalBareReader(r, val) +} + +// Unmarshals a BARE message into value (val, which must be a pointer), from a +// reader. See Unmarshal for details. +func UnmarshalReader(r io.Reader, val interface{}) error { + r = newLimitedReader(r) + return UnmarshalBareReader(NewReader(r), val) +} + +type decodeFunc func(r *Reader, v reflect.Value) error + +var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc + +func UnmarshalBareReader(r *Reader, val interface{}) error { + t := reflect.TypeOf(val) + v := reflect.ValueOf(val) + if t.Kind() != reflect.Ptr { + return errors.New("Expected val to be pointer type") + } + + return getDecoder(t.Elem())(r, v.Elem()) +} + +// get decoder from cache +func getDecoder(t reflect.Type) decodeFunc { + if f, ok := decodeFuncCache.Load(t); ok { + return f.(decodeFunc) + } + + f := decoderFunc(t) + decodeFuncCache.Store(t, f) + return f +} + +var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem() + +func decoderFunc(t reflect.Type) decodeFunc { + if reflect.PointerTo(t).Implements(unmarshalableInterface) { + return func(r *Reader, v reflect.Value) error { + uv := v.Addr().Interface().(Unmarshalable) + return uv.Unmarshal(r) + } + } + + if t.Kind() == reflect.Interface && t.Implements(unionInterface) { + return decodeUnion(t) + } + + switch t.Kind() { + case reflect.Ptr: + return decodeOptional(t.Elem()) + case reflect.Struct: + return decodeStruct(t) + case reflect.Array: + return decodeArray(t) + case reflect.Slice: + return decodeSlice(t) + case reflect.Map: + return decodeMap(t) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return decodeUint + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return decodeInt + case reflect.Float32, reflect.Float64: + return decodeFloat + case reflect.Bool: + return decodeBool + case reflect.String: + return decodeString + } + + return func(r *Reader, v reflect.Value) error { + return &UnsupportedTypeError{v.Type()} + } +} + +func decodeOptional(t reflect.Type) decodeFunc { + return func(r *Reader, v reflect.Value) error { + s, err := r.ReadU8() + if err != nil { + return err + } + + if s > 1 { + return fmt.Errorf("Invalid optional value: %#x", s) + } + + if s == 0 { + return nil + } + + v.Set(reflect.New(t)) + return getDecoder(t)(r, v.Elem()) + } +} + +func decodeStruct(t reflect.Type) decodeFunc { + n := t.NumField() + decoders := make([]decodeFunc, n) + for i := 0; i < n; i++ { + field := t.Field(i) + if field.Tag.Get("bare") == "-" { + continue + } + decoders[i] = getDecoder(field.Type) + } + + return func(r *Reader, v reflect.Value) error { + for i := 0; i < n; i++ { + if decoders[i] == nil { + continue + } + err := decoders[i](r, v.Field(i)) + if err != nil { + return err + } + } + return nil + } +} + +func decodeArray(t reflect.Type) decodeFunc { + f := getDecoder(t.Elem()) + len := t.Len() + + return func(r *Reader, v reflect.Value) error { + for i := 0; i < len; i++ { + err := f(r, v.Index(i)) + if err != nil { + return err + } + } + return nil + } +} + +func decodeSlice(t reflect.Type) decodeFunc { + elem := t.Elem() + f := getDecoder(elem) + + return func(r *Reader, v reflect.Value) error { + len, err := r.ReadUint() + if err != nil { + return err + } + + if len > maxArrayLength { + return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength) + } + + v.Set(reflect.MakeSlice(t, int(len), int(len))) + + for i := 0; i < int(len); i++ { + if err := f(r, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func decodeMap(t reflect.Type) decodeFunc { + keyType := t.Key() + keyf := getDecoder(keyType) + + valueType := t.Elem() + valf := getDecoder(valueType) + + return func(r *Reader, v reflect.Value) error { + size, err := r.ReadUint() + if err != nil { + return err + } + + if size > maxMapSize { + return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize) + } + + v.Set(reflect.MakeMapWithSize(t, int(size))) + + key := reflect.New(keyType).Elem() + value := reflect.New(valueType).Elem() + + for i := uint64(0); i < size; i++ { + if err := keyf(r, key); err != nil { + return err + } + + if v.MapIndex(key).Kind() > reflect.Invalid { + return fmt.Errorf("Encountered duplicate map key: %v", key.Interface()) + } + + if err := valf(r, value); err != nil { + return err + } + + v.SetMapIndex(key, value) + } + return nil + } +} + +func decodeUnion(t reflect.Type) decodeFunc { + ut, ok := unionRegistry[t] + if !ok { + return func(r *Reader, v reflect.Value) error { + return fmt.Errorf("Union type %s is not registered", t.Name()) + } + } + + decoders := make(map[uint64]decodeFunc) + for tag, t := range ut.types { + t := t + f := getDecoder(t) + + decoders[tag] = func(r *Reader, v reflect.Value) error { + nv := reflect.New(t) + if err := f(r, nv.Elem()); err != nil { + return err + } + + v.Set(nv) + return nil + } + } + + return func(r *Reader, v reflect.Value) error { + tag, err := r.ReadUint() + if err != nil { + return err + } + + if f, ok := decoders[tag]; ok { + return f(r, v) + } + + return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name()) + } +} + +func decodeUint(r *Reader, v reflect.Value) error { + var err error + switch getIntKind(v.Type()) { + case reflect.Uint: + var u uint64 + u, err = r.ReadUint() + v.SetUint(u) + + case reflect.Uint8: + var u uint8 + u, err = r.ReadU8() + v.SetUint(uint64(u)) + + case reflect.Uint16: + var u uint16 + u, err = r.ReadU16() + v.SetUint(uint64(u)) + case reflect.Uint32: + var u uint32 + u, err = r.ReadU32() + v.SetUint(uint64(u)) + + case reflect.Uint64: + var u uint64 + u, err = r.ReadU64() + v.SetUint(uint64(u)) + + default: + panic("not an uint") + } + + return err +} + +func decodeInt(r *Reader, v reflect.Value) error { + var err error + switch getIntKind(v.Type()) { + case reflect.Int: + var i int64 + i, err = r.ReadInt() + v.SetInt(i) + + case reflect.Int8: + var i int8 + i, err = r.ReadI8() + v.SetInt(int64(i)) + + case reflect.Int16: + var i int16 + i, err = r.ReadI16() + v.SetInt(int64(i)) + case reflect.Int32: + var i int32 + i, err = r.ReadI32() + v.SetInt(int64(i)) + + case reflect.Int64: + var i int64 + i, err = r.ReadI64() + v.SetInt(int64(i)) + + default: + panic("not an int") + } + + return err +} + +func decodeFloat(r *Reader, v reflect.Value) error { + var err error + switch v.Type().Kind() { + case reflect.Float32: + var f float32 + f, err = r.ReadF32() + v.SetFloat(float64(f)) + case reflect.Float64: + var f float64 + f, err = r.ReadF64() + v.SetFloat(f) + default: + panic("not a float") + } + return err +} + +func decodeBool(r *Reader, v reflect.Value) error { + b, err := r.ReadBool() + v.SetBool(b) + return err +} + +func decodeString(r *Reader, v reflect.Value) error { + s, err := r.ReadString() + v.SetString(s) + return err +} diff --git a/forged/internal/common/bare/varint.go b/forged/internal/common/bare/varint.go new file mode 100644 index 0000000..a185ac8 --- /dev/null +++ b/forged/internal/common/bare/varint.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "reflect" +) + +// Int is a variable-length encoded signed integer. +type Int int64 + +// Uint is a variable-length encoded unsigned integer. +type Uint uint64 + +var ( + intType = reflect.TypeOf(Int(0)) + uintType = reflect.TypeOf(Uint(0)) +) + +func getIntKind(t reflect.Type) reflect.Kind { + switch t { + case intType: + return reflect.Int + case uintType: + return reflect.Uint + default: + return t.Kind() + } +} diff --git a/forged/internal/common/bare/writer.go b/forged/internal/common/bare/writer.go new file mode 100644 index 0000000..1b23c9f --- /dev/null +++ b/forged/internal/common/bare/writer.go @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "encoding/binary" + "fmt" + "io" + "math" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" +) + +// A Writer for BARE primitive types. +type Writer struct { + base io.Writer + scratch [binary.MaxVarintLen64]byte +} + +// Returns a new BARE primitive writer wrapping the given io.Writer. +func NewWriter(base io.Writer) *Writer { + return &Writer{base: base} +} + +func (w *Writer) WriteUint(i uint64) error { + n := binary.PutUvarint(w.scratch[:], i) + _, err := w.base.Write(w.scratch[:n]) + return err +} + +func (w *Writer) WriteU8(i uint8) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteU16(i uint16) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteU32(i uint32) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteU64(i uint64) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteInt(i int64) error { + var buf [binary.MaxVarintLen64]byte + n := binary.PutVarint(buf[:], i) + _, err := w.base.Write(buf[:n]) + return err +} + +func (w *Writer) WriteI8(i int8) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteI16(i int16) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteI32(i int32) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteI64(i int64) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteF32(f float32) error { + if math.IsNaN(float64(f)) { + return fmt.Errorf("NaN is not permitted in BARE floats") + } + return binary.Write(w.base, binary.LittleEndian, f) +} + +func (w *Writer) WriteF64(f float64) error { + if math.IsNaN(f) { + return fmt.Errorf("NaN is not permitted in BARE floats") + } + return binary.Write(w.base, binary.LittleEndian, f) +} + +func (w *Writer) WriteBool(b bool) error { + return binary.Write(w.base, binary.LittleEndian, b) +} + +func (w *Writer) WriteString(str string) error { + return w.WriteData(misc.StringToBytes(str)) +} + +// Writes a fixed amount of arbitrary data, defined by the length of the slice. +func (w *Writer) WriteDataFixed(data []byte) error { + var amt int + for amt < len(data) { + n, err := w.base.Write(data[amt:]) + if err != nil { + return err + } + amt += n + } + return nil +} + +// Writes arbitrary data whose length is encoded into the message. +func (w *Writer) WriteData(data []byte) error { + err := w.WriteUint(uint64(len(data))) + if err != nil { + return err + } + var amt int + for amt < len(data) { + n, err := w.base.Write(data[amt:]) + if err != nil { + return err + } + amt += n + } + return nil +} diff --git a/forged/internal/common/cmap/LICENSE b/forged/internal/common/cmap/LICENSE new file mode 100644 index 0000000..d5dfee8 --- /dev/null +++ b/forged/internal/common/cmap/LICENSE @@ -0,0 +1,22 @@ +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS "AS IS" AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/forged/internal/common/cmap/comparable_map.go b/forged/internal/common/cmap/comparable_map.go new file mode 100644 index 0000000..e89175c --- /dev/null +++ b/forged/internal/common/cmap/comparable_map.go @@ -0,0 +1,539 @@ +// Inspired by github.com/SaveTheRbtz/generic-sync-map-go but technically +// written from scratch with Go 1.23's sync.Map. +// Copyright 2024 Runxi Yu (porting it to generics) +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cmap + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +// ComparableMap[K comparable, V comparable] is like a Go map[K]V but is safe for concurrent use +// by multiple goroutines without additional locking or coordination. Loads, +// stores, and deletes run in amortized constant time. +// +// The ComparableMap type is optimized for two common use cases: (1) when the comparableEntry for a given +// key is only ever written once but read many times, as in caches that only grow, +// or (2) when multiple goroutines read, write, and overwrite entries for disjoint +// sets of keys. In these two cases, use of a ComparableMap may significantly reduce lock +// contention compared to a Go map paired with a separate [Mutex] or [RWMutex]. +// +// The zero ComparableMap is empty and ready for use. A ComparableMap must not be copied after first use. +// +// In the terminology of [the Go memory model], ComparableMap arranges that a write operation +// “synchronizes before” any read operation that observes the effect of the write, where +// read and write operations are defined as follows. +// [ComparableMap.Load], [ComparableMap.LoadAndDelete], [ComparableMap.LoadOrStore], [ComparableMap.Swap], [ComparableMap.CompareAndSwap], +// and [ComparableMap.CompareAndDelete] are read operations; +// [ComparableMap.Delete], [ComparableMap.LoadAndDelete], [ComparableMap.Store], and [ComparableMap.Swap] are write operations; +// [ComparableMap.LoadOrStore] is a write operation when it returns loaded set to false; +// [ComparableMap.CompareAndSwap] is a write operation when it returns swapped set to true; +// and [ComparableMap.CompareAndDelete] is a write operation when it returns deleted set to true. +// +// [the Go memory model]: https://go.dev/ref/mem +type ComparableMap[K comparable, V comparable] struct { + mu sync.Mutex + + // read contains the portion of the map's contents that are safe for + // concurrent access (with or without mu held). + // + // The read field itself is always safe to load, but must only be stored with + // mu held. + // + // Entries stored in read may be updated concurrently without mu, but updating + // a previously-comparableExpunged comparableEntry requires that the comparableEntry be copied to the dirty + // map and uncomparableExpunged with mu held. + read atomic.Pointer[comparableReadOnly[K, V]] + + // dirty contains the portion of the map's contents that require mu to be + // held. To ensure that the dirty map can be promoted to the read map quickly, + // it also includes all of the non-comparableExpunged entries in the read map. + // + // Expunged entries are not stored in the dirty map. An comparableExpunged comparableEntry in the + // clean map must be uncomparableExpunged and added to the dirty map before a new value + // can be stored to it. + // + // If the dirty map is nil, the next write to the map will initialize it by + // making a shallow copy of the clean map, omitting stale entries. + dirty map[K]*comparableEntry[V] + + // misses counts the number of loads since the read map was last updated that + // needed to lock mu to determine whether the key was present. + // + // Once enough misses have occurred to cover the cost of copying the dirty + // map, the dirty map will be promoted to the read map (in the unamended + // state) and the next store to the map will make a new dirty copy. + misses int +} + +// comparableReadOnly is an immutable struct stored atomically in the ComparableMap.read field. +type comparableReadOnly[K comparable, V comparable] struct { + m map[K]*comparableEntry[V] + amended bool // true if the dirty map contains some key not in m. +} + +// comparableExpunged is an arbitrary pointer that marks entries which have been deleted +// from the dirty map. +var comparableExpunged = unsafe.Pointer(new(any)) + +// An comparableEntry is a slot in the map corresponding to a particular key. +type comparableEntry[V comparable] struct { + // p points to the value stored for the comparableEntry. + // + // If p == nil, the comparableEntry has been deleted, and either m.dirty == nil or + // m.dirty[key] is e. + // + // If p == comparableExpunged, the comparableEntry has been deleted, m.dirty != nil, and the comparableEntry + // is missing from m.dirty. + // + // Otherwise, the comparableEntry is valid and recorded in m.read.m[key] and, if m.dirty + // != nil, in m.dirty[key]. + // + // An comparableEntry can be deleted by atomic replacement with nil: when m.dirty is + // next created, it will atomically replace nil with comparableExpunged and leave + // m.dirty[key] unset. + // + // An comparableEntry's associated value can be updated by atomic replacement, provided + // p != comparableExpunged. If p == comparableExpunged, an comparableEntry's associated value can be updated + // only after first setting m.dirty[key] = e so that lookups using the dirty + // map find the comparableEntry. + p unsafe.Pointer +} + +func newComparableEntry[V comparable](i V) *comparableEntry[V] { + return &comparableEntry[V]{p: unsafe.Pointer(&i)} +} + +func (m *ComparableMap[K, V]) loadReadOnly() comparableReadOnly[K, V] { + if p := m.read.Load(); p != nil { + return *p + } + return comparableReadOnly[K, V]{} +} + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (m *ComparableMap[K, V]) Load(key K) (value V, ok bool) { + read := m.loadReadOnly() + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + // Avoid reporting a spurious miss if m.dirty got promoted while we were + // blocked on m.mu. (If further loads of the same key will not miss, it's + // not worth copying the dirty map for this key.) + read = m.loadReadOnly() + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + // Regardless of whether the comparableEntry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + if !ok { + return *new(V), false + } + return e.load() +} + +func (e *comparableEntry[V]) load() (value V, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == nil || p == comparableExpunged { + return value, false + } + return *(*V)(p), true +} + +// Store sets the value for a key. +func (m *ComparableMap[K, V]) Store(key K, value V) { + _, _ = m.Swap(key, value) +} + +// Clear deletes all the entries, resulting in an empty ComparableMap. +func (m *ComparableMap[K, V]) Clear() { + read := m.loadReadOnly() + if len(read.m) == 0 && !read.amended { + // Avoid allocating a new comparableReadOnly when the map is already clear. + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + read = m.loadReadOnly() + if len(read.m) > 0 || read.amended { + m.read.Store(&comparableReadOnly[K, V]{}) + } + + clear(m.dirty) + // Don't immediately promote the newly-cleared dirty map on the next operation. + m.misses = 0 +} + +// tryCompareAndSwap compare the comparableEntry with the given old value and swaps +// it with a new value if the comparableEntry is equal to the old value, and the comparableEntry +// has not been comparableExpunged. +// +// If the comparableEntry is comparableExpunged, tryCompareAndSwap returns false and leaves +// the comparableEntry unchanged. +func (e *comparableEntry[V]) tryCompareAndSwap(old V, new V) bool { + p := atomic.LoadPointer(&e.p) + if p == nil || p == comparableExpunged || *(*V)(p) != old { // XXX + return false + } + + // Copy the pointer after the first load to make this method more amenable + // to escape analysis: if the comparison fails from the start, we shouldn't + // bother heap-allocating a pointer to store. + nc := new + for { + if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(&nc)) { + return true + } + p = atomic.LoadPointer(&e.p) + if p == nil || p == comparableExpunged || *(*V)(p) != old { + return false + } + } +} + +// unexpungeLocked ensures that the comparableEntry is not marked as comparableExpunged. +// +// If the comparableEntry was previously comparableExpunged, it must be added to the dirty map +// before m.mu is unlocked. +func (e *comparableEntry[V]) unexpungeLocked() (wasExpunged bool) { + return atomic.CompareAndSwapPointer(&e.p, comparableExpunged, nil) +} + +// swapLocked unconditionally swaps a value into the comparableEntry. +// +// The comparableEntry must be known not to be comparableExpunged. +func (e *comparableEntry[V]) swapLocked(i *V) *V { + return (*V)(atomic.SwapPointer(&e.p, unsafe.Pointer(i))) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *ComparableMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + // Avoid locking if it's a clean hit. + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + actual, loaded, ok := e.tryLoadOrStore(value) + if ok { + return actual, loaded + } + } + + m.mu.Lock() + read = m.loadReadOnly() + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + m.dirty[key] = e + } + actual, loaded, _ = e.tryLoadOrStore(value) + } else if e, ok := m.dirty[key]; ok { + actual, loaded, _ = e.tryLoadOrStore(value) + m.missLocked() + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(&comparableReadOnly[K, V]{m: read.m, amended: true}) + } + m.dirty[key] = newComparableEntry(value) + actual, loaded = value, false + } + m.mu.Unlock() + + return actual, loaded +} + +// tryLoadOrStore atomically loads or stores a value if the comparableEntry is not +// comparableExpunged. +// +// If the comparableEntry is comparableExpunged, tryLoadOrStore leaves the comparableEntry unchanged and +// returns with ok==false. +func (e *comparableEntry[V]) tryLoadOrStore(i V) (actual V, loaded, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == comparableExpunged { + return actual, false, false + } + if p != nil { + return *(*V)(p), true, true + } + + // Copy the pointer after the first load to make this method more amenable + // to escape analysis: if we hit the "load" path or the comparableEntry is comparableExpunged, we + // shouldn't bother heap-allocating. + ic := i + for { + if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) { + return i, false, true + } + p = atomic.LoadPointer(&e.p) + if p == comparableExpunged { + return actual, false, false + } + if p != nil { + return *(*V)(p), true, true + } + } +} + +// LoadAndDelete deletes the value for a key, returning the previous value if any. +// The loaded result reports whether the key was present. +func (m *ComparableMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + read := m.loadReadOnly() + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + read = m.loadReadOnly() + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + delete(m.dirty, key) + // Regardless of whether the comparableEntry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + if ok { + return e.delete() + } + return value, false +} + +// Delete deletes the value for a key. +func (m *ComparableMap[K, V]) Delete(key K) { + m.LoadAndDelete(key) +} + +func (e *comparableEntry[V]) delete() (value V, ok bool) { + for { + p := atomic.LoadPointer(&e.p) + if p == nil || p == comparableExpunged { + return value, false + } + if atomic.CompareAndSwapPointer(&e.p, p, nil) { + return *(*V)(p), true + } + } +} + +// trySwap swaps a value if the comparableEntry has not been comparableExpunged. +// +// If the comparableEntry is comparableExpunged, trySwap returns false and leaves the comparableEntry +// unchanged. +func (e *comparableEntry[V]) trySwap(i *V) (*V, bool) { + for { + p := atomic.LoadPointer(&e.p) + if p == comparableExpunged { + return nil, false + } + if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) { + return (*V)(p), true + } + } +} + +// Swap swaps the value for a key and returns the previous value if any. +// The loaded result reports whether the key was present. +func (m *ComparableMap[K, V]) Swap(key K, value V) (previous V, loaded bool) { + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + if v, ok := e.trySwap(&value); ok { + if v == nil { + return previous, false + } + return *v, true + } + } + + m.mu.Lock() + read = m.loadReadOnly() + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + // The comparableEntry was previously comparableExpunged, which implies that there is a + // non-nil dirty map and this comparableEntry is not in it. + m.dirty[key] = e + } + if v := e.swapLocked(&value); v != nil { + loaded = true + previous = *v + } + } else if e, ok := m.dirty[key]; ok { + if v := e.swapLocked(&value); v != nil { + loaded = true + previous = *v + } + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(&comparableReadOnly[K, V]{m: read.m, amended: true}) + } + m.dirty[key] = newComparableEntry(value) + } + m.mu.Unlock() + return previous, loaded +} + +// CompareAndSwap swaps the old and new values for key +// if the value stored in the map is equal to old. +// The old value must be of a comparable type. +func (m *ComparableMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) { + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + return e.tryCompareAndSwap(old, new) + } else if !read.amended { + return false // No existing value for key. + } + + m.mu.Lock() + defer m.mu.Unlock() + read = m.loadReadOnly() + swapped = false + if e, ok := read.m[key]; ok { + swapped = e.tryCompareAndSwap(old, new) + } else if e, ok := m.dirty[key]; ok { + swapped = e.tryCompareAndSwap(old, new) + // We needed to lock mu in order to load the comparableEntry for key, + // and the operation didn't change the set of keys in the map + // (so it would be made more efficient by promoting the dirty + // map to read-only). + // Count it as a miss so that we will eventually switch to the + // more efficient steady state. + m.missLocked() + } + return swapped +} + +// CompareAndDelete deletes the comparableEntry for key if its value is equal to old. +// The old value must be of a comparable type. +// +// If there is no current value for key in the map, CompareAndDelete +// returns false (even if the old value is a nil pointer). +func (m *ComparableMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + read := m.loadReadOnly() + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + read = m.loadReadOnly() + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + // Don't delete key from m.dirty: we still need to do the “compare” part + // of the operation. The comparableEntry will eventually be comparableExpunged when the + // dirty map is promoted to the read map. + // + // Regardless of whether the comparableEntry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + for ok { + p := atomic.LoadPointer(&e.p) + if p == nil || p == comparableExpunged || *(*V)(p) != old { + return false + } + if atomic.CompareAndSwapPointer(&e.p, p, nil) { + return true + } + } + return false +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the ComparableMap's +// contents: no key will be visited more than once, but if the value for any key +// is stored or deleted concurrently (including by f), Range may reflect any +// mapping for that key from any point during the Range call. Range does not +// block other methods on the receiver; even f itself may call any method on m. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (m *ComparableMap[K, V]) Range(f func(key K, value V) bool) { + // We need to be able to iterate over all of the keys that were already + // present at the start of the call to Range. + // If read.amended is false, then read.m satisfies that property without + // requiring us to hold m.mu for a long time. + read := m.loadReadOnly() + if read.amended { + // m.dirty contains keys not in read.m. Fortunately, Range is already O(N) + // (assuming the caller does not break out early), so a call to Range + // amortizes an entire copy of the map: we can promote the dirty copy + // immediately! + m.mu.Lock() + read = m.loadReadOnly() + if read.amended { + read = comparableReadOnly[K, V]{m: m.dirty} + copyRead := read + m.read.Store(©Read) + m.dirty = nil + m.misses = 0 + } + m.mu.Unlock() + } + + for k, e := range read.m { + v, ok := e.load() + if !ok { + continue + } + if !f(k, v) { + break + } + } +} + +func (m *ComparableMap[K, V]) missLocked() { + m.misses++ + if m.misses < len(m.dirty) { + return + } + m.read.Store(&comparableReadOnly[K, V]{m: m.dirty}) + m.dirty = nil + m.misses = 0 +} + +func (m *ComparableMap[K, V]) dirtyLocked() { + if m.dirty != nil { + return + } + + read := m.loadReadOnly() + m.dirty = make(map[K]*comparableEntry[V], len(read.m)) + for k, e := range read.m { + if !e.tryExpungeLocked() { + m.dirty[k] = e + } + } +} + +func (e *comparableEntry[V]) tryExpungeLocked() (isExpunged bool) { + p := atomic.LoadPointer(&e.p) + for p == nil { + if atomic.CompareAndSwapPointer(&e.p, nil, comparableExpunged) { + return true + } + p = atomic.LoadPointer(&e.p) + } + return p == comparableExpunged +} diff --git a/forged/internal/common/cmap/map.go b/forged/internal/common/cmap/map.go new file mode 100644 index 0000000..7a1fe5b --- /dev/null +++ b/forged/internal/common/cmap/map.go @@ -0,0 +1,446 @@ +// Inspired by github.com/SaveTheRbtz/generic-sync-map-go but technically +// written from scratch with Go 1.23's sync.Map. +// Copyright 2024 Runxi Yu (porting it to generics) +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cmap provides a generic Map safe for concurrent use. +package cmap + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +// Map[K comparable, V any] is like a Go map[K]V but is safe for concurrent use +// by multiple goroutines without additional locking or coordination. Loads, +// stores, and deletes run in amortized constant time. +// +// The Map type is optimized for two common use cases: (1) when the entry for a given +// key is only ever written once but read many times, as in caches that only grow, +// or (2) when multiple goroutines read, write, and overwrite entries for disjoint +// sets of keys. In these two cases, use of a Map may significantly reduce lock +// contention compared to a Go map paired with a separate [Mutex] or [RWMutex]. +// +// The zero Map is empty and ready for use. A Map must not be copied after first use. +// +// In the terminology of [the Go memory model], Map arranges that a write operation +// “synchronizes before” any read operation that observes the effect of the write, where +// read and write operations are defined as follows. +// [Map.Load], [Map.LoadAndDelete], [Map.LoadOrStore], [Map.Swap], [Map.CompareAndSwap], +// and [Map.CompareAndDelete] are read operations; +// [Map.Delete], [Map.LoadAndDelete], [Map.Store], and [Map.Swap] are write operations; +// [Map.LoadOrStore] is a write operation when it returns loaded set to false; +// [Map.CompareAndSwap] is a write operation when it returns swapped set to true; +// and [Map.CompareAndDelete] is a write operation when it returns deleted set to true. +// +// [the Go memory model]: https://go.dev/ref/mem +type Map[K comparable, V any] struct { + mu sync.Mutex + + // read contains the portion of the map's contents that are safe for + // concurrent access (with or without mu held). + // + // The read field itself is always safe to load, but must only be stored with + // mu held. + // + // Entries stored in read may be updated concurrently without mu, but updating + // a previously-expunged entry requires that the entry be copied to the dirty + // map and unexpunged with mu held. + read atomic.Pointer[readOnly[K, V]] + + // dirty contains the portion of the map's contents that require mu to be + // held. To ensure that the dirty map can be promoted to the read map quickly, + // it also includes all of the non-expunged entries in the read map. + // + // Expunged entries are not stored in the dirty map. An expunged entry in the + // clean map must be unexpunged and added to the dirty map before a new value + // can be stored to it. + // + // If the dirty map is nil, the next write to the map will initialize it by + // making a shallow copy of the clean map, omitting stale entries. + dirty map[K]*entry[V] + + // misses counts the number of loads since the read map was last updated that + // needed to lock mu to determine whether the key was present. + // + // Once enough misses have occurred to cover the cost of copying the dirty + // map, the dirty map will be promoted to the read map (in the unamended + // state) and the next store to the map will make a new dirty copy. + misses int +} + +// readOnly is an immutable struct stored atomically in the Map.read field. +type readOnly[K comparable, V any] struct { + m map[K]*entry[V] + amended bool // true if the dirty map contains some key not in m. +} + +// expunged is an arbitrary pointer that marks entries which have been deleted +// from the dirty map. +var expunged = unsafe.Pointer(new(any)) + +// An entry is a slot in the map corresponding to a particular key. +type entry[V any] struct { + // p points to the value stored for the entry. + // + // If p == nil, the entry has been deleted, and either m.dirty == nil or + // m.dirty[key] is e. + // + // If p == expunged, the entry has been deleted, m.dirty != nil, and the entry + // is missing from m.dirty. + // + // Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty + // != nil, in m.dirty[key]. + // + // An entry can be deleted by atomic replacement with nil: when m.dirty is + // next created, it will atomically replace nil with expunged and leave + // m.dirty[key] unset. + // + // An entry's associated value can be updated by atomic replacement, provided + // p != expunged. If p == expunged, an entry's associated value can be updated + // only after first setting m.dirty[key] = e so that lookups using the dirty + // map find the entry. + p unsafe.Pointer +} + +func newEntry[V any](i V) *entry[V] { + return &entry[V]{p: unsafe.Pointer(&i)} +} + +func (m *Map[K, V]) loadReadOnly() readOnly[K, V] { + if p := m.read.Load(); p != nil { + return *p + } + return readOnly[K, V]{} +} + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + read := m.loadReadOnly() + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + // Avoid reporting a spurious miss if m.dirty got promoted while we were + // blocked on m.mu. (If further loads of the same key will not miss, it's + // not worth copying the dirty map for this key.) + read = m.loadReadOnly() + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + // Regardless of whether the entry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + if !ok { + return *new(V), false + } + return e.load() +} + +func (e *entry[V]) load() (value V, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == nil || p == expunged { + return value, false + } + return *(*V)(p), true +} + +// Store sets the value for a key. +func (m *Map[K, V]) Store(key K, value V) { + _, _ = m.Swap(key, value) +} + +// Clear deletes all the entries, resulting in an empty Map. +func (m *Map[K, V]) Clear() { + read := m.loadReadOnly() + if len(read.m) == 0 && !read.amended { + // Avoid allocating a new readOnly when the map is already clear. + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + read = m.loadReadOnly() + if len(read.m) > 0 || read.amended { + m.read.Store(&readOnly[K, V]{}) + } + + clear(m.dirty) + // Don't immediately promote the newly-cleared dirty map on the next operation. + m.misses = 0 +} + +// unexpungeLocked ensures that the entry is not marked as expunged. +// +// If the entry was previously expunged, it must be added to the dirty map +// before m.mu is unlocked. +func (e *entry[V]) unexpungeLocked() (wasExpunged bool) { + return atomic.CompareAndSwapPointer(&e.p, expunged, nil) +} + +// swapLocked unconditionally swaps a value into the entry. +// +// The entry must be known not to be expunged. +func (e *entry[V]) swapLocked(i *V) *V { + return (*V)(atomic.SwapPointer(&e.p, unsafe.Pointer(i))) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + // Avoid locking if it's a clean hit. + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + actual, loaded, ok := e.tryLoadOrStore(value) + if ok { + return actual, loaded + } + } + + m.mu.Lock() + read = m.loadReadOnly() + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + m.dirty[key] = e + } + actual, loaded, _ = e.tryLoadOrStore(value) + } else if e, ok := m.dirty[key]; ok { + actual, loaded, _ = e.tryLoadOrStore(value) + m.missLocked() + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(&readOnly[K, V]{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + actual, loaded = value, false + } + m.mu.Unlock() + + return actual, loaded +} + +// tryLoadOrStore atomically loads or stores a value if the entry is not +// expunged. +// +// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and +// returns with ok==false. +func (e *entry[V]) tryLoadOrStore(i V) (actual V, loaded, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == expunged { + return actual, false, false + } + if p != nil { + return *(*V)(p), true, true + } + + // Copy the pointer after the first load to make this method more amenable + // to escape analysis: if we hit the "load" path or the entry is expunged, we + // shouldn't bother heap-allocating. + ic := i + for { + if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) { + return i, false, true + } + p = atomic.LoadPointer(&e.p) + if p == expunged { + return actual, false, false + } + if p != nil { + return *(*V)(p), true, true + } + } +} + +// LoadAndDelete deletes the value for a key, returning the previous value if any. +// The loaded result reports whether the key was present. +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + read := m.loadReadOnly() + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + read = m.loadReadOnly() + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + delete(m.dirty, key) + // Regardless of whether the entry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + if ok { + return e.delete() + } + return value, false +} + +// Delete deletes the value for a key. +func (m *Map[K, V]) Delete(key K) { + m.LoadAndDelete(key) +} + +func (e *entry[V]) delete() (value V, ok bool) { + for { + p := atomic.LoadPointer(&e.p) + if p == nil || p == expunged { + return value, false + } + if atomic.CompareAndSwapPointer(&e.p, p, nil) { + return *(*V)(p), true + } + } +} + +// trySwap swaps a value if the entry has not been expunged. +// +// If the entry is expunged, trySwap returns false and leaves the entry +// unchanged. +func (e *entry[V]) trySwap(i *V) (*V, bool) { + for { + p := atomic.LoadPointer(&e.p) + if p == expunged { + return nil, false + } + if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) { + return (*V)(p), true + } + } +} + +// Swap swaps the value for a key and returns the previous value if any. +// The loaded result reports whether the key was present. +func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) { + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + if v, ok := e.trySwap(&value); ok { + if v == nil { + return previous, false + } + return *v, true + } + } + + m.mu.Lock() + read = m.loadReadOnly() + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + // The entry was previously expunged, which implies that there is a + // non-nil dirty map and this entry is not in it. + m.dirty[key] = e + } + if v := e.swapLocked(&value); v != nil { + loaded = true + previous = *v + } + } else if e, ok := m.dirty[key]; ok { + if v := e.swapLocked(&value); v != nil { + loaded = true + previous = *v + } + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(&readOnly[K, V]{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + } + m.mu.Unlock() + return previous, loaded +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the Map's +// contents: no key will be visited more than once, but if the value for any key +// is stored or deleted concurrently (including by f), Range may reflect any +// mapping for that key from any point during the Range call. Range does not +// block other methods on the receiver; even f itself may call any method on m. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + // We need to be able to iterate over all of the keys that were already + // present at the start of the call to Range. + // If read.amended is false, then read.m satisfies that property without + // requiring us to hold m.mu for a long time. + read := m.loadReadOnly() + if read.amended { + // m.dirty contains keys not in read.m. Fortunately, Range is already O(N) + // (assuming the caller does not break out early), so a call to Range + // amortizes an entire copy of the map: we can promote the dirty copy + // immediately! + m.mu.Lock() + read = m.loadReadOnly() + if read.amended { + read = readOnly[K, V]{m: m.dirty} + copyRead := read + m.read.Store(©Read) + m.dirty = nil + m.misses = 0 + } + m.mu.Unlock() + } + + for k, e := range read.m { + v, ok := e.load() + if !ok { + continue + } + if !f(k, v) { + break + } + } +} + +func (m *Map[K, V]) missLocked() { + m.misses++ + if m.misses < len(m.dirty) { + return + } + m.read.Store(&readOnly[K, V]{m: m.dirty}) + m.dirty = nil + m.misses = 0 +} + +func (m *Map[K, V]) dirtyLocked() { + if m.dirty != nil { + return + } + + read := m.loadReadOnly() + m.dirty = make(map[K]*entry[V], len(read.m)) + for k, e := range read.m { + if !e.tryExpungeLocked() { + m.dirty[k] = e + } + } +} + +func (e *entry[V]) tryExpungeLocked() (isExpunged bool) { + p := atomic.LoadPointer(&e.p) + for p == nil { + if atomic.CompareAndSwapPointer(&e.p, nil, expunged) { + return true + } + p = atomic.LoadPointer(&e.p) + } + return p == expunged +} diff --git a/forged/internal/common/humanize/bytes.go b/forged/internal/common/humanize/bytes.go new file mode 100644 index 0000000..bea504c --- /dev/null +++ b/forged/internal/common/humanize/bytes.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: Copyright (c) 2005-2008 Dustin Sallings <dustin@spy.net> +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// Package humanize provides functions to convert numbers into human-readable formats. +package humanize + +import ( + "fmt" + "math" +) + +// IBytes produces a human readable representation of an IEC size. +func IBytes(s uint64) string { + sizes := []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"} + return humanateBytes(s, 1024, sizes) +} + +func humanateBytes(s uint64, base float64, sizes []string) string { + if s < 10 { + return fmt.Sprintf("%d B", s) + } + e := math.Floor(logn(float64(s), base)) + suffix := sizes[int(e)] + val := math.Floor(float64(s)/math.Pow(base, e)*10+0.5) / 10 + f := "%.0f %s" + if val < 10 { + f = "%.1f %s" + } + + return fmt.Sprintf(f, val, suffix) +} + +func logn(n, b float64) float64 { + return math.Log(n) / math.Log(b) +} diff --git a/forged/internal/common/misc/back.go b/forged/internal/common/misc/back.go new file mode 100644 index 0000000..5351359 --- /dev/null +++ b/forged/internal/common/misc/back.go @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package misc + +// ErrorBack wraps a value and a channel for communicating an associated error. +// Typically used to get an error response after sending data across a channel. +type ErrorBack[T any] struct { + Content T + ErrorChan chan error +} diff --git a/forged/internal/common/misc/iter.go b/forged/internal/common/misc/iter.go new file mode 100644 index 0000000..61a96f4 --- /dev/null +++ b/forged/internal/common/misc/iter.go @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package misc + +import "iter" + +// iterSeqLimit returns an iterator equivalent to the supplied one, but stops +// after n iterations. +func IterSeqLimit[T any](s iter.Seq[T], n uint) iter.Seq[T] { + return func(yield func(T) bool) { + var iterations uint + for v := range s { + if iterations > n-1 { + return + } + if !yield(v) { + return + } + iterations++ + } + } +} diff --git a/forged/internal/common/misc/misc.go b/forged/internal/common/misc/misc.go new file mode 100644 index 0000000..e9e10ab --- /dev/null +++ b/forged/internal/common/misc/misc.go @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// Package misc provides miscellaneous functions and other definitions. +package misc diff --git a/forged/internal/common/misc/net.go b/forged/internal/common/misc/net.go new file mode 100644 index 0000000..967ea77 --- /dev/null +++ b/forged/internal/common/misc/net.go @@ -0,0 +1,42 @@ +package misc + +import ( + "context" + "errors" + "fmt" + "net" + "syscall" +) + +func ListenUnixSocket(ctx context.Context, path string) (listener net.Listener, replaced bool, err error) { + listenConfig := net.ListenConfig{} //exhaustruct:ignore + listener, err = listenConfig.Listen(ctx, "unix", path) + if errors.Is(err, syscall.EADDRINUSE) { + replaced = true + unlinkErr := syscall.Unlink(path) + if unlinkErr != nil { + return listener, false, fmt.Errorf("remove existing socket %q: %w", path, unlinkErr) + } + listener, err = listenConfig.Listen(ctx, "unix", path) + } + if err != nil { + return listener, replaced, fmt.Errorf("listen on unix socket %q: %w", path, err) + } + return listener, replaced, nil +} + +func Listen(ctx context.Context, net_, addr string) (listener net.Listener, err error) { + if net_ == "unix" { + listener, _, err = ListenUnixSocket(ctx, addr) + if err != nil { + return listener, fmt.Errorf("listen unix socket for web: %w", err) + } + } else { + listenConfig := net.ListenConfig{} //exhaustruct:ignore + listener, err = listenConfig.Listen(ctx, net_, addr) + if err != nil { + return listener, fmt.Errorf("listen %s for web: %w", net_, err) + } + } + return listener, nil +} diff --git a/forged/internal/common/misc/slices.go b/forged/internal/common/misc/slices.go new file mode 100644 index 0000000..3ad0211 --- /dev/null +++ b/forged/internal/common/misc/slices.go @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package misc + +import "strings" + +// sliceContainsNewlines returns true if and only if the given slice contains +// one or more strings that contains newlines. +func SliceContainsNewlines(s []string) bool { + for _, v := range s { + if strings.Contains(v, "\n") { + return true + } + } + return false +} diff --git a/forged/internal/common/misc/trivial.go b/forged/internal/common/misc/trivial.go new file mode 100644 index 0000000..83901e0 --- /dev/null +++ b/forged/internal/common/misc/trivial.go @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package misc + +import ( + "net/url" + "strings" +) + +// These are all trivial functions that are intended to be used in HTML +// templates. + +// FirstLine returns the first line of a string. +func FirstLine(s string) string { + before, _, _ := strings.Cut(s, "\n") + return before +} + +// PathEscape escapes the input as an URL path segment. +func PathEscape(s string) string { + return url.PathEscape(s) +} + +// QueryEscape escapes the input as an URL query segment. +func QueryEscape(s string) string { + return url.QueryEscape(s) +} + +// Dereference dereferences a pointer. +func Dereference[T any](p *T) T { //nolint:ireturn + return *p +} + +// DereferenceOrZero dereferences a pointer. If the pointer is nil, the zero +// value of its associated type is returned instead. +func DereferenceOrZero[T any](p *T) T { //nolint:ireturn + if p != nil { + return *p + } + var z T + return z +} + +// Minus subtracts two numbers. +func Minus(a, b int) int { + return a - b +} diff --git a/forged/internal/common/misc/unsafe.go b/forged/internal/common/misc/unsafe.go new file mode 100644 index 0000000..d827e7f --- /dev/null +++ b/forged/internal/common/misc/unsafe.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package misc + +import "unsafe" + +// StringToBytes converts a string to a byte slice without copying the string. +// Memory is borrowed from the string. +// The resulting byte slice must not be modified in any form. +func StringToBytes(s string) (bytes []byte) { + return unsafe.Slice(unsafe.StringData(s), len(s)) //#nosec G103 +} + +// BytesToString converts a byte slice to a string without copying the bytes. +// Memory is borrowed from the byte slice. +// The source byte slice must not be modified. +func BytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) //#nosec G103 +} diff --git a/forged/internal/common/misc/url.go b/forged/internal/common/misc/url.go new file mode 100644 index 0000000..346ff76 --- /dev/null +++ b/forged/internal/common/misc/url.go @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package misc + +import ( + "net/http" + "net/url" + "strings" +) + +// ParseReqURI parses an HTTP request URL, and returns a slice of path segments +// and the query parameters. It handles %2F correctly. +func ParseReqURI(requestURI string) (segments []string, params url.Values, err error) { + path, paramsStr, _ := strings.Cut(requestURI, "?") + + segments, err = PathToSegments(path) + if err != nil { + return + } + + params, err = url.ParseQuery(paramsStr) + return +} + +// PathToSegments splits a path into unescaped segments. It handles %2F correctly. +func PathToSegments(path string) (segments []string, err error) { + segments = strings.Split(strings.TrimPrefix(path, "/"), "/") + + for i, segment := range segments { + segments[i], err = url.PathUnescape(segment) + if err != nil { + return + } + } + + return +} + +// RedirectDir returns true and redirects the user to a version of the URL with +// a trailing slash, if and only if the request URL does not already have a +// trailing slash. +func RedirectDir(writer http.ResponseWriter, request *http.Request) bool { + requestURI := request.RequestURI + + pathEnd := strings.IndexAny(requestURI, "?#") + var path, rest string + if pathEnd == -1 { + path = requestURI + } else { + path = requestURI[:pathEnd] + rest = requestURI[pathEnd:] + } + + if !strings.HasSuffix(path, "/") { + http.Redirect(writer, request, path+"/"+rest, http.StatusSeeOther) + return true + } + return false +} + +// RedirectNoDir returns true and redirects the user to a version of the URL +// without a trailing slash, if and only if the request URL has a trailing +// slash. +func RedirectNoDir(writer http.ResponseWriter, request *http.Request) bool { + requestURI := request.RequestURI + + pathEnd := strings.IndexAny(requestURI, "?#") + var path, rest string + if pathEnd == -1 { + path = requestURI + } else { + path = requestURI[:pathEnd] + rest = requestURI[pathEnd:] + } + + if strings.HasSuffix(path, "/") { + http.Redirect(writer, request, strings.TrimSuffix(path, "/")+rest, http.StatusSeeOther) + return true + } + return false +} + +// RedirectUnconditionally unconditionally redirects the user back to the +// current page while preserving query parameters. +func RedirectUnconditionally(writer http.ResponseWriter, request *http.Request) { + requestURI := request.RequestURI + + pathEnd := strings.IndexAny(requestURI, "?#") + var path, rest string + if pathEnd == -1 { + path = requestURI + } else { + path = requestURI[:pathEnd] + rest = requestURI[pathEnd:] + } + + http.Redirect(writer, request, path+rest, http.StatusSeeOther) +} + +// SegmentsToURL joins URL segments to the path component of a URL. +// Each segment is escaped properly first. +func SegmentsToURL(segments []string) string { + for i, segment := range segments { + segments[i] = url.PathEscape(segment) + } + return strings.Join(segments, "/") +} + +// AnyContain returns true if and only if ss contains a string that contains c. +func AnyContain(ss []string, c string) bool { + for _, s := range ss { + if strings.Contains(s, c) { + return true + } + } + return false +} diff --git a/forged/internal/common/scfg/.golangci.yaml b/forged/internal/common/scfg/.golangci.yaml new file mode 100644 index 0000000..59f1970 --- /dev/null +++ b/forged/internal/common/scfg/.golangci.yaml @@ -0,0 +1,26 @@ +linters: + enable-all: true + disable: + - perfsprint + - wsl + - varnamelen + - nlreturn + - exhaustruct + - wrapcheck + - lll + - exhaustive + - intrange + - godox + - nestif + - err113 + - staticcheck + - errorlint + - cyclop + - nonamedreturns + - funlen + - gochecknoglobals + - tenv + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 diff --git a/forged/internal/common/scfg/LICENSE b/forged/internal/common/scfg/LICENSE new file mode 100644 index 0000000..3649823 --- /dev/null +++ b/forged/internal/common/scfg/LICENSE @@ -0,0 +1,18 @@ +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/forged/internal/common/scfg/reader.go b/forged/internal/common/scfg/reader.go new file mode 100644 index 0000000..b0e2cc0 --- /dev/null +++ b/forged/internal/common/scfg/reader.go @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2020 Simon Ser <https://emersion.fr> + +package scfg + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" +) + +// This limits the max block nesting depth to prevent stack overflows. +const maxNestingDepth = 1000 + +// Load loads a configuration file. +func Load(path string) (block Block, err error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer func() { + if cerr := f.Close(); err == nil && cerr != nil { + err = cerr + } + }() + + return Read(f) +} + +// Read parses a configuration file from an io.Reader. +func Read(r io.Reader) (Block, error) { + scanner := bufio.NewScanner(r) + + dec := decoder{scanner: scanner} + block, closingBrace, err := dec.readBlock() + if err != nil { + return nil, err + } else if closingBrace { + return nil, fmt.Errorf("line %v: unexpected '}'", dec.lineno) + } + + return block, scanner.Err() +} + +type decoder struct { + scanner *bufio.Scanner + lineno int + blockDepth int +} + +// readBlock reads a block. closingBrace is true if parsing stopped on '}' +// (otherwise, it stopped on Scanner.Scan). +func (dec *decoder) readBlock() (block Block, closingBrace bool, err error) { + dec.blockDepth++ + defer func() { + dec.blockDepth-- + }() + + if dec.blockDepth >= maxNestingDepth { + return nil, false, fmt.Errorf("exceeded max block depth") + } + + for dec.scanner.Scan() { + dec.lineno++ + + l := dec.scanner.Text() + words, err := splitWords(l) + if err != nil { + return nil, false, fmt.Errorf("line %v: %v", dec.lineno, err) + } else if len(words) == 0 { + continue + } + + if len(words) == 1 && l[len(l)-1] == '}' { + closingBrace = true + break + } + + var d *Directive + if words[len(words)-1] == "{" && l[len(l)-1] == '{' { + words = words[:len(words)-1] + + var name string + params := words + if len(words) > 0 { + name, params = words[0], words[1:] + } + + startLineno := dec.lineno + childBlock, childClosingBrace, err := dec.readBlock() + if err != nil { + return nil, false, err + } else if !childClosingBrace { + return nil, false, fmt.Errorf("line %v: unterminated block", startLineno) + } + + // Allows callers to tell apart "no block" and "empty block" + if childBlock == nil { + childBlock = Block{} + } + + d = &Directive{Name: name, Params: params, Children: childBlock, lineno: dec.lineno} + } else { + d = &Directive{Name: words[0], Params: words[1:], lineno: dec.lineno} + } + block = append(block, d) + } + + return block, closingBrace, nil +} + +func splitWords(l string) ([]string, error) { + var ( + words []string + sb strings.Builder + escape bool + quote rune + wantWSP bool + ) + for _, ch := range l { + switch { + case escape: + sb.WriteRune(ch) + escape = false + case wantWSP && (ch != ' ' && ch != '\t'): + return words, fmt.Errorf("atom not allowed after quoted string") + case ch == '\\': + escape = true + case quote != 0 && ch == quote: + quote = 0 + wantWSP = true + if sb.Len() == 0 { + words = append(words, "") + } + case quote == 0 && len(words) == 0 && sb.Len() == 0 && ch == '#': + return nil, nil + case quote == 0 && (ch == '\'' || ch == '"'): + if sb.Len() > 0 { + return words, fmt.Errorf("quoted string not allowed after atom") + } + quote = ch + case quote == 0 && (ch == ' ' || ch == '\t'): + if sb.Len() > 0 { + words = append(words, sb.String()) + } + sb.Reset() + wantWSP = false + default: + sb.WriteRune(ch) + } + } + if quote != 0 { + return words, fmt.Errorf("unterminated quoted string") + } + if sb.Len() > 0 { + words = append(words, sb.String()) + } + return words, nil +} diff --git a/forged/internal/common/scfg/scfg.go b/forged/internal/common/scfg/scfg.go new file mode 100644 index 0000000..4533e63 --- /dev/null +++ b/forged/internal/common/scfg/scfg.go @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2020 Simon Ser <https://emersion.fr> + +// Package scfg parses and formats configuration files. +// Note that this fork of scfg behaves differently from upstream scfg. +package scfg + +import ( + "fmt" +) + +// Block is a list of directives. +type Block []*Directive + +// GetAll returns a list of directives with the provided name. +func (blk Block) GetAll(name string) []*Directive { + l := make([]*Directive, 0, len(blk)) + for _, child := range blk { + if child.Name == name { + l = append(l, child) + } + } + return l +} + +// Get returns the first directive with the provided name. +func (blk Block) Get(name string) *Directive { + for _, child := range blk { + if child.Name == name { + return child + } + } + return nil +} + +// Directive is a configuration directive. +type Directive struct { + Name string + Params []string + + Children Block + + lineno int +} + +// ParseParams extracts parameters from the directive. It errors out if the +// user hasn't provided enough parameters. +func (d *Directive) ParseParams(params ...*string) error { + if len(d.Params) < len(params) { + return fmt.Errorf("directive %q: want %v params, got %v", d.Name, len(params), len(d.Params)) + } + for i, ptr := range params { + if ptr == nil { + continue + } + *ptr = d.Params[i] + } + return nil +} diff --git a/forged/internal/common/scfg/struct.go b/forged/internal/common/scfg/struct.go new file mode 100644 index 0000000..98ec943 --- /dev/null +++ b/forged/internal/common/scfg/struct.go @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2020 Simon Ser <https://emersion.fr> + +package scfg + +import ( + "fmt" + "reflect" + "strings" + "sync" +) + +// structInfo contains scfg metadata for structs. +type structInfo struct { + param int // index of field storing parameters + children map[string]int // indices of fields storing child directives +} + +var ( + structCacheMutex sync.Mutex + structCache = make(map[reflect.Type]*structInfo) +) + +func getStructInfo(t reflect.Type) (*structInfo, error) { + structCacheMutex.Lock() + defer structCacheMutex.Unlock() + + if info := structCache[t]; info != nil { + return info, nil + } + + info := &structInfo{ + param: -1, + children: make(map[string]int), + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Anonymous { + return nil, fmt.Errorf("scfg: anonymous struct fields are not supported") + } else if !f.IsExported() { + continue + } + + tag := f.Tag.Get("scfg") + parts := strings.Split(tag, ",") + k, options := parts[0], parts[1:] + if k == "-" { + continue + } else if k == "" { + k = f.Name + } + + isParam := false + for _, opt := range options { + switch opt { + case "param": + isParam = true + default: + return nil, fmt.Errorf("scfg: invalid option %q in struct tag", opt) + } + } + + if isParam { + if info.param >= 0 { + return nil, fmt.Errorf("scfg: param option specified multiple times in struct tag in %v", t) + } + if parts[0] != "" { + return nil, fmt.Errorf("scfg: name must be empty when param option is specified in struct tag in %v", t) + } + info.param = i + } else { + if _, ok := info.children[k]; ok { + return nil, fmt.Errorf("scfg: key %q specified multiple times in struct tag in %v", k, t) + } + info.children[k] = i + } + } + + structCache[t] = info + return info, nil +} diff --git a/forged/internal/common/scfg/unmarshal.go b/forged/internal/common/scfg/unmarshal.go new file mode 100644 index 0000000..8befc10 --- /dev/null +++ b/forged/internal/common/scfg/unmarshal.go @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2020 Simon Ser <https://emersion.fr> +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package scfg + +import ( + "encoding" + "fmt" + "io" + "reflect" + "strconv" +) + +// Decoder reads and decodes an scfg document from an input stream. +type Decoder struct { + r io.Reader + unknownDirectives []*Directive +} + +// NewDecoder returns a new decoder which reads from r. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: r} +} + +// UnknownDirectives returns a slice of all unknown directives encountered +// during Decode. +func (dec *Decoder) UnknownDirectives() []*Directive { + return dec.unknownDirectives +} + +// Decode reads scfg document from the input and stores it in the value pointed +// to by v. +// +// If v is nil or not a pointer, Decode returns an error. +// +// Blocks can be unmarshaled to: +// +// - Maps. Each directive is unmarshaled into a map entry. The map key must +// be a string. +// - Structs. Each directive is unmarshaled into a struct field. +// +// Duplicate directives are not allowed, unless the struct field or map value +// is a slice of values representing a directive: structs or maps. +// +// Directives can be unmarshaled to: +// +// - Maps. The children block is unmarshaled into the map. Parameters are not +// allowed. +// - Structs. The children block is unmarshaled into the struct. Parameters +// are allowed if one of the struct fields contains the "param" option in +// its tag. +// - Slices. Parameters are unmarshaled into the slice. Children blocks are +// not allowed. +// - Arrays. Parameters are unmarshaled into the array. The number of +// parameters must match exactly the length of the array. Children blocks +// are not allowed. +// - Strings, booleans, integers, floating-point values, values implementing +// encoding.TextUnmarshaler. Only a single parameter is allowed and is +// unmarshaled into the value. Children blocks are not allowed. +// +// The decoding of each struct field can be customized by the format string +// stored under the "scfg" key in the struct field's tag. The tag contains the +// name of the field possibly followed by a comma-separated list of options. +// The name may be empty in order to specify options without overriding the +// default field name. As a special case, if the field name is "-", the field +// is ignored. The "param" option specifies that directive parameters are +// stored in this field (the name must be empty). +func (dec *Decoder) Decode(v interface{}) error { + block, err := Read(dec.r) + if err != nil { + return err + } + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("scfg: invalid value for unmarshaling") + } + + return dec.unmarshalBlock(block, rv) +} + +func (dec *Decoder) unmarshalBlock(block Block, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + dirsByName := make(map[string][]*Directive, len(block)) + for _, dir := range block { + dirsByName[dir.Name] = append(dirsByName[dir.Name], dir) + } + + switch v.Kind() { + case reflect.Map: + if t.Key().Kind() != reflect.String { + return fmt.Errorf("scfg: map key type must be string") + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } else if v.Len() > 0 { + clearMap(v) + } + + for name, dirs := range dirsByName { + mv := reflect.New(t.Elem()).Elem() + if err := dec.unmarshalDirectiveList(dirs, mv); err != nil { + return err + } + v.SetMapIndex(reflect.ValueOf(name), mv) + } + + case reflect.Struct: + si, err := getStructInfo(t) + if err != nil { + return err + } + + seen := make(map[int]bool) + + for name, dirs := range dirsByName { + fieldIndex, ok := si.children[name] + if !ok { + dec.unknownDirectives = append(dec.unknownDirectives, dirs...) + continue + } + fv := v.Field(fieldIndex) + if err := dec.unmarshalDirectiveList(dirs, fv); err != nil { + return err + } + seen[fieldIndex] = true + } + + for name, fieldIndex := range si.children { + if fieldIndex == si.param { + continue + } + if _, ok := seen[fieldIndex]; !ok { + return fmt.Errorf("scfg: missing required directive %q", name) + } + } + + default: + return fmt.Errorf("scfg: unsupported type for unmarshaling blocks: %v", t) + } + + return nil +} + +func (dec *Decoder) unmarshalDirectiveList(dirs []*Directive, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + if v.Kind() != reflect.Slice || !isDirectiveType(t.Elem()) { + if len(dirs) > 1 { + return newUnmarshalDirectiveError(dirs[1], "directive must not be specified more than once") + } + return dec.unmarshalDirective(dirs[0], v) + } + + sv := reflect.MakeSlice(t, len(dirs), len(dirs)) + for i, dir := range dirs { + if err := dec.unmarshalDirective(dir, sv.Index(i)); err != nil { + return err + } + } + v.Set(sv) + return nil +} + +// isDirectiveType checks whether a type can only be unmarshaled as a +// directive, not as a parameter. Accepting too many types here would result in +// ambiguities, see: +// https://lists.sr.ht/~emersion/public-inbox/%3C20230629132458.152205-1-contact%40emersion.fr%3E#%3Ch4Y2peS_YBqY3ar4XlmPDPiNBFpYGns3EBYUx3_6zWEhV2o8_-fBQveRujGADWYhVVCucHBEryFGoPtpC3d3mQ-x10pWnFogfprbQTSvtxc=@emersion.fr%3E +func isDirectiveType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + textUnmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + if reflect.PointerTo(t).Implements(textUnmarshalerType) { + return false + } + + switch t.Kind() { + case reflect.Struct, reflect.Map: + return true + default: + return false + } +} + +func (dec *Decoder) unmarshalDirective(dir *Directive, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + if v.CanAddr() { + if _, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok { + if len(dir.Children) != 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero children") + } + return unmarshalParamList(dir, v) + } + } + + switch v.Kind() { + case reflect.Map: + if len(dir.Params) > 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero parameters") + } + if err := dec.unmarshalBlock(dir.Children, v); err != nil { + return err + } + case reflect.Struct: + si, err := getStructInfo(t) + if err != nil { + return err + } + + if si.param >= 0 { + if err := unmarshalParamList(dir, v.Field(si.param)); err != nil { + return err + } + } else { + if len(dir.Params) > 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero parameters") + } + } + + if err := dec.unmarshalBlock(dir.Children, v); err != nil { + return err + } + default: + if len(dir.Children) != 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero children") + } + if err := unmarshalParamList(dir, v); err != nil { + return err + } + } + return nil +} + +func unmarshalParamList(dir *Directive, v reflect.Value) error { + switch v.Kind() { + case reflect.Slice: + t := v.Type() + sv := reflect.MakeSlice(t, len(dir.Params), len(dir.Params)) + for i, param := range dir.Params { + if err := unmarshalParam(param, sv.Index(i)); err != nil { + return newUnmarshalParamError(dir, i, err) + } + } + v.Set(sv) + case reflect.Array: + if len(dir.Params) != v.Len() { + return newUnmarshalDirectiveError(dir, fmt.Sprintf("directive requires exactly %v parameters", v.Len())) + } + for i, param := range dir.Params { + if err := unmarshalParam(param, v.Index(i)); err != nil { + return newUnmarshalParamError(dir, i, err) + } + } + default: + if len(dir.Params) != 1 { + return newUnmarshalDirectiveError(dir, "directive requires exactly one parameter") + } + if err := unmarshalParam(dir.Params[0], v); err != nil { + return newUnmarshalParamError(dir, 0, err) + } + } + + return nil +} + +func unmarshalParam(param string, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + // TODO: improve our logic following: + // https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/encoding/json/decode.go;drc=b9b8cecbfc72168ca03ad586cc2ed52b0e8db409;l=421 + if v.CanAddr() { + if v, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok { + return v.UnmarshalText([]byte(param)) + } + } + + switch v.Kind() { + case reflect.String: + v.Set(reflect.ValueOf(param)) + case reflect.Bool: + switch param { + case "true": + v.Set(reflect.ValueOf(true)) + case "false": + v.Set(reflect.ValueOf(false)) + default: + return fmt.Errorf("invalid bool parameter %q", param) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(param, 10, t.Bits()) + if err != nil { + return fmt.Errorf("invalid %v parameter: %v", t, err) + } + v.Set(reflect.ValueOf(i).Convert(t)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u, err := strconv.ParseUint(param, 10, t.Bits()) + if err != nil { + return fmt.Errorf("invalid %v parameter: %v", t, err) + } + v.Set(reflect.ValueOf(u).Convert(t)) + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(param, t.Bits()) + if err != nil { + return fmt.Errorf("invalid %v parameter: %v", t, err) + } + v.Set(reflect.ValueOf(f).Convert(t)) + default: + return fmt.Errorf("unsupported type for unmarshaling parameter: %v", t) + } + + return nil +} + +func unwrapPointers(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +func clearMap(v reflect.Value) { + for _, k := range v.MapKeys() { + v.SetMapIndex(k, reflect.Value{}) + } +} + +type unmarshalDirectiveError struct { + lineno int + name string + msg string +} + +func newUnmarshalDirectiveError(dir *Directive, msg string) *unmarshalDirectiveError { + return &unmarshalDirectiveError{ + name: dir.Name, + lineno: dir.lineno, + msg: msg, + } +} + +func (err *unmarshalDirectiveError) Error() string { + return fmt.Sprintf("line %v, directive %q: %v", err.lineno, err.name, err.msg) +} + +type unmarshalParamError struct { + lineno int + directive string + paramIndex int + err error +} + +func newUnmarshalParamError(dir *Directive, paramIndex int, err error) *unmarshalParamError { + return &unmarshalParamError{ + directive: dir.Name, + lineno: dir.lineno, + paramIndex: paramIndex, + err: err, + } +} + +func (err *unmarshalParamError) Error() string { + return fmt.Sprintf("line %v, directive %q, parameter %v: %v", err.lineno, err.directive, err.paramIndex+1, err.err) +} diff --git a/forged/internal/common/scfg/writer.go b/forged/internal/common/scfg/writer.go new file mode 100644 index 0000000..02a07fe --- /dev/null +++ b/forged/internal/common/scfg/writer.go @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2020 Simon Ser <https://emersion.fr> + +package scfg + +import ( + "errors" + "io" + "strings" +) + +var errDirEmptyName = errors.New("scfg: directive with empty name") + +// Write writes a parsed configuration to the provided io.Writer. +func Write(w io.Writer, blk Block) error { + enc := newEncoder(w) + err := enc.encodeBlock(blk) + return err +} + +// encoder write SCFG directives to an output stream. +type encoder struct { + w io.Writer + lvl int + err error +} + +// newEncoder returns a new encoder that writes to w. +func newEncoder(w io.Writer) *encoder { + return &encoder{w: w} +} + +func (enc *encoder) push() { + enc.lvl++ +} + +func (enc *encoder) pop() { + enc.lvl-- +} + +func (enc *encoder) writeIndent() { + for i := 0; i < enc.lvl; i++ { + enc.write([]byte("\t")) + } +} + +func (enc *encoder) write(p []byte) { + if enc.err != nil { + return + } + _, enc.err = enc.w.Write(p) +} + +func (enc *encoder) encodeBlock(blk Block) error { + for _, dir := range blk { + if err := enc.encodeDir(*dir); err != nil { + return err + } + } + return enc.err +} + +func (enc *encoder) encodeDir(dir Directive) error { + if enc.err != nil { + return enc.err + } + + if dir.Name == "" { + enc.err = errDirEmptyName + return enc.err + } + + enc.writeIndent() + enc.write([]byte(maybeQuote(dir.Name))) + for _, p := range dir.Params { + enc.write([]byte(" ")) + enc.write([]byte(maybeQuote(p))) + } + + if len(dir.Children) > 0 { + enc.write([]byte(" {\n")) + enc.push() + if err := enc.encodeBlock(dir.Children); err != nil { + return err + } + enc.pop() + + enc.writeIndent() + enc.write([]byte("}")) + } + enc.write([]byte("\n")) + + return enc.err +} + +const specialChars = "\"\\\r\n'{} \t" + +func maybeQuote(s string) string { + if s == "" || strings.ContainsAny(s, specialChars) { + var sb strings.Builder + sb.WriteByte('"') + for _, ch := range s { + if strings.ContainsRune(`"\`, ch) { + sb.WriteByte('\\') + } + sb.WriteRune(ch) + } + sb.WriteByte('"') + return sb.String() + } + return s +} diff --git a/forged/internal/config/config.go b/forged/internal/config/config.go new file mode 100644 index 0000000..1825882 --- /dev/null +++ b/forged/internal/config/config.go @@ -0,0 +1,111 @@ +package config + +import ( + "bufio" + "fmt" + "log/slog" + "os" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/scfg" +) + +type Config struct { + DB DB `scfg:"db"` + Web Web `scfg:"web"` + Hooks Hooks `scfg:"hooks"` + LMTP LMTP `scfg:"lmtp"` + SSH SSH `scfg:"ssh"` + IRC IRC `scfg:"irc"` + Git Git `scfg:"git"` + General General `scfg:"general"` + Pprof Pprof `scfg:"pprof"` +} + +type DB struct { + Conn string `scfg:"conn"` +} + +type Web struct { + Net string `scfg:"net"` + Addr string `scfg:"addr"` + Root string `scfg:"root"` + CookieExpiry int `scfg:"cookie_expiry"` + ReadTimeout uint32 `scfg:"read_timeout"` + WriteTimeout uint32 `scfg:"write_timeout"` + IdleTimeout uint32 `scfg:"idle_timeout"` + MaxHeaderBytes int `scfg:"max_header_bytes"` + ReverseProxy bool `scfg:"reverse_proxy"` + ShutdownTimeout uint32 `scfg:"shutdown_timeout"` + TemplatesPath string `scfg:"templates_path"` + StaticPath string `scfg:"static_path"` +} + +type Hooks struct { + Socket string `scfg:"socket"` + Execs string `scfg:"execs"` +} + +type LMTP struct { + Socket string `scfg:"socket"` + Domain string `scfg:"domain"` + MaxSize int64 `scfg:"max_size"` + WriteTimeout uint32 `scfg:"write_timeout"` + ReadTimeout uint32 `scfg:"read_timeout"` +} + +type SSH struct { + Net string `scfg:"net"` + Addr string `scfg:"addr"` + Key string `scfg:"key"` + Root string `scfg:"root"` + ShutdownTimeout uint32 `scfg:"shutdown_timeout"` +} + +type IRC struct { + Net string `scfg:"net"` + Addr string `scfg:"addr"` + TLS bool `scfg:"tls"` + SendQ uint `scfg:"sendq"` + Nick string `scfg:"nick"` + User string `scfg:"user"` + Gecos string `scfg:"gecos"` +} + +type Git struct { + RepoDir string `scfg:"repo_dir"` + Socket string `scfg:"socket"` +} + +type General struct { + Title string `scfg:"title"` +} + +type Pprof struct { + Net string `scfg:"net"` + Addr string `scfg:"addr"` +} + +func Open(path string) (config Config, err error) { + var configFile *os.File + + configFile, err = os.Open(path) //#nosec G304 + if err != nil { + err = fmt.Errorf("open config file: %w", err) + return config, err + } + defer func() { + _ = configFile.Close() + }() + + decoder := scfg.NewDecoder(bufio.NewReader(configFile)) + err = decoder.Decode(&config) + if err != nil { + err = fmt.Errorf("decode config file: %w", err) + return config, err + } + for _, u := range decoder.UnknownDirectives() { + slog.Warn("unknown configuration directive", "directive", u) + } + + return config, err +} diff --git a/forged/internal/database/database.go b/forged/internal/database/database.go new file mode 100644 index 0000000..d96af6b --- /dev/null +++ b/forged/internal/database/database.go @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// Package database provides stubs and wrappers for databases. +package database + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" +) + +type Database struct { + *pgxpool.Pool +} + +func Open(ctx context.Context, conn string) (Database, error) { + db, err := pgxpool.New(ctx, conn) + if err != nil { + err = fmt.Errorf("create pgxpool: %w", err) + } + return Database{db}, err +} diff --git a/forged/internal/database/queries/.gitignore b/forged/internal/database/queries/.gitignore new file mode 100644 index 0000000..1307f6d --- /dev/null +++ b/forged/internal/database/queries/.gitignore @@ -0,0 +1 @@ +/*.go diff --git a/forged/internal/global/global.go b/forged/internal/global/global.go new file mode 100644 index 0000000..99f85e7 --- /dev/null +++ b/forged/internal/global/global.go @@ -0,0 +1,18 @@ +package global + +import ( + "go.lindenii.runxiyu.org/forge/forged/internal/config" + "go.lindenii.runxiyu.org/forge/forged/internal/database" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" +) + +type Global struct { + ForgeTitle string // should be removed since it's in Config + ForgeVersion string + SSHPubkey string + SSHFingerprint string + + Config *config.Config + Queries *queries.Queries + DB *database.Database +} diff --git a/forged/internal/incoming/hooks/hooks.go b/forged/internal/incoming/hooks/hooks.go new file mode 100644 index 0000000..effd104 --- /dev/null +++ b/forged/internal/incoming/hooks/hooks.go @@ -0,0 +1,81 @@ +package hooks + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/gliderlabs/ssh" + "go.lindenii.runxiyu.org/forge/forged/internal/common/cmap" + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/global" +) + +type Server struct { + hookMap cmap.Map[string, hookInfo] + socketPath string + executablesPath string + global *global.Global +} +type hookInfo struct { + session ssh.Session + pubkey string + directAccess bool + repoPath string + userID int + userType string + repoID int + groupPath []string + repoName string + contribReq string +} + +func New(global *global.Global) (server *Server) { + cfg := global.Config.Hooks + return &Server{ + socketPath: cfg.Socket, + executablesPath: cfg.Execs, + hookMap: cmap.Map[string, hookInfo]{}, + global: global, + } +} + +func (server *Server) Run(ctx context.Context) error { + listener, _, err := misc.ListenUnixSocket(ctx, server.socketPath) + if err != nil { + return fmt.Errorf("listen unix socket for hooks: %w", err) + } + defer func() { + _ = listener.Close() + }() + + stop := context.AfterFunc(ctx, func() { + _ = listener.Close() + }) + defer stop() + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { + return nil + } + return fmt.Errorf("accept conn: %w", err) + } + + go server.handleConn(ctx, conn) + } +} + +func (server *Server) handleConn(ctx context.Context, conn net.Conn) { + defer func() { + _ = conn.Close() + }() + unblock := context.AfterFunc(ctx, func() { + _ = conn.SetDeadline(time.Now()) + _ = conn.Close() + }) + defer unblock() +} diff --git a/forged/internal/incoming/lmtp/lmtp.go b/forged/internal/incoming/lmtp/lmtp.go new file mode 100644 index 0000000..c8918f8 --- /dev/null +++ b/forged/internal/incoming/lmtp/lmtp.go @@ -0,0 +1,71 @@ +package lmtp + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/global" +) + +type Server struct { + socket string + domain string + maxSize int64 + writeTimeout uint32 + readTimeout uint32 + global *global.Global +} + +func New(global *global.Global) (server *Server) { + cfg := global.Config.LMTP + return &Server{ + socket: cfg.Socket, + domain: cfg.Domain, + maxSize: cfg.MaxSize, + writeTimeout: cfg.WriteTimeout, + readTimeout: cfg.ReadTimeout, + global: global, + } +} + +func (server *Server) Run(ctx context.Context) error { + listener, _, err := misc.ListenUnixSocket(ctx, server.socket) + if err != nil { + return fmt.Errorf("listen unix socket for LMTP: %w", err) + } + defer func() { + _ = listener.Close() + }() + + stop := context.AfterFunc(ctx, func() { + _ = listener.Close() + }) + defer stop() + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { + return nil + } + return fmt.Errorf("accept conn: %w", err) + } + + go server.handleConn(ctx, conn) + } +} + +func (server *Server) handleConn(ctx context.Context, conn net.Conn) { + defer func() { + _ = conn.Close() + }() + unblock := context.AfterFunc(ctx, func() { + _ = conn.SetDeadline(time.Now()) + _ = conn.Close() + }) + defer unblock() +} diff --git a/forged/internal/incoming/ssh/ssh.go b/forged/internal/incoming/ssh/ssh.go new file mode 100644 index 0000000..1f27be2 --- /dev/null +++ b/forged/internal/incoming/ssh/ssh.go @@ -0,0 +1,90 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + gliderssh "github.com/gliderlabs/ssh" + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/global" + gossh "golang.org/x/crypto/ssh" +) + +type Server struct { + gliderServer *gliderssh.Server + privkey gossh.Signer + net string + addr string + root string + shutdownTimeout uint32 + global *global.Global +} + +func New(global *global.Global) (server *Server, err error) { + cfg := global.Config.SSH + server = &Server{ + net: cfg.Net, + addr: cfg.Addr, + root: cfg.Root, + shutdownTimeout: cfg.ShutdownTimeout, + global: global, + } //exhaustruct:ignore + + var privkeyBytes []byte + + privkeyBytes, err = os.ReadFile(cfg.Key) + if err != nil { + return server, fmt.Errorf("read SSH private key: %w", err) + } + + server.privkey, err = gossh.ParsePrivateKey(privkeyBytes) + if err != nil { + return server, fmt.Errorf("parse SSH private key: %w", err) + } + + server.global.SSHPubkey = misc.BytesToString(gossh.MarshalAuthorizedKey(server.privkey.PublicKey())) + server.global.SSHFingerprint = gossh.FingerprintSHA256(server.privkey.PublicKey()) + + server.gliderServer = &gliderssh.Server{ + Handler: handle, + PublicKeyHandler: func(ctx gliderssh.Context, key gliderssh.PublicKey) bool { return true }, + KeyboardInteractiveHandler: func(ctx gliderssh.Context, challenge gossh.KeyboardInteractiveChallenge) bool { return true }, + } //exhaustruct:ignore + server.gliderServer.AddHostKey(server.privkey) + + return server, nil +} + +func (server *Server) Run(ctx context.Context) (err error) { + listener, err := misc.Listen(ctx, server.net, server.addr) + if err != nil { + return fmt.Errorf("listen for SSH: %w", err) + } + defer func() { + _ = listener.Close() + }() + + stop := context.AfterFunc(ctx, func() { + shCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(server.shutdownTimeout)*time.Second) + defer cancel() + _ = server.gliderServer.Shutdown(shCtx) + _ = listener.Close() + }) + defer stop() + + err = server.gliderServer.Serve(listener) + if err != nil { + if errors.Is(err, gliderssh.ErrServerClosed) || ctx.Err() != nil { + return nil + } + return fmt.Errorf("serve SSH: %w", err) + } + panic("unreachable") +} + +func handle(session gliderssh.Session) { + panic("SSH server handler not implemented yet") +} diff --git a/forged/internal/incoming/web/authn.go b/forged/internal/incoming/web/authn.go new file mode 100644 index 0000000..9754eb1 --- /dev/null +++ b/forged/internal/incoming/web/authn.go @@ -0,0 +1,33 @@ +package web + +import ( + "crypto/sha256" + "errors" + "fmt" + "net/http" + + "github.com/jackc/pgx/v5" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" +) + +func userResolver(r *http.Request) (string, string, error) { + cookie, err := r.Cookie("session") + if err != nil { + if errors.Is(err, http.ErrNoCookie) { + return "", "", nil + } + return "", "", err + } + + tokenHash := sha256.Sum256([]byte(cookie.Value)) + + session, err := types.Base(r).Global.Queries.GetUserFromSession(r.Context(), tokenHash[:]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", "", nil + } + return "", "", err + } + + return fmt.Sprint(session.UserID), session.Username, nil +} diff --git a/forged/internal/incoming/web/handler.go b/forged/internal/incoming/web/handler.go new file mode 100644 index 0000000..394469a --- /dev/null +++ b/forged/internal/incoming/web/handler.go @@ -0,0 +1,69 @@ +package web + +import ( + "html/template" + "net/http" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/global" + handlers "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/handlers" + repoHandlers "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/handlers/repo" + specialHandlers "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/handlers/special" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/templates" +) + +type handler struct { + r *Router +} + +func NewHandler(global *global.Global) *handler { + cfg := global.Config.Web + h := &handler{r: NewRouter().ReverseProxy(cfg.ReverseProxy).Global(global).UserResolver(userResolver)} + + staticFS := http.FileServer(http.Dir(cfg.StaticPath)) + h.r.ANYHTTP("-/static/*rest", + http.StripPrefix("/-/static/", staticFS), + WithDirIfEmpty("rest"), + ) + + funcs := template.FuncMap{ + "path_escape": misc.PathEscape, + "query_escape": misc.QueryEscape, + "minus": misc.Minus, + "first_line": misc.FirstLine, + "dereference_error": misc.DereferenceOrZero[error], + } + t := templates.MustParseDir(cfg.TemplatesPath, funcs) + renderer := templates.New(t) + + indexHTTP := handlers.NewIndexHTTP(renderer) + loginHTTP := specialHandlers.NewLoginHTTP(renderer, cfg.CookieExpiry) + groupHTTP := handlers.NewGroupHTTP(renderer) + repoHTTP := repoHandlers.NewHTTP(renderer) + notImpl := handlers.NewNotImplementedHTTP(renderer) + + h.r.GET("/", indexHTTP.Index) + + h.r.ANY("-/login", loginHTTP.Login) + h.r.ANY("-/users", notImpl.Handle) + + h.r.GET("@group/", groupHTTP.Index) + h.r.POST("@group/", groupHTTP.Post) + + h.r.GET("@group/-/repos/:repo/", repoHTTP.Index) + h.r.ANY("@group/-/repos/:repo/info", notImpl.Handle) + h.r.ANY("@group/-/repos/:repo/git-upload-pack", notImpl.Handle) + h.r.GET("@group/-/repos/:repo/branches/", repoHTTP.Branches) + h.r.GET("@group/-/repos/:repo/log/", repoHTTP.Log) + h.r.GET("@group/-/repos/:repo/commit/:commit", repoHTTP.Commit) + h.r.GET("@group/-/repos/:repo/tree/*rest", repoHTTP.Tree, WithDirIfEmpty("rest")) + h.r.GET("@group/-/repos/:repo/raw/*rest", repoHTTP.Raw, WithDirIfEmpty("rest")) + h.r.GET("@group/-/repos/:repo/contrib/", notImpl.Handle) + h.r.GET("@group/-/repos/:repo/contrib/:mr", notImpl.Handle) + + return h +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.r.ServeHTTP(w, r) +} diff --git a/forged/internal/incoming/web/handlers/group.go b/forged/internal/incoming/web/handlers/group.go new file mode 100644 index 0000000..4823cb7 --- /dev/null +++ b/forged/internal/incoming/web/handlers/group.go @@ -0,0 +1,156 @@ +package handlers + +import ( + "fmt" + "log/slog" + "net/http" + "path/filepath" + "strconv" + + "github.com/jackc/pgx/v5" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/templates" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +type GroupHTTP struct { + r templates.Renderer +} + +func NewGroupHTTP(r templates.Renderer) *GroupHTTP { + return &GroupHTTP{ + r: r, + } +} + +func (h *GroupHTTP) Index(w http.ResponseWriter, r *http.Request, _ wtypes.Vars) { + base := wtypes.Base(r) + userID, err := strconv.ParseInt(base.UserID, 10, 64) + if err != nil { + userID = 0 + } + + queryParams := queries.GetGroupByPathParams{ + Column1: base.URLSegments, + UserID: userID, + } + p, err := base.Global.Queries.GetGroupByPath(r.Context(), queryParams) + if err != nil { + slog.Error("failed to get group ID by path", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + subgroups, err := base.Global.Queries.GetSubgroups(r.Context(), &p.ID) + if err != nil { + slog.Error("failed to get subgroups", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + // TODO: gracefully fail this part of the page + } + repos, err := base.Global.Queries.GetReposInGroup(r.Context(), p.ID) + if err != nil { + slog.Error("failed to get repos in group", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + // TODO: gracefully fail this part of the page + } + err = h.r.Render(w, "group", struct { + BaseData *wtypes.BaseData + Subgroups []queries.GetSubgroupsRow + Repos []queries.GetReposInGroupRow + Description string + DirectAccess bool + }{ + BaseData: base, + Subgroups: subgroups, + Repos: repos, + Description: p.Description, + DirectAccess: p.HasRole, + }) + if err != nil { + slog.Error("failed to render index page", "error", err) + } +} + +func (h *GroupHTTP) Post(w http.ResponseWriter, r *http.Request, _ wtypes.Vars) { + base := wtypes.Base(r) + userID, err := strconv.ParseInt(base.UserID, 10, 64) + if err != nil { + userID = 0 + } + + queryParams := queries.GetGroupByPathParams{ + Column1: base.URLSegments, + UserID: userID, + } + p, err := base.Global.Queries.GetGroupByPath(r.Context(), queryParams) + if err != nil { + slog.Error("failed to get group ID by path", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + if !p.HasRole { + http.Error(w, "You do not have the necessary permissions to create repositories in this group.", http.StatusForbidden) + return + } + + name := r.PostFormValue("repo_name") + desc := r.PostFormValue("repo_desc") + contrib := r.PostFormValue("repo_contrib") + if name == "" { + http.Error(w, "Repo name is required", http.StatusBadRequest) + return + } + + if contrib == "" || contrib == "public" { + contrib = "open" + } + + tx, err := base.Global.DB.BeginTx(r.Context(), pgx.TxOptions{}) + if err != nil { + slog.Error("begin tx failed", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer func() { _ = tx.Rollback(r.Context()) }() + + txq := base.Global.Queries.WithTx(tx) + var descPtr *string + if desc != "" { + descPtr = &desc + } + repoID, err := txq.InsertRepo(r.Context(), queries.InsertRepoParams{ + GroupID: p.ID, + Name: name, + Description: descPtr, + ContribRequirements: contrib, + }) + if err != nil { + slog.Error("insert repo failed", "error", err) + http.Error(w, "Failed to create repository", http.StatusInternalServerError) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoID)) + + gitc, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err != nil { + slog.Error("git2d connect failed", "error", err) + http.Error(w, "Failed to initialize repository (backend)", http.StatusInternalServerError) + return + } + defer func() { _ = gitc.Close() }() + if err = gitc.InitRepo(repoPath, base.Global.Config.Hooks.Execs); err != nil { + slog.Error("git2d init failed", "error", err) + http.Error(w, "Failed to initialize repository", http.StatusInternalServerError) + return + } + + if err = tx.Commit(r.Context()); err != nil { + slog.Error("commit tx failed", "error", err) + http.Error(w, "Failed to finalize repository creation", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) +} diff --git a/forged/internal/incoming/web/handlers/index.go b/forged/internal/incoming/web/handlers/index.go new file mode 100644 index 0000000..a758b07 --- /dev/null +++ b/forged/internal/incoming/web/handlers/index.go @@ -0,0 +1,39 @@ +package handlers + +import ( + "log" + "net/http" + + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/templates" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" +) + +type IndexHTTP struct { + r templates.Renderer +} + +func NewIndexHTTP(r templates.Renderer) *IndexHTTP { + return &IndexHTTP{ + r: r, + } +} + +func (h *IndexHTTP) Index(w http.ResponseWriter, r *http.Request, _ wtypes.Vars) { + groups, err := wtypes.Base(r).Global.Queries.GetRootGroups(r.Context()) + if err != nil { + http.Error(w, "failed to get root groups", http.StatusInternalServerError) + log.Println("failed to get root groups", "error", err) + return + } + err = h.r.Render(w, "index", struct { + BaseData *wtypes.BaseData + Groups []queries.GetRootGroupsRow + }{ + BaseData: wtypes.Base(r), + Groups: groups, + }) + if err != nil { + log.Println("failed to render index page", "error", err) + } +} diff --git a/forged/internal/incoming/web/handlers/not_implemented.go b/forged/internal/incoming/web/handlers/not_implemented.go new file mode 100644 index 0000000..6813c88 --- /dev/null +++ b/forged/internal/incoming/web/handlers/not_implemented.go @@ -0,0 +1,22 @@ +package handlers + +import ( + "net/http" + + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/templates" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" +) + +type NotImplementedHTTP struct { + r templates.Renderer +} + +func NewNotImplementedHTTP(r templates.Renderer) *NotImplementedHTTP { + return &NotImplementedHTTP{ + r: r, + } +} + +func (h *NotImplementedHTTP) Handle(w http.ResponseWriter, _ *http.Request, _ wtypes.Vars) { + http.Error(w, "not implemented", http.StatusNotImplemented) +} diff --git a/forged/internal/incoming/web/handlers/repo/branches.go b/forged/internal/incoming/web/handlers/repo/branches.go new file mode 100644 index 0000000..26f3b04 --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/branches.go @@ -0,0 +1,68 @@ +package repo + +import ( + "fmt" + "log/slog" + "net/http" + "net/url" + "path/filepath" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +func (h *HTTP) Branches(w http.ResponseWriter, r *http.Request, v wtypes.Vars) { + base := wtypes.Base(r) + repoName := v["repo"] + + var userID int64 + if base.UserID != "" { + _, _ = fmt.Sscan(base.UserID, &userID) + } + grp, err := base.Global.Queries.GetGroupByPath(r.Context(), queries.GetGroupByPathParams{Column1: base.GroupPath, UserID: userID}) + if err != nil { + slog.Error("get group by path", "error", err) + http.Error(w, "Group not found", http.StatusNotFound) + return + } + repoRow, err := base.Global.Queries.GetRepoByGroupAndName(r.Context(), queries.GetRepoByGroupAndNameParams{GroupID: grp.ID, Name: repoName}) + if err != nil { + slog.Error("get repo by name", "error", err) + http.Error(w, "Repository not found", http.StatusNotFound) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoRow.ID)) + client, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err != nil { + slog.Error("git2d connect failed", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer func() { _ = client.Close() }() + + branches, err := client.ListBranches(repoPath) + if err != nil { + slog.Error("list branches failed", "error", err) + branches = nil + } + + repoURLRoot := "/" + misc.SegmentsToURL(base.GroupPath) + "/-/repos/" + url.PathEscape(repoRow.Name) + "/" + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "repo_url_root": repoURLRoot, + "branches": branches, + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_branches", data); err != nil { + slog.Error("render repo branches", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} diff --git a/forged/internal/incoming/web/handlers/repo/commit.go b/forged/internal/incoming/web/handlers/repo/commit.go new file mode 100644 index 0000000..0a27f3b --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/commit.go @@ -0,0 +1,239 @@ +package repo + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "log/slog" + "net/http" + "net/url" + "path/filepath" + "strings" + "time" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +type commitPerson struct { + Name string + Email string + When time.Time +} + +type commitObject struct { + Hash string + Message string + Author commitPerson + Committer commitPerson +} + +type usableChunk struct { + Operation int + Content string +} + +type diffFileMeta struct { + Hash string + Mode string + Path string +} + +type usableFilePatch struct { + From diffFileMeta + To diffFileMeta + Chunks []usableChunk +} + +func shortHash(s string) string { + if s == "" { + return "" + } + b := sha1.Sum([]byte(s)) + return hex.EncodeToString(b[:8]) +} + +func parseUnifiedPatch(p string) []usableFilePatch { + lines := strings.Split(p, "\n") + patches := []usableFilePatch{} + var cur *usableFilePatch + flush := func() { + if cur != nil { + patches = append(patches, *cur) + cur = nil + } + } + appendChunk := func(op int, buf *[]string) { + if len(*buf) == 0 || cur == nil { + return + } + content := strings.Join(*buf, "\n") + *buf = (*buf)[:0] + cur.Chunks = append(cur.Chunks, usableChunk{Operation: op, Content: content}) + } + var bufSame, bufAdd, bufDel []string + + for _, ln := range lines { + if strings.HasPrefix(ln, "diff --git ") { + appendChunk(0, &bufSame) + appendChunk(1, &bufAdd) + appendChunk(2, &bufDel) + flush() + parts := strings.SplitN(strings.TrimPrefix(ln, "diff --git "), " ", 2) + from := strings.TrimPrefix(strings.TrimSpace(parts[0]), "a/") + to := from + if len(parts) > 1 { + to = strings.TrimPrefix(strings.TrimSpace(strings.TrimPrefix(parts[1], "b/")), "b/") + } + cur = &usableFilePatch{ + From: diffFileMeta{Path: from, Hash: shortHash(from)}, + To: diffFileMeta{Path: to, Hash: shortHash(to)}, + } + continue + } + if cur == nil { + continue + } + switch { + case strings.HasPrefix(ln, "+"): + appendChunk(0, &bufSame) + appendChunk(2, &bufDel) + bufAdd = append(bufAdd, ln) + case strings.HasPrefix(ln, "-"): + appendChunk(0, &bufSame) + appendChunk(1, &bufAdd) + bufDel = append(bufDel, ln) + default: + appendChunk(1, &bufAdd) + appendChunk(2, &bufDel) + bufSame = append(bufSame, ln) + } + } + if cur != nil { + appendChunk(0, &bufSame) + appendChunk(1, &bufAdd) + appendChunk(2, &bufDel) + flush() + } + return patches +} + +func (h *HTTP) Commit(w http.ResponseWriter, r *http.Request, v wtypes.Vars) { + base := wtypes.Base(r) + repoName := v["repo"] + commitSpec := v["commit"] + wantPatch := strings.HasSuffix(commitSpec, ".patch") + commitSpec = strings.TrimSuffix(commitSpec, ".patch") + + var userID int64 + if base.UserID != "" { + _, _ = fmt.Sscan(base.UserID, &userID) + } + grp, err := base.Global.Queries.GetGroupByPath(r.Context(), queries.GetGroupByPathParams{Column1: base.GroupPath, UserID: userID}) + if err != nil { + slog.Error("get group by path", "error", err) + http.Error(w, "Group not found", http.StatusNotFound) + return + } + repoRow, err := base.Global.Queries.GetRepoByGroupAndName(r.Context(), queries.GetRepoByGroupAndNameParams{GroupID: grp.ID, Name: repoName}) + if err != nil { + slog.Error("get repo by name", "error", err) + http.Error(w, "Repository not found", http.StatusNotFound) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoRow.ID)) + client, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err != nil { + slog.Error("git2d connect failed", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer func() { _ = client.Close() }() + + resolved := commitSpec + if len(commitSpec) < 40 { + if list, lerr := client.Log(repoPath, commitSpec, 1); lerr == nil && len(list) > 0 { + resolved = list[0].Hash + } + } + if !wantPatch && resolved != "" && resolved != commitSpec { + u := *r.URL + basePath := strings.TrimSuffix(u.EscapedPath(), commitSpec) + u.Path = basePath + resolved + http.Redirect(w, r, u.String(), http.StatusSeeOther) + return + } + + if wantPatch { + patchStr, perr := client.FormatPatch(repoPath, resolved) + if perr != nil { + slog.Error("format patch failed", "error", perr) + http.Error(w, "Failed to format patch", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + _, _ = w.Write([]byte(patchStr)) + return + } + + info, derr := client.CommitInfo(repoPath, resolved) + if derr != nil { + slog.Error("commit info failed", "error", derr) + http.Error(w, "Failed to get commit info", http.StatusInternalServerError) + return + } + + toTime := func(sec, minoff int64) time.Time { + loc := time.FixedZone("", int(minoff*60)) + return time.Unix(sec, 0).In(loc) + } + co := commitObject{ + Hash: info.Hash, + Message: info.Message, + Author: commitPerson{Name: info.AuthorName, Email: info.AuthorEmail, When: toTime(info.AuthorWhen, info.AuthorTZMin)}, + Committer: commitPerson{Name: info.CommitterName, Email: info.CommitterEmail, When: toTime(info.CommitterWhen, info.CommitterTZMin)}, + } + + toUsable := func(files []git2c.FileDiff) []usableFilePatch { + out := make([]usableFilePatch, 0, len(files)) + for _, f := range files { + u := usableFilePatch{ + From: diffFileMeta{Path: f.FromPath, Mode: fmt.Sprintf("%06o", f.FromMode), Hash: shortHash(f.FromPath)}, + To: diffFileMeta{Path: f.ToPath, Mode: fmt.Sprintf("%06o", f.ToMode), Hash: shortHash(f.ToPath)}, + } + for _, ch := range f.Chunks { + u.Chunks = append(u.Chunks, usableChunk{Operation: int(ch.Op), Content: ch.Content}) + } + out = append(out, u) + } + return out + } + filePatches := toUsable(info.Files) + parentHex := "" + if len(info.Parents) > 0 { + parentHex = info.Parents[0] + } + + repoURLRoot := "/" + misc.SegmentsToURL(base.GroupPath) + "/-/repos/" + url.PathEscape(repoRow.Name) + "/" + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "repo_url_root": repoURLRoot, + "commit_object": co, + "commit_id": co.Hash, + "parent_commit_hash": parentHex, + "file_patches": filePatches, + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_commit", data); err != nil { + slog.Error("render repo commit", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} diff --git a/forged/internal/incoming/web/handlers/repo/handler.go b/forged/internal/incoming/web/handlers/repo/handler.go new file mode 100644 index 0000000..2881d7d --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/handler.go @@ -0,0 +1,15 @@ +package repo + +import ( + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/templates" +) + +type HTTP struct { + r templates.Renderer +} + +func NewHTTP(r templates.Renderer) *HTTP { + return &HTTP{ + r: r, + } +} diff --git a/forged/internal/incoming/web/handlers/repo/index.go b/forged/internal/incoming/web/handlers/repo/index.go new file mode 100644 index 0000000..c2cb24a --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/index.go @@ -0,0 +1,132 @@ +package repo + +import ( + "bytes" + "fmt" + "html/template" + "log/slog" + "net/http" + "net/url" + "path/filepath" + "strings" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/extension" + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +func (h *HTTP) Index(w http.ResponseWriter, r *http.Request, v wtypes.Vars) { + base := wtypes.Base(r) + repoName := v["repo"] + slog.Info("repo index", "group_path", base.GroupPath, "repo", repoName) + + var userID int64 + if base.UserID != "" { + _, _ = fmt.Sscan(base.UserID, &userID) + } + grp, err := base.Global.Queries.GetGroupByPath(r.Context(), queries.GetGroupByPathParams{ + Column1: base.GroupPath, + UserID: userID, + }) + if err != nil { + slog.Error("get group by path", "error", err) + http.Error(w, "Group not found", http.StatusNotFound) + return + } + + repoRow, err := base.Global.Queries.GetRepoByGroupAndName(r.Context(), queries.GetRepoByGroupAndNameParams{ + GroupID: grp.ID, + Name: repoName, + }) + if err != nil { + slog.Error("get repo by name", "error", err) + http.Error(w, "Repository not found", http.StatusNotFound) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoRow.ID)) + + var commits []git2c.Commit + var readme template.HTML + var commitsErr error + var readmeFile *git2c.FilenameContents + var cerr error + client, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err == nil { + defer func() { _ = client.Close() }() + commits, readmeFile, cerr = client.CmdIndex(repoPath) + if cerr != nil { + commitsErr = cerr + slog.Error("git2d CmdIndex failed", "error", cerr, "path", repoPath) + } else if readmeFile != nil { + nameLower := strings.ToLower(readmeFile.Filename) + if strings.HasSuffix(nameLower, ".md") || strings.HasSuffix(nameLower, ".markdown") || nameLower == "readme" { + md := goldmark.New( + goldmark.WithExtensions(extension.GFM), + ) + var buf bytes.Buffer + if err := md.Convert(readmeFile.Content, &buf); err == nil { + readme = template.HTML(buf.String()) + } else { + readme = template.HTML(template.HTMLEscapeString(string(readmeFile.Content))) + } + } else { + readme = template.HTML(template.HTMLEscapeString(string(readmeFile.Content))) + } + } + } else { + commitsErr = err + slog.Error("git2d connect failed", "error", err) + } + + sshRoot := strings.TrimSuffix(base.Global.Config.SSH.Root, "/") + httpRoot := strings.TrimSuffix(base.Global.Config.Web.Root, "/") + pathPart := misc.SegmentsToURL(base.GroupPath) + "/-/repos/" + url.PathEscape(repoRow.Name) + sshURL := "" + httpURL := "" + if sshRoot != "" { + sshURL = sshRoot + "/" + pathPart + } + if httpRoot != "" { + httpURL = httpRoot + "/" + pathPart + } + + var notes []string + if len(commits) == 0 && commitsErr == nil { + notes = append(notes, "This repository has no commits yet.") + } + if readme == template.HTML("") { + notes = append(notes, "No README found in the default branch.") + } + if sshURL == "" && httpURL == "" { + notes = append(notes, "Clone URLs not configured (missing SSH root and HTTP root).") + } + + cloneURL := sshURL + if cloneURL == "" { + cloneURL = httpURL + } + + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "ssh_clone_url": cloneURL, + "ref_name": base.RefName, + "commits": commits, + "commits_err": &commitsErr, + "readme": readme, + "notes": notes, + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_index", data); err != nil { + slog.Error("render repo index", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} diff --git a/forged/internal/incoming/web/handlers/repo/log.go b/forged/internal/incoming/web/handlers/repo/log.go new file mode 100644 index 0000000..9a1a6b8 --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/log.go @@ -0,0 +1,107 @@ +package repo + +import ( + "fmt" + "log/slog" + "net/http" + "net/url" + "path/filepath" + "time" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +type logAuthor struct { + Name string + Email string + When time.Time +} + +type logCommit struct { + Hash string + Message string + Author logAuthor +} + +func (h *HTTP) Log(w http.ResponseWriter, r *http.Request, v wtypes.Vars) { + base := wtypes.Base(r) + repoName := v["repo"] + + var userID int64 + if base.UserID != "" { + _, _ = fmt.Sscan(base.UserID, &userID) + } + grp, err := base.Global.Queries.GetGroupByPath(r.Context(), queries.GetGroupByPathParams{Column1: base.GroupPath, UserID: userID}) + if err != nil { + slog.Error("get group by path", "error", err) + http.Error(w, "Group not found", http.StatusNotFound) + return + } + repoRow, err := base.Global.Queries.GetRepoByGroupAndName(r.Context(), queries.GetRepoByGroupAndNameParams{GroupID: grp.ID, Name: repoName}) + if err != nil { + slog.Error("get repo by name", "error", err) + http.Error(w, "Repository not found", http.StatusNotFound) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoRow.ID)) + client, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err != nil { + slog.Error("git2d connect failed", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer func() { _ = client.Close() }() + + var refspec string + if base.RefType == "" { + refspec = "" + } else { + hex, rerr := client.ResolveRef(repoPath, base.RefType, base.RefName) + if rerr != nil { + slog.Error("resolve ref failed", "error", rerr) + refspec = "" + } else { + refspec = hex + } + } + + var rawCommits []git2c.Commit + rawCommits, err = client.Log(repoPath, refspec, 0) + var commitsErr error + if err != nil { + commitsErr = err + slog.Error("git2d log failed", "error", err) + } + commits := make([]logCommit, 0, len(rawCommits)) + for _, c := range rawCommits { + when, _ := time.Parse("2006-01-02 15:04:05", c.Date) + commits = append(commits, logCommit{ + Hash: c.Hash, + Message: c.Message, + Author: logAuthor{Name: c.Author, Email: c.Email, When: when}, + }) + } + + repoURLRoot := "/" + misc.SegmentsToURL(base.GroupPath) + "/-/repos/" + url.PathEscape(repoRow.Name) + "/" + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "repo_url_root": repoURLRoot, + "ref_name": base.RefName, + "commits": commits, + "commits_err": &commitsErr, + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_log", data); err != nil { + slog.Error("render repo log", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} diff --git a/forged/internal/incoming/web/handlers/repo/raw.go b/forged/internal/incoming/web/handlers/repo/raw.go new file mode 100644 index 0000000..6d5db1e --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/raw.go @@ -0,0 +1,90 @@ +package repo + +import ( + "fmt" + "log/slog" + "net/http" + "net/url" + "path/filepath" + "strings" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +func (h *HTTP) Raw(w http.ResponseWriter, r *http.Request, v wtypes.Vars) { + base := wtypes.Base(r) + repoName := v["repo"] + rawPathSpec := v["rest"] + pathSpec := strings.TrimSuffix(rawPathSpec, "/") + + var userID int64 + if base.UserID != "" { + _, _ = fmt.Sscan(base.UserID, &userID) + } + grp, err := base.Global.Queries.GetGroupByPath(r.Context(), queries.GetGroupByPathParams{Column1: base.GroupPath, UserID: userID}) + if err != nil { + slog.Error("get group by path", "error", err) + http.Error(w, "Group not found", http.StatusNotFound) + return + } + repoRow, err := base.Global.Queries.GetRepoByGroupAndName(r.Context(), queries.GetRepoByGroupAndNameParams{GroupID: grp.ID, Name: repoName}) + if err != nil { + slog.Error("get repo by name", "error", err) + http.Error(w, "Repository not found", http.StatusNotFound) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoRow.ID)) + + client, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err != nil { + slog.Error("git2d connect failed", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer func() { _ = client.Close() }() + + files, content, err := client.CmdTreeRaw(repoPath, pathSpec) + if err != nil { + slog.Error("git2d CmdTreeRaw failed", "error", err, "path", repoPath, "spec", pathSpec) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + repoURLRoot := "/" + misc.SegmentsToURL(base.GroupPath) + "/-/repos/" + url.PathEscape(repoRow.Name) + "/" + + switch { + case files != nil: + if !base.DirMode && misc.RedirectDir(w, r) { + return + } + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "repo_url_root": repoURLRoot, + "ref_name": base.RefName, + "path_spec": pathSpec, + "files": files, + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_raw_dir", data); err != nil { + slog.Error("render repo raw dir", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + case content != "": + if base.DirMode && misc.RedirectNoDir(w, r) { + return + } + w.Header().Set("Content-Type", "application/octet-stream") + _, _ = w.Write([]byte(content)) + default: + http.Error(w, "Unknown object type", http.StatusInternalServerError) + } +} diff --git a/forged/internal/incoming/web/handlers/repo/tree.go b/forged/internal/incoming/web/handlers/repo/tree.go new file mode 100644 index 0000000..627c998 --- /dev/null +++ b/forged/internal/incoming/web/handlers/repo/tree.go @@ -0,0 +1,110 @@ +package repo + +import ( + "fmt" + "html/template" + "log/slog" + "net/http" + "net/url" + "path/filepath" + "strings" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" + "go.lindenii.runxiyu.org/forge/forged/internal/ipc/git2c" +) + +func (h *HTTP) Tree(w http.ResponseWriter, r *http.Request, v wtypes.Vars) { + base := wtypes.Base(r) + repoName := v["repo"] + rawPathSpec := v["rest"] + pathSpec := strings.TrimSuffix(rawPathSpec, "/") + + var userID int64 + if base.UserID != "" { + _, _ = fmt.Sscan(base.UserID, &userID) + } + grp, err := base.Global.Queries.GetGroupByPath(r.Context(), queries.GetGroupByPathParams{Column1: base.GroupPath, UserID: userID}) + if err != nil { + slog.Error("get group by path", "error", err) + http.Error(w, "Group not found", http.StatusNotFound) + return + } + repoRow, err := base.Global.Queries.GetRepoByGroupAndName(r.Context(), queries.GetRepoByGroupAndNameParams{GroupID: grp.ID, Name: repoName}) + if err != nil { + slog.Error("get repo by name", "error", err) + http.Error(w, "Repository not found", http.StatusNotFound) + return + } + + repoPath := filepath.Join(base.Global.Config.Git.RepoDir, fmt.Sprintf("%d.git", repoRow.ID)) + + client, err := git2c.NewClient(r.Context(), base.Global.Config.Git.Socket) + if err != nil { + slog.Error("git2d connect failed", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer func() { _ = client.Close() }() + + files, content, err := client.CmdTreeRaw(repoPath, pathSpec) + if err != nil { + slog.Error("git2d CmdTreeRaw failed", "error", err, "path", repoPath, "spec", pathSpec) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + repoURLRoot := "/" + misc.SegmentsToURL(base.GroupPath) + "/-/repos/" + url.PathEscape(repoRow.Name) + "/" + + switch { + case files != nil: + if !base.DirMode && misc.RedirectDir(w, r) { + return + } + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "repo_url_root": repoURLRoot, + "ref_name": base.RefName, + "path_spec": pathSpec, + "files": files, + "readme_filename": "README.md", + "readme": template.HTML("<p>README rendering here is WIP.</p>"), + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_tree_dir", data); err != nil { + slog.Error("render repo tree dir", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + case content != "": + if base.DirMode && misc.RedirectNoDir(w, r) { + return + } + escaped := template.HTMLEscapeString(content) + rendered := template.HTML("<pre class=\"chroma\"><code>" + escaped + "</code></pre>") + data := map[string]any{ + "BaseData": base, + "group_path": base.GroupPath, + "repo_name": repoRow.Name, + "repo_description": repoRow.Description, + "repo_url_root": repoURLRoot, + "ref_name": base.RefName, + "path_spec": pathSpec, + "file_contents": rendered, + "global": map[string]any{ + "forge_title": base.Global.ForgeTitle, + }, + } + if err := h.r.Render(w, "repo_tree_file", data); err != nil { + slog.Error("render repo tree file", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + default: + http.Error(w, "Unknown object type", http.StatusInternalServerError) + } +} diff --git a/forged/internal/incoming/web/handlers/special/login.go b/forged/internal/incoming/web/handlers/special/login.go new file mode 100644 index 0000000..5672f1f --- /dev/null +++ b/forged/internal/incoming/web/handlers/special/login.go @@ -0,0 +1,119 @@ +package handlers + +import ( + "crypto/rand" + "crypto/sha256" + "errors" + "log" + "net/http" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "go.lindenii.runxiyu.org/forge/forged/internal/common/argon2id" + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/templates" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" +) + +type LoginHTTP struct { + r templates.Renderer + cookieExpiry int +} + +func NewLoginHTTP(r templates.Renderer, cookieExpiry int) *LoginHTTP { + return &LoginHTTP{ + r: r, + cookieExpiry: cookieExpiry, + } +} + +func (h *LoginHTTP) Login(w http.ResponseWriter, r *http.Request, _ wtypes.Vars) { + renderLoginPage := func(loginError string) bool { + err := h.r.Render(w, "login", struct { + BaseData *wtypes.BaseData + LoginError string + }{ + BaseData: wtypes.Base(r), + LoginError: loginError, + }) + if err != nil { + log.Println("failed to render login page", "error", err) + http.Error(w, "Failed to render login page", http.StatusInternalServerError) + return true + } + return false + } + + if r.Method == http.MethodGet { + renderLoginPage("") + return + } + + username := r.PostFormValue("username") + password := r.PostFormValue("password") + + userCreds, err := wtypes.Base(r).Global.Queries.GetUserCreds(r.Context(), &username) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + renderLoginPage("User not found") + return + } + log.Println("failed to get user credentials", "error", err) + http.Error(w, "Failed to get user credentials", http.StatusInternalServerError) + return + } + + if userCreds.PasswordHash == "" { + renderLoginPage("No password set for this user") + return + } + + passwordMatches, err := argon2id.ComparePasswordAndHash(password, userCreds.PasswordHash) + if err != nil { + log.Println("failed to compare password and hash", "error", err) + http.Error(w, "Failed to verify password", http.StatusInternalServerError) + return + } + + if !passwordMatches { + renderLoginPage("Invalid password") + return + } + + cookieValue := rand.Text() + + now := time.Now() + expiry := now.Add(time.Duration(h.cookieExpiry) * time.Second) + + cookie := &http.Cookie{ + Name: "session", + Value: cookieValue, + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: false, // TODO + Expires: expiry, + Path: "/", + } //exhaustruct:ignore + + http.SetCookie(w, cookie) + + tokenHash := sha256.Sum256(misc.StringToBytes(cookieValue)) + + err = wtypes.Base(r).Global.Queries.InsertSession(r.Context(), queries.InsertSessionParams{ + UserID: userCreds.ID, + TokenHash: tokenHash[:], + ExpiresAt: pgtype.Timestamptz{ + Time: expiry, + Valid: true, + }, + }) + if err != nil { + log.Println("failed to insert session", "error", err) + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} diff --git a/forged/internal/incoming/web/router.go b/forged/internal/incoming/web/router.go new file mode 100644 index 0000000..3809afb --- /dev/null +++ b/forged/internal/incoming/web/router.go @@ -0,0 +1,419 @@ +package web + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + + "go.lindenii.runxiyu.org/forge/forged/internal/global" + wtypes "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web/types" +) + +type UserResolver func(*http.Request) (id string, username string, err error) + +type ErrorRenderers struct { + BadRequest func(http.ResponseWriter, *wtypes.BaseData, string) + BadRequestColon func(http.ResponseWriter, *wtypes.BaseData) + NotFound func(http.ResponseWriter, *wtypes.BaseData) + ServerError func(http.ResponseWriter, *wtypes.BaseData, string) +} + +type dirPolicy int + +const ( + dirIgnore dirPolicy = iota + dirRequire + dirForbid + dirRequireIfEmpty +) + +type patKind uint8 + +const ( + lit patKind = iota + param + splat + group // @group, must be first token +) + +type patSeg struct { + kind patKind + lit string + key string +} + +type route struct { + method string + rawPattern string + wantDir dirPolicy + ifEmptyKey string + segs []patSeg + h wtypes.HandlerFunc + hh http.Handler + priority int +} + +type Router struct { + routes []route + errors ErrorRenderers + user UserResolver + global *global.Global + reverseProxy bool +} + +func NewRouter() *Router { return &Router{} } + +func (r *Router) Global(g *global.Global) *Router { + r.global = g + return r +} +func (r *Router) ReverseProxy(enabled bool) *Router { r.reverseProxy = enabled; return r } +func (r *Router) Errors(e ErrorRenderers) *Router { r.errors = e; return r } +func (r *Router) UserResolver(u UserResolver) *Router { r.user = u; return r } + +type RouteOption func(*route) + +func WithDir() RouteOption { return func(rt *route) { rt.wantDir = dirRequire } } +func WithoutDir() RouteOption { return func(rt *route) { rt.wantDir = dirForbid } } +func WithDirIfEmpty(param string) RouteOption { + return func(rt *route) { rt.wantDir = dirRequireIfEmpty; rt.ifEmptyKey = param } +} + +func (r *Router) GET(pattern string, f wtypes.HandlerFunc, opts ...RouteOption) { + r.handle("GET", pattern, f, nil, opts...) +} + +func (r *Router) POST(pattern string, f wtypes.HandlerFunc, opts ...RouteOption) { + r.handle("POST", pattern, f, nil, opts...) +} + +func (r *Router) ANY(pattern string, f wtypes.HandlerFunc, opts ...RouteOption) { + r.handle("", pattern, f, nil, opts...) +} + +func (r *Router) ANYHTTP(pattern string, hh http.Handler, opts ...RouteOption) { + r.handle("", pattern, nil, hh, opts...) +} + +func (r *Router) handle(method, pattern string, f wtypes.HandlerFunc, hh http.Handler, opts ...RouteOption) { + want := dirIgnore + if strings.HasSuffix(pattern, "/") { + want = dirRequire + pattern = strings.TrimSuffix(pattern, "/") + } else if pattern != "" { + want = dirForbid + } + segs, prio := compilePattern(pattern) + rt := route{ + method: method, + rawPattern: pattern, + wantDir: want, + segs: segs, + h: f, + hh: hh, + priority: prio, + } + for _, o := range opts { + o(&rt) + } + r.routes = append(r.routes, rt) + + sort.SliceStable(r.routes, func(i, j int) bool { + return r.routes[i].priority > r.routes[j].priority + }) +} + +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + segments, dirMode, err := splitAndUnescapePath(req.URL.EscapedPath()) + if err != nil { + r.err400(w, &wtypes.BaseData{Global: r.global}, "Error parsing request URI: "+err.Error()) + return + } + for _, s := range segments { + if strings.Contains(s, ":") { + r.err400Colon(w, &wtypes.BaseData{Global: r.global}) + return + } + } + + bd := &wtypes.BaseData{ + Global: r.global, + URLSegments: segments, + DirMode: dirMode, + } + req = req.WithContext(wtypes.WithBaseData(req.Context(), bd)) + + bd.RefType, bd.RefName, err = GetParamRefTypeName(req) + if err != nil { + r.err400(w, bd, "Error parsing ref query parameters: "+err.Error()) + return + } + + if r.user != nil { + uid, uname, uerr := r.user(req) + if uerr != nil { + r.err500(w, bd, "Error getting user info from request: "+uerr.Error()) + return + } + bd.UserID = uid + bd.Username = uname + } + + method := req.Method + var pathMatched bool + var matchedRaw string + + for _, rt := range r.routes { + ok, vars, sepIdx := match(rt.segs, segments) + if !ok { + continue + } + pathMatched = true + matchedRaw = rt.rawPattern + + switch rt.wantDir { + case dirRequire: + if !dirMode && redirectAddSlash(w, req) { + return + } + case dirForbid: + if dirMode && redirectDropSlash(w, req) { + return + } + case dirRequireIfEmpty: + if v := vars[rt.ifEmptyKey]; v == "" && !dirMode && redirectAddSlash(w, req) { + return + } + } + + bd.SeparatorIndex = sepIdx + if g := vars["group"]; g == "" { + bd.GroupPath = []string{} + } else { + bd.GroupPath = strings.Split(g, "/") + } + + if rt.method != "" && rt.method != method && (method != http.MethodHead || rt.method != http.MethodGet) { + continue + } + + if rt.h != nil { + rt.h(w, req, wtypes.Vars(vars)) + } else if rt.hh != nil { + rt.hh.ServeHTTP(w, req) + } else { + r.err500(w, bd, "route has no handler") + } + return + } + + if pathMatched { + w.Header().Set("Allow", allowForPattern(r.routes, matchedRaw)) + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + r.err404(w, bd) +} + +func compilePattern(pat string) ([]patSeg, int) { + if pat == "" || pat == "/" { + return nil, 1000 + } + pat = strings.Trim(pat, "/") + raw := strings.Split(pat, "/") + + segs := make([]patSeg, 0, len(raw)) + prio := 0 + for i, t := range raw { + switch { + case t == "@group": + if i != 0 { + segs = append(segs, patSeg{kind: lit, lit: t}) + prio += 10 + continue + } + segs = append(segs, patSeg{kind: group}) + prio += 1 + case strings.HasPrefix(t, ":"): + segs = append(segs, patSeg{kind: param, key: t[1:]}) + prio += 5 + case strings.HasPrefix(t, "*"): + segs = append(segs, patSeg{kind: splat, key: t[1:]}) + default: + segs = append(segs, patSeg{kind: lit, lit: t}) + prio += 10 + } + } + return segs, prio +} + +func match(pat []patSeg, segs []string) (bool, map[string]string, int) { + vars := make(map[string]string) + i := 0 + sepIdx := -1 + for pi := 0; pi < len(pat); pi++ { + ps := pat[pi] + switch ps.kind { + case group: + start := i + for i < len(segs) && segs[i] != "-" { + i++ + } + if start < i { + vars["group"] = strings.Join(segs[start:i], "/") + } else { + vars["group"] = "" + } + if i < len(segs) && segs[i] == "-" { + sepIdx = i + } + case lit: + if i >= len(segs) || segs[i] != ps.lit { + return false, nil, -1 + } + i++ + case param: + if i >= len(segs) { + return false, nil, -1 + } + vars[ps.key] = segs[i] + i++ + case splat: + if i < len(segs) { + vars[ps.key] = strings.Join(segs[i:], "/") + i = len(segs) + } else { + vars[ps.key] = "" + } + pi = len(pat) + } + } + if i != len(segs) { + return false, nil, -1 + } + return true, vars, sepIdx +} + +func splitAndUnescapePath(escaped string) ([]string, bool, error) { + if escaped == "" { + return nil, false, nil + } + dir := strings.HasSuffix(escaped, "/") + path := strings.Trim(escaped, "/") + if path == "" { + return []string{}, dir, nil + } + raw := strings.Split(path, "/") + out := make([]string, 0, len(raw)) + for _, seg := range raw { + u, err := url.PathUnescape(seg) + if err != nil { + return nil, dir, err + } + if u != "" { + out = append(out, u) + } + } + return out, dir, nil +} + +func redirectAddSlash(w http.ResponseWriter, r *http.Request) bool { + u := *r.URL + u.Path = u.EscapedPath() + "/" + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) + return true +} + +func redirectDropSlash(w http.ResponseWriter, r *http.Request) bool { + u := *r.URL + u.Path = strings.TrimRight(u.EscapedPath(), "/") + if u.Path == "" { + u.Path = "/" + } + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) + return true +} + +func allowForPattern(routes []route, raw string) string { + seen := map[string]struct{}{} + out := make([]string, 0, 4) + for _, rt := range routes { + if rt.rawPattern != raw || rt.method == "" { + continue + } + if _, ok := seen[rt.method]; ok { + continue + } + seen[rt.method] = struct{}{} + out = append(out, rt.method) + } + sort.Strings(out) + return strings.Join(out, ", ") +} + +func (r *Router) err400(w http.ResponseWriter, b *wtypes.BaseData, msg string) { + if r.errors.BadRequest != nil { + r.errors.BadRequest(w, b, msg) + return + } + http.Error(w, msg, http.StatusBadRequest) +} + +func (r *Router) err400Colon(w http.ResponseWriter, b *wtypes.BaseData) { + if r.errors.BadRequestColon != nil { + r.errors.BadRequestColon(w, b) + return + } + http.Error(w, "bad request", http.StatusBadRequest) +} + +func (r *Router) err404(w http.ResponseWriter, b *wtypes.BaseData) { + if r.errors.NotFound != nil { + r.errors.NotFound(w, b) + return + } + http.NotFound(w, nil) +} + +func (r *Router) err500(w http.ResponseWriter, b *wtypes.BaseData, msg string) { + if r.errors.ServerError != nil { + r.errors.ServerError(w, b, msg) + return + } + http.Error(w, msg, http.StatusInternalServerError) +} + +func GetParamRefTypeName(request *http.Request) (retRefType, retRefName string, err error) { + rawQuery := request.URL.RawQuery + queryValues, err := url.ParseQuery(rawQuery) + if err != nil { + return + } + done := false + for _, refType := range []string{"commit", "branch", "tag"} { + refName, ok := queryValues[refType] + if ok { + if done { + err = errDupRefSpec + return + } + done = true + if len(refName) != 1 { + err = errDupRefSpec + return + } + retRefName = refName[0] + retRefType = refType + } + } + if !done { + retRefType = "" + retRefName = "" + err = nil + } + return +} + +var errDupRefSpec = fmt.Errorf("duplicate ref specifications") diff --git a/forged/internal/incoming/web/server.go b/forged/internal/incoming/web/server.go new file mode 100644 index 0000000..ab70aec --- /dev/null +++ b/forged/internal/incoming/web/server.go @@ -0,0 +1,70 @@ +package web + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "time" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" + "go.lindenii.runxiyu.org/forge/forged/internal/global" +) + +type Server struct { + net string + addr string + root string + httpServer *http.Server + shutdownTimeout uint32 + global *global.Global +} + +func New(global *global.Global) *Server { + cfg := global.Config.Web + httpServer := &http.Server{ + Handler: NewHandler(global), + ReadTimeout: time.Duration(cfg.ReadTimeout) * time.Second, + WriteTimeout: time.Duration(cfg.WriteTimeout) * time.Second, + IdleTimeout: time.Duration(cfg.IdleTimeout) * time.Second, + MaxHeaderBytes: cfg.MaxHeaderBytes, + } //exhaustruct:ignore + return &Server{ + net: cfg.Net, + addr: cfg.Addr, + root: cfg.Root, + shutdownTimeout: cfg.ShutdownTimeout, + httpServer: httpServer, + global: global, + } +} + +func (server *Server) Run(ctx context.Context) (err error) { + server.httpServer.BaseContext = func(_ net.Listener) context.Context { return ctx } + + listener, err := misc.Listen(ctx, server.net, server.addr) + if err != nil { + return fmt.Errorf("listen for web: %w", err) + } + defer func() { + _ = listener.Close() + }() + + stop := context.AfterFunc(ctx, func() { + shCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(server.shutdownTimeout)*time.Second) + defer cancel() + _ = server.httpServer.Shutdown(shCtx) + _ = listener.Close() + }) + defer stop() + + err = server.httpServer.Serve(listener) + if err != nil { + if errors.Is(err, http.ErrServerClosed) || ctx.Err() != nil { + return nil + } + return fmt.Errorf("serve web: %w", err) + } + panic("unreachable") +} diff --git a/forged/internal/incoming/web/templates/load.go b/forged/internal/incoming/web/templates/load.go new file mode 100644 index 0000000..4a6fc49 --- /dev/null +++ b/forged/internal/incoming/web/templates/load.go @@ -0,0 +1,31 @@ +package templates + +import ( + "html/template" + "io/fs" + "os" + "path/filepath" +) + +func MustParseDir(dir string, funcs template.FuncMap) *template.Template { + base := template.New("").Funcs(funcs) + + err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + b, err := os.ReadFile(path) + if err != nil { + return err + } + _, err = base.Parse(string(b)) + return err + }) + if err != nil { + panic(err) + } + return base +} diff --git a/forged/internal/incoming/web/templates/renderer.go b/forged/internal/incoming/web/templates/renderer.go new file mode 100644 index 0000000..350e9ec --- /dev/null +++ b/forged/internal/incoming/web/templates/renderer.go @@ -0,0 +1,35 @@ +package templates + +import ( + "bytes" + "html/template" + "log/slog" + "net/http" +) + +type Renderer interface { + Render(w http.ResponseWriter, name string, data any) error +} + +type tmplRenderer struct { + t *template.Template +} + +func New(t *template.Template) Renderer { + return &tmplRenderer{t: t} +} + +func (r *tmplRenderer) Render(w http.ResponseWriter, name string, data any) error { + var buf bytes.Buffer + if err := r.t.ExecuteTemplate(&buf, name, data); err != nil { + slog.Error("template render failed", "name", name, "error", err) + return err + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + n, err := w.Write(buf.Bytes()) + if err != nil { + return err + } + slog.Info("template rendered", "name", name, "bytes", n) + return nil +} diff --git a/forged/internal/incoming/web/types/types.go b/forged/internal/incoming/web/types/types.go new file mode 100644 index 0000000..4b9a65a --- /dev/null +++ b/forged/internal/incoming/web/types/types.go @@ -0,0 +1,37 @@ +package types + +import ( + "context" + "net/http" + + "go.lindenii.runxiyu.org/forge/forged/internal/global" +) + +type BaseData struct { + UserID string + Username string + URLSegments []string + DirMode bool + GroupPath []string + SeparatorIndex int + RefType string + RefName string + Global *global.Global +} + +type ctxKey struct{} + +func WithBaseData(ctx context.Context, b *BaseData) context.Context { + return context.WithValue(ctx, ctxKey{}, b) +} + +func Base(r *http.Request) *BaseData { + if v, ok := r.Context().Value(ctxKey{}).(*BaseData); ok && v != nil { + return v + } + return &BaseData{} +} + +type Vars map[string]string + +type HandlerFunc func(http.ResponseWriter, *http.Request, Vars) diff --git a/forged/internal/ipc/git2c/build.go b/forged/internal/ipc/git2c/build.go new file mode 100644 index 0000000..3d1b7a0 --- /dev/null +++ b/forged/internal/ipc/git2c/build.go @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import ( + "encoding/hex" + "fmt" + "path" + "sort" + "strings" +) + +func (c *Client) BuildTreeRecursive(repoPath, baseTreeHex string, updates map[string]string) (string, error) { + treeCache := make(map[string][]TreeEntryRaw) + var walk func(prefix, hexid string) error + walk = func(prefix, hexid string) error { + ents, err := c.TreeListByOID(repoPath, hexid) + if err != nil { + return err + } + treeCache[prefix] = ents + for _, e := range ents { + if e.Mode == 40000 { + sub := path.Join(prefix, e.Name) + if err := walk(sub, e.OID); err != nil { + return err + } + } + } + return nil + } + if err := walk("", baseTreeHex); err != nil { + return "", err + } + + for p, blob := range updates { + parts := strings.Split(p, "/") + dir := strings.Join(parts[:len(parts)-1], "/") + name := parts[len(parts)-1] + entries := treeCache[dir] + found := false + for i := range entries { + if entries[i].Name == name { + if blob == "" { + entries = append(entries[:i], entries[i+1:]...) + } else { + entries[i].Mode = 0o100644 + entries[i].OID = blob + } + found = true + break + } + } + if !found && blob != "" { + entries = append(entries, TreeEntryRaw{Mode: 0o100644, Name: name, OID: blob}) + } + treeCache[dir] = entries + } + + built := make(map[string]string) + var build func(prefix string) (string, error) + build = func(prefix string) (string, error) { + entries := treeCache[prefix] + for i := range entries { + if entries[i].Mode == 0o40000 || entries[i].Mode == 40000 { + sub := path.Join(prefix, entries[i].Name) + var ok bool + var oid string + if oid, ok = built[sub]; !ok { + var err error + oid, err = build(sub) + if err != nil { + return "", err + } + } + entries[i].Mode = 0o40000 + entries[i].OID = oid + } + } + sort.Slice(entries, func(i, j int) bool { + ni, nj := entries[i].Name, entries[j].Name + if ni == nj { + return entries[i].Mode != 0o40000 && entries[j].Mode == 0o40000 + } + if strings.HasPrefix(nj, ni) && len(ni) < len(nj) { + return entries[i].Mode != 0o40000 + } + if strings.HasPrefix(ni, nj) && len(nj) < len(ni) { + return entries[j].Mode == 0o40000 + } + return ni < nj + }) + wr := make([]TreeEntryRaw, 0, len(entries)) + for _, e := range entries { + if e.OID == "" { + continue + } + if e.Mode == 40000 { + e.Mode = 0o40000 + } + if _, err := hex.DecodeString(e.OID); err != nil { + return "", fmt.Errorf("invalid OID hex for %s/%s: %w", prefix, e.Name, err) + } + wr = append(wr, TreeEntryRaw{Mode: e.Mode, Name: e.Name, OID: e.OID}) + } + id, err := c.WriteTree(repoPath, wr) + if err != nil { + return "", err + } + built[prefix] = id + return id, nil + } + root, err := build("") + if err != nil { + return "", err + } + return root, nil +} diff --git a/forged/internal/ipc/git2c/client.go b/forged/internal/ipc/git2c/client.go new file mode 100644 index 0000000..79c2024 --- /dev/null +++ b/forged/internal/ipc/git2c/client.go @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import ( + "context" + "fmt" + "net" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/bare" +) + +type Client struct { + socketPath string + conn net.Conn + writer *bare.Writer + reader *bare.Reader +} + +func NewClient(ctx context.Context, socketPath string) (*Client, error) { + dialer := &net.Dialer{} //exhaustruct:ignore + conn, err := dialer.DialContext(ctx, "unix", socketPath) + if err != nil { + return nil, fmt.Errorf("git2d connection failed: %w", err) + } + + writer := bare.NewWriter(conn) + reader := bare.NewReader(conn) + + return &Client{ + socketPath: socketPath, + conn: conn, + writer: writer, + reader: reader, + }, nil +} + +func (c *Client) Close() (err error) { + if c.conn != nil { + err = c.conn.Close() + if err != nil { + return fmt.Errorf("close underlying socket: %w", err) + } + } + return nil +} diff --git a/forged/internal/ipc/git2c/cmd_index.go b/forged/internal/ipc/git2c/cmd_index.go new file mode 100644 index 0000000..44a0845 --- /dev/null +++ b/forged/internal/ipc/git2c/cmd_index.go @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import ( + "encoding/hex" + "errors" + "fmt" + "io" +) + +func (c *Client) CmdIndex(repoPath string) ([]Commit, *FilenameContents, error) { + err := c.writer.WriteData([]byte(repoPath)) + if err != nil { + return nil, nil, fmt.Errorf("sending repo path failed: %w", err) + } + err = c.writer.WriteUint(1) + if err != nil { + return nil, nil, fmt.Errorf("sending command failed: %w", err) + } + + status, err := c.reader.ReadUint() + if err != nil { + return nil, nil, fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return nil, nil, fmt.Errorf("git2d error: %d", status) + } + + // README + readmeRaw, err := c.reader.ReadData() + if err != nil { + readmeRaw = nil + } + + readmeFilename := "README.md" // TODO + readme := &FilenameContents{Filename: readmeFilename, Content: readmeRaw} + + // Commits + var commits []Commit + for { + id, err := c.reader.ReadData() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, nil, fmt.Errorf("reading commit ID failed: %w", err) + } + title, _ := c.reader.ReadData() + authorName, _ := c.reader.ReadData() + authorEmail, _ := c.reader.ReadData() + authorDate, _ := c.reader.ReadData() + + commits = append(commits, Commit{ + Hash: hex.EncodeToString(id), + Author: string(authorName), + Email: string(authorEmail), + Date: string(authorDate), + Message: string(title), + }) + } + + return commits, readme, nil +} diff --git a/forged/internal/ipc/git2c/cmd_init_repo.go b/forged/internal/ipc/git2c/cmd_init_repo.go new file mode 100644 index 0000000..ae1e92a --- /dev/null +++ b/forged/internal/ipc/git2c/cmd_init_repo.go @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import "fmt" + +func (c *Client) InitRepo(repoPath, hooksPath string) error { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(15); err != nil { + return fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(hooksPath)); err != nil { + return fmt.Errorf("sending hooks path failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return Perror(status) + } + return nil +} diff --git a/forged/internal/ipc/git2c/cmd_treeraw.go b/forged/internal/ipc/git2c/cmd_treeraw.go new file mode 100644 index 0000000..d2d5ac2 --- /dev/null +++ b/forged/internal/ipc/git2c/cmd_treeraw.go @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import ( + "errors" + "fmt" + "io" +) + +func (c *Client) CmdTreeRaw(repoPath, pathSpec string) ([]TreeEntry, string, error) { + err := c.writer.WriteData([]byte(repoPath)) + if err != nil { + return nil, "", fmt.Errorf("sending repo path failed: %w", err) + } + err = c.writer.WriteUint(2) + if err != nil { + return nil, "", fmt.Errorf("sending command failed: %w", err) + } + err = c.writer.WriteData([]byte(pathSpec)) + if err != nil { + return nil, "", fmt.Errorf("sending path failed: %w", err) + } + + status, err := c.reader.ReadUint() + if err != nil { + return nil, "", fmt.Errorf("reading status failed: %w", err) + } + + switch status { + case 0: + kind, err := c.reader.ReadUint() + if err != nil { + return nil, "", fmt.Errorf("reading object kind failed: %w", err) + } + + switch kind { + case 1: + // Tree + count, err := c.reader.ReadUint() + if err != nil { + return nil, "", fmt.Errorf("reading entry count failed: %w", err) + } + + var files []TreeEntry + for range count { + typeCode, err := c.reader.ReadUint() + if err != nil { + return nil, "", fmt.Errorf("error reading entry type: %w", err) + } + mode, err := c.reader.ReadUint() + if err != nil { + return nil, "", fmt.Errorf("error reading entry mode: %w", err) + } + size, err := c.reader.ReadUint() + if err != nil { + return nil, "", fmt.Errorf("error reading entry size: %w", err) + } + name, err := c.reader.ReadData() + if err != nil { + return nil, "", fmt.Errorf("error reading entry name: %w", err) + } + + files = append(files, TreeEntry{ + Name: string(name), + Mode: fmt.Sprintf("%06o", mode), + Size: size, + IsFile: typeCode == 2, + IsSubtree: typeCode == 1, + }) + } + + return files, "", nil + + case 2: + // Blob + content, err := c.reader.ReadData() + if err != nil && !errors.Is(err, io.EOF) { + return nil, "", fmt.Errorf("error reading file content: %w", err) + } + + return nil, string(content), nil + + default: + return nil, "", fmt.Errorf("unknown kind: %d", kind) + } + + case 3: + return nil, "", fmt.Errorf("path not found: %s", pathSpec) + + default: + return nil, "", fmt.Errorf("unknown status code: %d", status) + } +} diff --git a/forged/internal/ipc/git2c/doc.go b/forged/internal/ipc/git2c/doc.go new file mode 100644 index 0000000..e14dae0 --- /dev/null +++ b/forged/internal/ipc/git2c/doc.go @@ -0,0 +1,2 @@ +// Package git2c provides routines to interact with the git2d backend daemon. +package git2c diff --git a/forged/internal/ipc/git2c/extra.go b/forged/internal/ipc/git2c/extra.go new file mode 100644 index 0000000..1a3e3a6 --- /dev/null +++ b/forged/internal/ipc/git2c/extra.go @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import ( + "encoding/hex" + "fmt" + "time" +) + +type DiffChunk struct { + Op uint64 + Content string +} + +type FileDiff struct { + FromMode uint64 + ToMode uint64 + FromPath string + ToPath string + Chunks []DiffChunk +} + +type CommitInfo struct { + Hash string + AuthorName string + AuthorEmail string + AuthorWhen int64 // unix secs + AuthorTZMin int64 // minutes ofs + CommitterName string + CommitterEmail string + CommitterWhen int64 + CommitterTZMin int64 + Message string + Parents []string // hex + Files []FileDiff +} + +func (c *Client) ResolveRef(repoPath, refType, refName string) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(3); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(refType)); err != nil { + return "", fmt.Errorf("sending ref type failed: %w", err) + } + if err := c.writer.WriteData([]byte(refName)); err != nil { + return "", fmt.Errorf("sending ref name failed: %w", err) + } + + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + id, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading oid failed: %w", err) + } + return hex.EncodeToString(id), nil +} + +func (c *Client) ListBranches(repoPath string) ([]string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return nil, fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(4); err != nil { + return nil, fmt.Errorf("sending command failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return nil, Perror(status) + } + count, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading count failed: %w", err) + } + branches := make([]string, 0, count) + for range count { + name, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading branch name failed: %w", err) + } + branches = append(branches, string(name)) + } + return branches, nil +} + +func (c *Client) FormatPatch(repoPath, commitHex string) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(5); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(commitHex)); err != nil { + return "", fmt.Errorf("sending commit failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + buf, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading patch failed: %w", err) + } + return string(buf), nil +} + +func (c *Client) MergeBase(repoPath, hexA, hexB string) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(7); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(hexA)); err != nil { + return "", fmt.Errorf("sending oid A failed: %w", err) + } + if err := c.writer.WriteData([]byte(hexB)); err != nil { + return "", fmt.Errorf("sending oid B failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + base, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading base oid failed: %w", err) + } + return hex.EncodeToString(base), nil +} + +func (c *Client) Log(repoPath, refSpec string, n uint) ([]Commit, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return nil, fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(8); err != nil { + return nil, fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(refSpec)); err != nil { + return nil, fmt.Errorf("sending refspec failed: %w", err) + } + if err := c.writer.WriteUint(uint64(n)); err != nil { + return nil, fmt.Errorf("sending limit failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return nil, Perror(status) + } + var out []Commit + for { + id, err := c.reader.ReadData() + if err != nil { + break + } + title, _ := c.reader.ReadData() + authorName, _ := c.reader.ReadData() + authorEmail, _ := c.reader.ReadData() + date, _ := c.reader.ReadData() + out = append(out, Commit{ + Hash: hex.EncodeToString(id), + Author: string(authorName), + Email: string(authorEmail), + Date: string(date), + Message: string(title), + }) + } + return out, nil +} + +func (c *Client) CommitTreeOID(repoPath, commitHex string) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(12); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(commitHex)); err != nil { + return "", fmt.Errorf("sending oid failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + id, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading tree oid failed: %w", err) + } + return hex.EncodeToString(id), nil +} + +func (c *Client) CommitCreate(repoPath, treeHex string, parents []string, authorName, authorEmail string, when time.Time, message string) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(13); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(treeHex)); err != nil { + return "", fmt.Errorf("sending tree oid failed: %w", err) + } + if err := c.writer.WriteUint(uint64(len(parents))); err != nil { + return "", fmt.Errorf("sending parents count failed: %w", err) + } + for _, p := range parents { + if err := c.writer.WriteData([]byte(p)); err != nil { + return "", fmt.Errorf("sending parent oid failed: %w", err) + } + } + if err := c.writer.WriteData([]byte(authorName)); err != nil { + return "", fmt.Errorf("sending author name failed: %w", err) + } + if err := c.writer.WriteData([]byte(authorEmail)); err != nil { + return "", fmt.Errorf("sending author email failed: %w", err) + } + if err := c.writer.WriteInt(when.Unix()); err != nil { + return "", fmt.Errorf("sending when failed: %w", err) + } + _, offset := when.Zone() + if err := c.writer.WriteInt(int64(offset / 60)); err != nil { + return "", fmt.Errorf("sending tz offset failed: %w", err) + } + if err := c.writer.WriteData([]byte(message)); err != nil { + return "", fmt.Errorf("sending message failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + id, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading commit oid failed: %w", err) + } + return hex.EncodeToString(id), nil +} + +func (c *Client) UpdateRef(repoPath, refName, commitHex string) error { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(14); err != nil { + return fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(refName)); err != nil { + return fmt.Errorf("sending ref name failed: %w", err) + } + if err := c.writer.WriteData([]byte(commitHex)); err != nil { + return fmt.Errorf("sending commit oid failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return Perror(status) + } + return nil +} + +func (c *Client) CommitInfo(repoPath, commitHex string) (*CommitInfo, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return nil, fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(6); err != nil { + return nil, fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(commitHex)); err != nil { + return nil, fmt.Errorf("sending commit failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return nil, Perror(status) + } + id, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading id failed: %w", err) + } + aname, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading author name failed: %w", err) + } + aemail, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading author email failed: %w", err) + } + awhen, err := c.reader.ReadI64() + if err != nil { + return nil, fmt.Errorf("reading author time failed: %w", err) + } + aoff, err := c.reader.ReadI64() + if err != nil { + return nil, fmt.Errorf("reading author tz failed: %w", err) + } + cname, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading committer name failed: %w", err) + } + cemail, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading committer email failed: %w", err) + } + cwhen, err := c.reader.ReadI64() + if err != nil { + return nil, fmt.Errorf("reading committer time failed: %w", err) + } + coff, err := c.reader.ReadI64() + if err != nil { + return nil, fmt.Errorf("reading committer tz failed: %w", err) + } + msg, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading message failed: %w", err) + } + pcnt, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading parents count failed: %w", err) + } + parents := make([]string, 0, pcnt) + for i := uint64(0); i < pcnt; i++ { + praw, perr := c.reader.ReadData() + if perr != nil { + return nil, fmt.Errorf("reading parent failed: %w", perr) + } + parents = append(parents, hex.EncodeToString(praw)) + } + fcnt, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading file count failed: %w", err) + } + files := make([]FileDiff, 0, fcnt) + for i := uint64(0); i < fcnt; i++ { + fromMode, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading from mode failed: %w", err) + } + toMode, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading to mode failed: %w", err) + } + fromPath, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading from path failed: %w", err) + } + toPath, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading to path failed: %w", err) + } + ccnt, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading chunk count failed: %w", err) + } + chunks := make([]DiffChunk, 0, ccnt) + for j := uint64(0); j < ccnt; j++ { + op, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading chunk op failed: %w", err) + } + content, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading chunk content failed: %w", err) + } + chunks = append(chunks, DiffChunk{Op: op, Content: string(content)}) + } + files = append(files, FileDiff{ + FromMode: fromMode, + ToMode: toMode, + FromPath: string(fromPath), + ToPath: string(toPath), + Chunks: chunks, + }) + } + return &CommitInfo{ + Hash: hex.EncodeToString(id), + AuthorName: string(aname), + AuthorEmail: string(aemail), + AuthorWhen: awhen, + AuthorTZMin: aoff, + CommitterName: string(cname), + CommitterEmail: string(cemail), + CommitterWhen: cwhen, + CommitterTZMin: coff, + Message: string(msg), + Parents: parents, + Files: files, + }, nil +} diff --git a/forged/internal/ipc/git2c/git_types.go b/forged/internal/ipc/git2c/git_types.go new file mode 100644 index 0000000..da685bf --- /dev/null +++ b/forged/internal/ipc/git2c/git_types.go @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +type Commit struct { + Hash string + Author string + Email string + Date string + Message string +} + +type FilenameContents struct { + Filename string + Content []byte +} + +type TreeEntry struct { + Name string + Mode string + Size uint64 + IsFile bool + IsSubtree bool +} diff --git a/forged/internal/ipc/git2c/perror.go b/forged/internal/ipc/git2c/perror.go new file mode 100644 index 0000000..4be2a07 --- /dev/null +++ b/forged/internal/ipc/git2c/perror.go @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// TODO: Make the C part report detailed error messages too + +package git2c + +import "errors" + +var ( + ErrUnknown = errors.New("git2c: unknown error") + ErrPath = errors.New("git2c: get tree entry by path failed") + ErrRevparse = errors.New("git2c: revparse failed") + ErrReadme = errors.New("git2c: no readme") + ErrBlobExpected = errors.New("git2c: blob expected") + ErrEntryToObject = errors.New("git2c: tree entry to object conversion failed") + ErrBlobRawContent = errors.New("git2c: get blob raw content failed") + ErrRevwalk = errors.New("git2c: revwalk failed") + ErrRevwalkPushHead = errors.New("git2c: revwalk push head failed") + ErrBareProto = errors.New("git2c: bare protocol error") + ErrRefResolve = errors.New("git2c: ref resolve failed") + ErrBranches = errors.New("git2c: list branches failed") + ErrCommitLookup = errors.New("git2c: commit lookup failed") + ErrDiff = errors.New("git2c: diff failed") + ErrMergeBaseNone = errors.New("git2c: no merge base found") + ErrMergeBase = errors.New("git2c: merge base failed") + ErrCommitCreate = errors.New("git2c: commit create failed") + ErrUpdateRef = errors.New("git2c: update ref failed") + ErrCommitTree = errors.New("git2c: commit tree lookup failed") + ErrInitRepoCreate = errors.New("git2c: init repo: create failed") + ErrInitRepoConfig = errors.New("git2c: init repo: open config failed") + ErrInitRepoSetHooksPath = errors.New("git2c: init repo: set core.hooksPath failed") + ErrInitRepoSetAdvertisePushOptions = errors.New("git2c: init repo: set receive.advertisePushOptions failed") + ErrInitRepoMkdir = errors.New("git2c: init repo: create directory failed") +) + +func Perror(errno uint64) error { + switch errno { + case 0: + return nil + case 3: + return ErrPath + case 4: + return ErrRevparse + case 5: + return ErrReadme + case 6: + return ErrBlobExpected + case 7: + return ErrEntryToObject + case 8: + return ErrBlobRawContent + case 9: + return ErrRevwalk + case 10: + return ErrRevwalkPushHead + case 11: + return ErrBareProto + case 12: + return ErrRefResolve + case 13: + return ErrBranches + case 14: + return ErrCommitLookup + case 15: + return ErrDiff + case 16: + return ErrMergeBaseNone + case 17: + return ErrMergeBase + case 18: + return ErrUpdateRef + case 19: + return ErrCommitCreate + case 20: + return ErrInitRepoCreate + case 21: + return ErrInitRepoConfig + case 22: + return ErrInitRepoSetHooksPath + case 23: + return ErrInitRepoSetAdvertisePushOptions + case 24: + return ErrInitRepoMkdir + } + return ErrUnknown +} diff --git a/forged/internal/ipc/git2c/tree.go b/forged/internal/ipc/git2c/tree.go new file mode 100644 index 0000000..f598e14 --- /dev/null +++ b/forged/internal/ipc/git2c/tree.go @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +package git2c + +import ( + "encoding/hex" + "fmt" +) + +type TreeEntryRaw struct { + Mode uint64 + Name string + OID string // hex +} + +func (c *Client) TreeListByOID(repoPath, treeHex string) ([]TreeEntryRaw, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return nil, fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(9); err != nil { + return nil, fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData([]byte(treeHex)); err != nil { + return nil, fmt.Errorf("sending tree oid failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return nil, Perror(status) + } + count, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading count failed: %w", err) + } + entries := make([]TreeEntryRaw, 0, count) + for range count { + mode, err := c.reader.ReadUint() + if err != nil { + return nil, fmt.Errorf("reading mode failed: %w", err) + } + name, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading name failed: %w", err) + } + id, err := c.reader.ReadData() + if err != nil { + return nil, fmt.Errorf("reading oid failed: %w", err) + } + entries = append(entries, TreeEntryRaw{Mode: mode, Name: string(name), OID: hex.EncodeToString(id)}) + } + return entries, nil +} + +func (c *Client) WriteTree(repoPath string, entries []TreeEntryRaw) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(10); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteUint(uint64(len(entries))); err != nil { + return "", fmt.Errorf("sending count failed: %w", err) + } + for _, e := range entries { + if err := c.writer.WriteUint(e.Mode); err != nil { + return "", fmt.Errorf("sending mode failed: %w", err) + } + if err := c.writer.WriteData([]byte(e.Name)); err != nil { + return "", fmt.Errorf("sending name failed: %w", err) + } + raw, err := hex.DecodeString(e.OID) + if err != nil { + return "", fmt.Errorf("decode oid hex: %w", err) + } + if err := c.writer.WriteDataFixed(raw); err != nil { + return "", fmt.Errorf("sending oid failed: %w", err) + } + } + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + id, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading oid failed: %w", err) + } + return hex.EncodeToString(id), nil +} + +func (c *Client) WriteBlob(repoPath string, content []byte) (string, error) { + if err := c.writer.WriteData([]byte(repoPath)); err != nil { + return "", fmt.Errorf("sending repo path failed: %w", err) + } + if err := c.writer.WriteUint(11); err != nil { + return "", fmt.Errorf("sending command failed: %w", err) + } + if err := c.writer.WriteData(content); err != nil { + return "", fmt.Errorf("sending blob content failed: %w", err) + } + status, err := c.reader.ReadUint() + if err != nil { + return "", fmt.Errorf("reading status failed: %w", err) + } + if status != 0 { + return "", Perror(status) + } + id, err := c.reader.ReadData() + if err != nil { + return "", fmt.Errorf("reading oid failed: %w", err) + } + return hex.EncodeToString(id), nil +} diff --git a/forged/internal/server/server.go b/forged/internal/server/server.go new file mode 100644 index 0000000..39a6823 --- /dev/null +++ b/forged/internal/server/server.go @@ -0,0 +1,87 @@ +package server + +import ( + "context" + "fmt" + + "go.lindenii.runxiyu.org/forge/forged/internal/config" + "go.lindenii.runxiyu.org/forge/forged/internal/database" + "go.lindenii.runxiyu.org/forge/forged/internal/database/queries" + "go.lindenii.runxiyu.org/forge/forged/internal/global" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/hooks" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/lmtp" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/ssh" + "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web" + "golang.org/x/sync/errgroup" +) + +type Server struct { + config config.Config + + database database.Database + hookServer *hooks.Server + lmtpServer *lmtp.Server + webServer *web.Server + sshServer *ssh.Server + + global global.Global +} + +func New(configPath string) (server *Server, err error) { + server = &Server{} //exhaustruct:ignore + + server.config, err = config.Open(configPath) + if err != nil { + return server, fmt.Errorf("open config: %w", err) + } + + queries := queries.New(&server.database) + + server.global.ForgeVersion = "unknown" // TODO + server.global.ForgeTitle = server.config.General.Title + server.global.Config = &server.config + server.global.Queries = queries + + server.hookServer = hooks.New(&server.global) + server.lmtpServer = lmtp.New(&server.global) + server.webServer = web.New(&server.global) + server.sshServer, err = ssh.New(&server.global) + if err != nil { + return server, fmt.Errorf("create SSH server: %w", err) + } + + return server, nil +} + +func (server *Server) Run(ctx context.Context) (err error) { + // TODO: Not running git2d because it should be run separately. + // This needs to be documented somewhere, hence a TODO here for now. + + g, gctx := errgroup.WithContext(ctx) + + server.database, err = database.Open(gctx, server.config.DB.Conn) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + defer server.database.Close() + + // TODO: neater way to do this for transactions in querypool? + server.global.DB = &server.database + + g.Go(func() error { return server.hookServer.Run(gctx) }) + g.Go(func() error { return server.lmtpServer.Run(gctx) }) + g.Go(func() error { return server.webServer.Run(gctx) }) + g.Go(func() error { return server.sshServer.Run(gctx) }) + + err = g.Wait() + if err != nil { + return fmt.Errorf("server error: %w", err) + } + + err = ctx.Err() + if err != nil { + return fmt.Errorf("context exceeded: %w", err) + } + + return nil +} |