aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRunxi Yu <me@runxiyu.org>2025-04-05 20:26:57 +0800
committerRunxi Yu <me@runxiyu.org>2025-04-05 20:26:57 +0800
commit20b4fe0c59357a433042732d46e38da9c3d14c3b (patch)
tree537d7d450701a839802b1d57a82bd07324dcba90
parentmisc: Move utils.go's string function to misc (diff)
downloadforge-20b4fe0c59357a433042732d46e38da9c3d14c3b.tar.gz
forge-20b4fe0c59357a433042732d46e38da9c3d14c3b.tar.zst
forge-20b4fe0c59357a433042732d46e38da9c3d14c3b.zip
database shall no longer be a global variable
-rw-r--r--acl.go2
-rw-r--r--config.go2
-rw-r--r--database.go8
-rw-r--r--fedauth.go2
-rw-r--r--git_hooks_handle_linux.go6
-rw-r--r--git_hooks_handle_other.go6
-rw-r--r--git_misc.go4
-rw-r--r--http_auth.go4
-rw-r--r--http_handle_group_index.go12
-rw-r--r--http_handle_login.go4
-rw-r--r--http_handle_repo_contrib_index.go4
-rw-r--r--http_handle_repo_contrib_one.go4
-rw-r--r--http_handle_repo_info.go4
-rw-r--r--http_handle_repo_upload_pack.go2
-rw-r--r--http_server.go10
-rw-r--r--lmtp_handle_patch.go4
-rw-r--r--lmtp_server.go2
-rw-r--r--server.go7
-rw-r--r--ssh_handle_receive_pack.go2
-rw-r--r--users.go4
20 files changed, 47 insertions, 46 deletions
diff --git a/acl.go b/acl.go
index 44cd04b..dfe128a 100644
--- a/acl.go
+++ b/acl.go
@@ -14,7 +14,7 @@ import (
//
// TODO: Revamp.
func (s *server) getRepoInfo(ctx context.Context, groupPath []string, repoName, sshPubkey string) (repoID int, fsPath string, access bool, contribReq, userType string, userID int, err error) {
- err = database.QueryRow(ctx, `
+ err = s.database.QueryRow(ctx, `
WITH RECURSIVE group_path_cte AS (
-- Start: match the first name in the path where parent_group IS NULL
SELECT
diff --git a/config.go b/config.go
index 1bbc3a1..773a223 100644
--- a/config.go
+++ b/config.go
@@ -92,7 +92,7 @@ func (s *server) loadConfig(path string) (err error) {
return errors.New("unsupported database type")
}
- if database, err = pgxpool.New(context.Background(), s.config.DB.Conn); err != nil {
+ if s.database, err = pgxpool.New(context.Background(), s.config.DB.Conn); err != nil {
return err
}
diff --git a/database.go b/database.go
index 18e753f..1ea0753 100644
--- a/database.go
+++ b/database.go
@@ -7,7 +7,6 @@ import (
"context"
"github.com/jackc/pgx/v5"
- "github.com/jackc/pgx/v5/pgxpool"
)
// TODO: All database handling logic in all request handlers must be revamped.
@@ -16,18 +15,13 @@ import (
// at a single point. A failure to do so may cause things as serious as
// privilege escalation.
-// database serves as the primary database handle for this entire application.
-// Transactions or single reads may be used from it. A [pgxpool.Pool] is
-// necessary to safely use pgx concurrently; pgx.Conn, etc. are insufficient.
-var database *pgxpool.Pool
-
// queryNameDesc is a helper function that executes a query and returns a
// list of nameDesc results. The query must return two string arguments, i.e. a
// name and a description.
func (s *server) queryNameDesc(ctx context.Context, query string, args ...any) (result []nameDesc, err error) {
var rows pgx.Rows
- if rows, err = database.Query(ctx, query, args...); err != nil {
+ if rows, err = s.database.Query(ctx, query, args...); err != nil {
return nil, err
}
defer rows.Close()
diff --git a/fedauth.go b/fedauth.go
index 46290e5..43cb4e3 100644
--- a/fedauth.go
+++ b/fedauth.go
@@ -77,7 +77,7 @@ func (s *server) fedauth(ctx context.Context, userID int, service, remoteUsernam
}
var txn pgx.Tx
- if txn, err = database.Begin(ctx); err != nil {
+ if txn, err = s.database.Begin(ctx); err != nil {
return false, err
}
defer func() {
diff --git a/git_hooks_handle_linux.go b/git_hooks_handle_linux.go
index 37afba1..ca262e3 100644
--- a/git_hooks_handle_linux.go
+++ b/git_hooks_handle_linux.go
@@ -233,12 +233,12 @@ func (s *server) hooksHandler(conn net.Conn) {
var newMRLocalID int
if packPass.userID != 0 {
- err = database.QueryRow(ctx,
+ err = s.database.QueryRow(ctx,
"INSERT INTO merge_requests (repo_id, creator, source_ref, status) VALUES ($1, $2, $3, 'open') RETURNING repo_local_id",
packPass.repoID, packPass.userID, strings.TrimPrefix(refName, "refs/heads/"),
).Scan(&newMRLocalID)
} else {
- err = database.QueryRow(ctx,
+ err = s.database.QueryRow(ctx,
"INSERT INTO merge_requests (repo_id, source_ref, status) VALUES ($1, $2, 'open') RETURNING repo_local_id",
packPass.repoID, strings.TrimPrefix(refName, "refs/heads/"),
).Scan(&newMRLocalID)
@@ -259,7 +259,7 @@ func (s *server) hooksHandler(conn net.Conn) {
var existingMRUser int
var isAncestor bool
- err = database.QueryRow(ctx,
+ err = s.database.QueryRow(ctx,
"SELECT COALESCE(creator, 0) FROM merge_requests WHERE source_ref = $1 AND repo_id = $2",
strings.TrimPrefix(refName, "refs/heads/"), packPass.repoID,
).Scan(&existingMRUser)
diff --git a/git_hooks_handle_other.go b/git_hooks_handle_other.go
index 6d5b08d..ed75e7a 100644
--- a/git_hooks_handle_other.go
+++ b/git_hooks_handle_other.go
@@ -211,12 +211,12 @@ func (s *server) hooksHandler(conn net.Conn) {
var newMRLocalID int
if packPass.userID != 0 {
- err = database.QueryRow(ctx,
+ err = s.database.QueryRow(ctx,
"INSERT INTO merge_requests (repo_id, creator, source_ref, status) VALUES ($1, $2, $3, 'open') RETURNING repo_local_id",
packPass.repoID, packPass.userID, strings.TrimPrefix(refName, "refs/heads/"),
).Scan(&newMRLocalID)
} else {
- err = database.QueryRow(ctx,
+ err = s.database.QueryRow(ctx,
"INSERT INTO merge_requests (repo_id, source_ref, status) VALUES ($1, $2, 'open') RETURNING repo_local_id",
packPass.repoID, strings.TrimPrefix(refName, "refs/heads/"),
).Scan(&newMRLocalID)
@@ -237,7 +237,7 @@ func (s *server) hooksHandler(conn net.Conn) {
var existingMRUser int
var isAncestor bool
- err = database.QueryRow(ctx,
+ err = s.database.QueryRow(ctx,
"SELECT COALESCE(creator, 0) FROM merge_requests WHERE source_ref = $1 AND repo_id = $2",
strings.TrimPrefix(refName, "refs/heads/"), packPass.repoID,
).Scan(&existingMRUser)
diff --git a/git_misc.go b/git_misc.go
index 8e72d0c..8dda01c 100644
--- a/git_misc.go
+++ b/git_misc.go
@@ -22,8 +22,8 @@ import (
// TODO: This should be deprecated in favor of doing it in the relevant
// request/router context in the future, as it cannot cover the nuance of
// fields needed.
-func openRepo(ctx context.Context, groupPath []string, repoName string) (repo *git.Repository, description string, repoID int, fsPath string, err error) {
- err = database.QueryRow(ctx, `
+func (s *server) openRepo(ctx context.Context, groupPath []string, repoName string) (repo *git.Repository, description string, repoID int, fsPath string, err error) {
+ err = s.database.QueryRow(ctx, `
WITH RECURSIVE group_path_cte AS (
-- Start: match the first name in the path where parent_group IS NULL
SELECT
diff --git a/http_auth.go b/http_auth.go
index 03b7e2b..5f0dc66 100644
--- a/http_auth.go
+++ b/http_auth.go
@@ -9,14 +9,14 @@ import (
// getUserFromRequest returns the user ID and username associated with the
// session cookie in a given [http.Request].
-func getUserFromRequest(request *http.Request) (id int, username string, err error) {
+func (s *server) getUserFromRequest(request *http.Request) (id int, username string, err error) {
var sessionCookie *http.Cookie
if sessionCookie, err = request.Cookie("session"); err != nil {
return
}
- err = database.QueryRow(
+ err = s.database.QueryRow(
request.Context(),
"SELECT user_id, COALESCE(username, '') FROM users u JOIN sessions s ON u.id = s.user_id WHERE s.session_id = $1;",
sessionCookie.Value,
diff --git a/http_handle_group_index.go b/http_handle_group_index.go
index 16120a8..46f1f6a 100644
--- a/http_handle_group_index.go
+++ b/http_handle_group_index.go
@@ -28,7 +28,7 @@ func (s *server) httpHandleGroupIndex(writer http.ResponseWriter, request *http.
groupPath = params["group_path"].([]string)
// The group itself
- err = database.QueryRow(request.Context(), `
+ err = s.database.QueryRow(request.Context(), `
WITH RECURSIVE group_path_cte AS (
SELECT
id,
@@ -69,7 +69,7 @@ func (s *server) httpHandleGroupIndex(writer http.ResponseWriter, request *http.
// ACL
var count int
- err = database.QueryRow(request.Context(), `
+ err = s.database.QueryRow(request.Context(), `
SELECT COUNT(*)
FROM user_group_roles
WHERE user_id = $1
@@ -96,7 +96,7 @@ func (s *server) httpHandleGroupIndex(writer http.ResponseWriter, request *http.
}
var newRepoID int
- err := database.QueryRow(
+ err := s.database.QueryRow(
request.Context(),
`INSERT INTO repos (name, description, group_id, contrib_requirements)
VALUES ($1, $2, $3, $4)
@@ -113,7 +113,7 @@ func (s *server) httpHandleGroupIndex(writer http.ResponseWriter, request *http.
filePath := filepath.Join(s.config.Git.RepoDir, strconv.Itoa(newRepoID)+".git")
- _, err = database.Exec(
+ _, err = s.database.Exec(
request.Context(),
`UPDATE repos
SET filesystem_path = $1
@@ -137,7 +137,7 @@ func (s *server) httpHandleGroupIndex(writer http.ResponseWriter, request *http.
// Repos
var rows pgx.Rows
- rows, err = database.Query(request.Context(), `
+ rows, err = s.database.Query(request.Context(), `
SELECT name, COALESCE(description, '')
FROM repos
WHERE group_id = $1
@@ -162,7 +162,7 @@ func (s *server) httpHandleGroupIndex(writer http.ResponseWriter, request *http.
}
// Subgroups
- rows, err = database.Query(request.Context(), `
+ rows, err = s.database.Query(request.Context(), `
SELECT name, COALESCE(description, '')
FROM groups
WHERE parent_group = $1
diff --git a/http_handle_login.go b/http_handle_login.go
index ea1dbae..10bfdcd 100644
--- a/http_handle_login.go
+++ b/http_handle_login.go
@@ -35,7 +35,7 @@ func (s *server) httpHandleLogin(writer http.ResponseWriter, request *http.Reque
username = request.PostFormValue("username")
password = request.PostFormValue("password")
- err = database.QueryRow(request.Context(),
+ err = s.database.QueryRow(request.Context(),
"SELECT id, COALESCE(password, '') FROM users WHERE username = $1",
username,
).Scan(&userID, &passwordHash)
@@ -85,7 +85,7 @@ func (s *server) httpHandleLogin(writer http.ResponseWriter, request *http.Reque
http.SetCookie(writer, &cookie)
- _, err = database.Exec(request.Context(), "INSERT INTO sessions (user_id, session_id) VALUES ($1, $2)", userID, cookieValue)
+ _, err = s.database.Exec(request.Context(), "INSERT INTO sessions (user_id, session_id) VALUES ($1, $2)", userID, cookieValue)
if err != nil {
errorPage500(writer, params, "Error inserting session: "+err.Error())
return
diff --git a/http_handle_repo_contrib_index.go b/http_handle_repo_contrib_index.go
index ee7b956..e0c8478 100644
--- a/http_handle_repo_contrib_index.go
+++ b/http_handle_repo_contrib_index.go
@@ -18,12 +18,12 @@ type idTitleStatus struct {
}
// httpHandleRepoContribIndex provides an index to merge requests of a repo.
-func httpHandleRepoContribIndex(writer http.ResponseWriter, request *http.Request, params map[string]any) {
+func (s *server) httpHandleRepoContribIndex(writer http.ResponseWriter, request *http.Request, params map[string]any) {
var rows pgx.Rows
var result []idTitleStatus
var err error
- if rows, err = database.Query(request.Context(),
+ if rows, err = s.database.Query(request.Context(),
"SELECT repo_local_id, COALESCE(title, 'Untitled'), status FROM merge_requests WHERE repo_id = $1",
params["repo_id"],
); err != nil {
diff --git a/http_handle_repo_contrib_one.go b/http_handle_repo_contrib_one.go
index dcd0e0d..0df7491 100644
--- a/http_handle_repo_contrib_one.go
+++ b/http_handle_repo_contrib_one.go
@@ -14,7 +14,7 @@ import (
// httpHandleRepoContribOne provides an interface to each merge request of a
// repo.
-func httpHandleRepoContribOne(writer http.ResponseWriter, request *http.Request, params map[string]any) {
+func (s *server) httpHandleRepoContribOne(writer http.ResponseWriter, request *http.Request, params map[string]any) {
var mrIDStr string
var mrIDInt int
var err error
@@ -33,7 +33,7 @@ func httpHandleRepoContribOne(writer http.ResponseWriter, request *http.Request,
}
mrIDInt = int(mrIDInt64)
- if err = database.QueryRow(request.Context(),
+ if err = s.database.QueryRow(request.Context(),
"SELECT COALESCE(title, ''), status, source_ref, COALESCE(destination_branch, '') FROM merge_requests WHERE repo_id = $1 AND repo_local_id = $2",
params["repo_id"], mrIDInt,
).Scan(&title, &status, &srcRefStr, &dstBranchStr); err != nil {
diff --git a/http_handle_repo_info.go b/http_handle_repo_info.go
index 3f1787e..b7b7438 100644
--- a/http_handle_repo_info.go
+++ b/http_handle_repo_info.go
@@ -16,12 +16,12 @@ import (
// HTTP protocol.
//
// TODO: Reject access from web browsers.
-func httpHandleRepoInfo(writer http.ResponseWriter, request *http.Request, params map[string]any) (err error) {
+func (s *server) httpHandleRepoInfo(writer http.ResponseWriter, request *http.Request, params map[string]any) (err error) {
groupPath := params["group_path"].([]string)
repoName := params["repo_name"].(string)
var repoPath string
- if err := database.QueryRow(request.Context(), `
+ if err := s.database.QueryRow(request.Context(), `
WITH RECURSIVE group_path_cte AS (
-- Start: match the first name in the path where parent_group IS NULL
SELECT
diff --git a/http_handle_repo_upload_pack.go b/http_handle_repo_upload_pack.go
index 3d9170c..a6580a7 100644
--- a/http_handle_repo_upload_pack.go
+++ b/http_handle_repo_upload_pack.go
@@ -24,7 +24,7 @@ func (s *server) httpHandleUploadPack(writer http.ResponseWriter, request *http.
groupPath, repoName = params["group_path"].([]string), params["repo_name"].(string)
- if err := database.QueryRow(request.Context(), `
+ if err := s.database.QueryRow(request.Context(), `
WITH RECURSIVE group_path_cte AS (
-- Start: match the first name in the path where parent_group IS NULL
SELECT
diff --git a/http_server.go b/http_server.go
index 5c78533..ae82241 100644
--- a/http_server.go
+++ b/http_server.go
@@ -52,7 +52,7 @@ func (s *server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
params["dir_mode"] = dirMode
params["global"] = globalData
var userID int // 0 for none
- userID, params["username"], err = getUserFromRequest(request)
+ userID, params["username"], err = s.getUserFromRequest(request)
params["user_id"] = userID
if err != nil && !errors.Is(err, http.ErrNoCookie) && !errors.Is(err, pgx.ErrNoRows) {
errorPage500(writer, params, "Error getting user info from request: "+err.Error())
@@ -152,7 +152,7 @@ func (s *server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if len(segments) > sepIndex+3 {
switch segments[sepIndex+3] {
case "info":
- if err = httpHandleRepoInfo(writer, request, params); err != nil {
+ if err = s.httpHandleRepoInfo(writer, request, params); err != nil {
errorPage500(writer, params, err.Error())
}
return
@@ -173,7 +173,7 @@ func (s *server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
}
- if params["repo"], params["repo_description"], params["repo_id"], _, err = openRepo(request.Context(), groupPath, moduleName); err != nil {
+ if params["repo"], params["repo_description"], params["repo_id"], _, err = s.openRepo(request.Context(), groupPath, moduleName); err != nil {
errorPage500(writer, params, "Error opening repo: "+err.Error())
return
}
@@ -256,10 +256,10 @@ func (s *server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
switch len(segments) {
case sepIndex + 4:
- httpHandleRepoContribIndex(writer, request, params)
+ s.httpHandleRepoContribIndex(writer, request, params)
case sepIndex + 5:
params["mr_id"] = segments[sepIndex+4]
- httpHandleRepoContribOne(writer, request, params)
+ s.httpHandleRepoContribOne(writer, request, params)
default:
errorPage400(writer, params, "Too many parameters")
}
diff --git a/lmtp_handle_patch.go b/lmtp_handle_patch.go
index 45d146a..ab846aa 100644
--- a/lmtp_handle_patch.go
+++ b/lmtp_handle_patch.go
@@ -19,7 +19,7 @@ import (
"go.lindenii.runxiyu.org/forge/misc"
)
-func lmtpHandlePatch(session *lmtpSession, groupPath []string, repoName string, mbox io.Reader) (err error) {
+func (s *server) lmtpHandlePatch(session *lmtpSession, groupPath []string, repoName string, mbox io.Reader) (err error) {
var diffFiles []*gitdiff.File
var preamble string
if diffFiles, preamble, err = gitdiff.Parse(mbox); err != nil {
@@ -33,7 +33,7 @@ func lmtpHandlePatch(session *lmtpSession, groupPath []string, repoName string,
var repo *git.Repository
var fsPath string
- repo, _, _, fsPath, err = openRepo(session.ctx, groupPath, repoName)
+ repo, _, _, fsPath, err = s.openRepo(session.ctx, groupPath, repoName)
if err != nil {
return fmt.Errorf("failed to open repo: %w", err)
}
diff --git a/lmtp_server.go b/lmtp_server.go
index e97ca55..8191766 100644
--- a/lmtp_server.go
+++ b/lmtp_server.go
@@ -177,7 +177,7 @@ func (session *lmtpSession) Data(r io.Reader) error {
moduleName := segments[sepIndex+2]
switch moduleType {
case "repos":
- err = lmtpHandlePatch(session, groupPath, moduleName, &mbox)
+ err = session.s.lmtpHandlePatch(session, groupPath, moduleName, &mbox)
if err != nil {
slog.Error("error handling patch", "error", err)
goto end
diff --git a/server.go b/server.go
index 8f35913..1113740 100644
--- a/server.go
+++ b/server.go
@@ -1,5 +1,12 @@
package main
+import "github.com/jackc/pgx/v5/pgxpool"
+
type server struct {
config Config
+
+ // database serves as the primary database handle for this entire application.
+ // Transactions or single reads may be used from it. A [pgxpool.Pool] is
+ // necessary to safely use pgx concurrently; pgx.Conn, etc. are insufficient.
+ database *pgxpool.Pool
}
diff --git a/ssh_handle_receive_pack.go b/ssh_handle_receive_pack.go
index ed7ef40..317609f 100644
--- a/ssh_handle_receive_pack.go
+++ b/ssh_handle_receive_pack.go
@@ -76,7 +76,7 @@ func (s *server) sshHandleRecvPack(session gliderSSH.Session, pubkey, repoIdenti
return errors.New("you need to have an SSH public key to push to this repo")
}
if userType == "" {
- userID, err = addUserSSH(session.Context(), pubkey)
+ userID, err = s.addUserSSH(session.Context(), pubkey)
if err != nil {
return err
}
diff --git a/users.go b/users.go
index f0dabce..1b31f3a 100644
--- a/users.go
+++ b/users.go
@@ -12,10 +12,10 @@ import (
// addUserSSH adds a new user solely based on their SSH public key.
//
// TODO: Audit all users of this function.
-func addUserSSH(ctx context.Context, pubkey string) (userID int, err error) {
+func (s *server) addUserSSH(ctx context.Context, pubkey string) (userID int, err error) {
var txn pgx.Tx
- if txn, err = database.Begin(ctx); err != nil {
+ if txn, err = s.database.Begin(ctx); err != nil {
return
}
defer func() {