aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--acl.go8
-rw-r--r--git_hooks_handle.go37
-rw-r--r--schema.sql8
-rw-r--r--ssh_handle_receive_pack.go8
-rw-r--r--ssh_handle_upload_pack.go2
-rw-r--r--ssh_utils.go12
6 files changed, 54 insertions, 21 deletions
diff --git a/acl.go b/acl.go
index 7ad48fb..414e102 100644
--- a/acl.go
+++ b/acl.go
@@ -6,16 +6,18 @@ import (
// get_path_perm_by_group_repo_key returns the filesystem path and direct
// access permission for a given repo and a provided ssh public key.
-func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name, ssh_pubkey string) (filesystem_path string, access bool, contrib_requirements string, user_type string, err error) {
+func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name, ssh_pubkey string) (repo_id int, filesystem_path string, access bool, contrib_requirements string, user_type string, user_id int, err error) {
err = database.QueryRow(ctx,
`SELECT
+ r.id,
r.filesystem_path,
CASE
WHEN ugr.user_id IS NOT NULL THEN TRUE
ELSE FALSE
END AS has_role_in_group,
r.contrib_requirements,
- COALESCE(u.type, '')
+ COALESCE(u.type, ''),
+ COALESCE(u.id, 0)
FROM
groups g
JOIN
@@ -30,6 +32,6 @@ func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name,
g.name = $1
AND r.name = $2;`,
group_name, repo_name, ssh_pubkey,
- ).Scan(&filesystem_path, &access, &contrib_requirements, &user_type)
+ ).Scan(&repo_id, &filesystem_path, &access, &contrib_requirements, &user_type, &user_id)
return
}
diff --git a/git_hooks_handle.go b/git_hooks_handle.go
index b047eb9..9dc3ed6 100644
--- a/git_hooks_handle.go
+++ b/git_hooks_handle.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -12,6 +13,7 @@ import (
"strings"
"syscall"
+ "github.com/jackc/pgx/v5"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
)
@@ -25,6 +27,8 @@ var (
// unix socket.
func hooks_handle_connection(conn net.Conn) {
defer conn.Close()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
// There aren't reasonable cases where someone would run this as
// another user.
@@ -124,12 +128,35 @@ func hooks_handle_connection(conn net.Conn) {
if strings.HasPrefix(ref_name, "refs/heads/contrib/") {
if all_zero_num_string(old_oid) { // New branch
fmt.Fprintln(ssh_stderr, "Acceptable push to new contrib branch: "+ref_name)
- // TODO: Create a merge request. If that fails,
- // we should just reject this entire push
- // immediately.
+ _, err = database.Exec(ctx,
+ "INSERT INTO merge_requests (repo_id, creator, source_ref, status) VALUES ($1, $2, $3, 'open')",
+ pack_to_hook.repo_id, pack_to_hook.user_id, strings.TrimPrefix(ref_name, "refs/heads/contrib/"),
+ )
+ if err != nil {
+ fmt.Fprintln(ssh_stderr, "Error creating merge request:", err.Error())
+ return 1
+ }
} else { // Existing contrib branch
- // TODO: Check if the current user is authorized
- // to push to this contrib branch.
+ var existing_merge_request_user_id int
+ err = database.QueryRow(ctx,
+ "SELECT creator FROM merge_requests WHERE source_ref = $1 AND repo_id = $2",
+ strings.TrimPrefix(ref_name, "refs/heads/contrib/"), pack_to_hook.repo_id,
+ ).Scan(&existing_merge_request_user_id)
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ fmt.Fprintln(ssh_stderr, "No existing merge request for existing contrib branch:", err.Error())
+ } else {
+ fmt.Fprintln(ssh_stderr, "Error querying for existing merge request:", err.Error())
+ }
+ return 1
+ }
+
+ if existing_merge_request_user_id != pack_to_hook.user_id {
+ all_ok = false
+ fmt.Fprintln(ssh_stderr, "Rejecting push to existing contrib branch owned by another user:", ref_name)
+ continue
+ }
+
repo, err := git.PlainOpen(pack_to_hook.repo_path)
if err != nil {
fmt.Fprintln(ssh_stderr, "Daemon failed to open repo:", err.Error())
diff --git a/schema.sql b/schema.sql
index 3db8967..684c32d 100644
--- a/schema.sql
+++ b/schema.sql
@@ -66,14 +66,14 @@ CREATE TABLE sessions (
UNIQUE(user_id, session_id)
);
-// TODO:
+-- TODO:
CREATE TABLE merge_requests (
id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
- title TEXT NOT NULL,
+ title TEXT,
repo_id INTEGER NOT NULL REFERENCES repos(id) ON DELETE CASCADE,
- creator INTEGER NOT NULL REFERENCES users(id) ON DELETE SET NULL,
+ creator INTEGER REFERENCES users(id) ON DELETE SET NULL,
source_ref TEXT NOT NULL,
- destination_branch TEXT NOT NULL,
+ destination_branch TEXT,
status TEXT NOT NULL CHECK (status IN ('open', 'merged', 'closed')),
UNIQUE (repo_id, source_ref, destination_branch),
UNIQUE (repo_id, id)
diff --git a/ssh_handle_receive_pack.go b/ssh_handle_receive_pack.go
index 293bb36..8803151 100644
--- a/ssh_handle_receive_pack.go
+++ b/ssh_handle_receive_pack.go
@@ -15,13 +15,15 @@ type pack_to_hook_t struct {
pubkey string
direct_access bool
repo_path string
+ user_id int
+ repo_id int
}
var pack_to_hook_by_cookie = cmap.Map[string, pack_to_hook_t]{}
// ssh_handle_receive_pack handles attempts to push to repos.
func ssh_handle_receive_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) {
- repo_path, direct_access, contrib_requirements, user_type, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
+ repo_id, repo_path, direct_access, contrib_requirements, user_type, user_id, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
if err != nil {
return err
}
@@ -41,7 +43,7 @@ func ssh_handle_receive_pack(session glider_ssh.Session, pubkey string, repo_ide
return errors.New("You need to have an SSH public key to push to this repo.")
}
if user_type == "" {
- user_id, err := add_user_ssh(session.Context(), pubkey)
+ user_id, err = add_user_ssh(session.Context(), pubkey)
if err != nil {
return err
}
@@ -63,6 +65,8 @@ func ssh_handle_receive_pack(session glider_ssh.Session, pubkey string, repo_ide
pubkey: pubkey,
direct_access: direct_access,
repo_path: repo_path,
+ user_id: user_id,
+ repo_id: repo_id,
})
defer pack_to_hook_by_cookie.Delete(cookie)
// The Delete won't execute until proc.Wait returns unless something
diff --git a/ssh_handle_upload_pack.go b/ssh_handle_upload_pack.go
index 7435668..8efcd28 100644
--- a/ssh_handle_upload_pack.go
+++ b/ssh_handle_upload_pack.go
@@ -11,7 +11,7 @@ import (
// ssh_handle_upload_pack handles clones/fetches. It just uses git-upload-pack
// and has no ACL checks.
func ssh_handle_upload_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) {
- repo_path, _, _, _, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
+ _, repo_path, _, _, _, _, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
if err != nil {
return err
}
diff --git a/ssh_utils.go b/ssh_utils.go
index fb8f920..7f3188f 100644
--- a/ssh_utils.go
+++ b/ssh_utils.go
@@ -9,19 +9,19 @@ import (
var err_ssh_illegal_endpoint = errors.New("illegal endpoint during SSH access")
-func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path string, ssh_pubkey string) (repo_path string, direct_access bool, contrib_requirements string, user_type string, err error) {
+func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path string, ssh_pubkey string) (repo_id int, repo_path string, direct_access bool, contrib_requirements string, user_type string, user_id int, err error) {
segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/")
for i, segment := range segments {
var err error
segments[i], err = url.PathUnescape(segment)
if err != nil {
- return "", false, "", "", err
+ return 0, "", false, "", "", 0, err
}
}
if segments[0] == ":" {
- return "", false, "", "", err_ssh_illegal_endpoint
+ return 0, "", false, "", "", 0, err_ssh_illegal_endpoint
}
separator_index := -1
@@ -37,9 +37,9 @@ func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path stri
switch {
case separator_index == -1:
- return "", false, "", "", err_ssh_illegal_endpoint
+ return 0, "", false, "", "", 0, err_ssh_illegal_endpoint
case len(segments) <= separator_index+2:
- return "", false, "", "", err_ssh_illegal_endpoint
+ return 0, "", false, "", "", 0, err_ssh_illegal_endpoint
}
group_name := segments[0]
@@ -49,6 +49,6 @@ func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path stri
case "repos":
return get_path_perm_by_group_repo_key(ctx, group_name, module_name, ssh_pubkey)
default:
- return "", false, "", "", err_ssh_illegal_endpoint
+ return 0, "", false, "", "", 0, err_ssh_illegal_endpoint
}
}