From de1b961fbf54601f25c54c1618f11978f6618858 Mon Sep 17 00:00:00 2001
From: Runxi Yu <me@runxiyu.org>
Date: Wed, 19 Feb 2025 20:14:20 +0800
Subject: ssh/recv, schema: Add repos.contrib_requirements

---
 acl.go                     | 11 ++++++++---
 schema.sql                 |  1 +
 ssh_handle_receive_pack.go | 23 ++++++++++++++++++++++-
 ssh_handle_upload_pack.go  |  2 +-
 ssh_utils.go               | 12 ++++++------
 5 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/acl.go b/acl.go
index 22c8a4c..095c1f1 100644
--- a/acl.go
+++ b/acl.go
@@ -6,14 +6,19 @@ import (
 
 // get_path_perm_by_group_repo_key returns the filesystem path and direct
 // access permission for a given repo and a provided ssh public key.
-func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name, ssh_pubkey string) (filesystem_path string, access bool, err error) {
+func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name, ssh_pubkey string) (filesystem_path string, access bool, contrib_requirements string, is_registered_user bool, err error) {
 	err = database.QueryRow(ctx,
 		`SELECT 
 		r.filesystem_path,
 		CASE
 			WHEN ugr.user_id IS NOT NULL THEN TRUE
 			ELSE FALSE
-		END AS has_role_in_group
+		END AS has_role_in_group,
+		r.contrib_requirements,
+		CASE
+			WHEN u.id IS NOT NULL THEN TRUE
+			ELSE FALSE
+		END
 		FROM 
 			groups g
 		JOIN 
@@ -28,6 +33,6 @@ func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name,
 			g.name = $1
 		AND r.name = $2;`,
 		group_name, repo_name, ssh_pubkey,
-	).Scan(&filesystem_path, &access)
+	).Scan(&filesystem_path, &access, &contrib_requirements, &is_registered_user)
 	return
 }
diff --git a/schema.sql b/schema.sql
index ee32bd6..2589a07 100644
--- a/schema.sql
+++ b/schema.sql
@@ -7,6 +7,7 @@ CREATE TABLE groups (
 CREATE TABLE repos (
 	id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
 	group_id INTEGER NOT NULL REFERENCES groups(id) ON DELETE RESTRICT, -- I mean, should be CASCADE but deleting Git repos on disk also needs to be considered
+	contrib_requirements TEXT NOT NULL CHECK (contrib_requirements IN ('closed', 'registered_user', 'ssh_pubkey', 'public')),
 	name TEXT NOT NULL,
 	UNIQUE(group_id, name),
 	description TEXT,
diff --git a/ssh_handle_receive_pack.go b/ssh_handle_receive_pack.go
index 93bf25b..8bc296e 100644
--- a/ssh_handle_receive_pack.go
+++ b/ssh_handle_receive_pack.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"errors"
 	"fmt"
 	"os"
 	"os/exec"
@@ -24,11 +25,31 @@ func ssh_handle_receive_pack(session glider_ssh.Session, pubkey string, repo_ide
 	// necessarily mean the push is declined. This decision is delegated to
 	// the pre-receive hook, which is then handled by git_hooks_handle.go
 	// while being aware of the refs to be updated.
-	repo_path, access, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
+	repo_path, access, contrib_requirements, is_registered_user, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
 	if err != nil {
 		return err
 	}
 
+	if !access {
+		switch contrib_requirements {
+		case "closed":
+			if !access {
+				return errors.New("You need direct access to push to this repo.")
+			}
+		case "registered_user":
+			if !is_registered_user {
+				return errors.New("You need to be a registered user to push to this repo.")
+			}
+		case "ssh_pubkey":
+			if pubkey == "" {
+				return errors.New("You need to have an SSH public key to push to this repo.")
+			}
+		case "public":
+		default:
+			panic("unknown contrib_requirements value " + contrib_requirements)
+		}
+	}
+
 	cookie, err := random_urlsafe_string(16)
 	if err != nil {
 		fmt.Fprintln(session.Stderr(), "Error while generating cookie:", err)
diff --git a/ssh_handle_upload_pack.go b/ssh_handle_upload_pack.go
index afc9900..7435668 100644
--- a/ssh_handle_upload_pack.go
+++ b/ssh_handle_upload_pack.go
@@ -11,7 +11,7 @@ import (
 // ssh_handle_upload_pack handles clones/fetches. It just uses git-upload-pack
 // and has no ACL checks.
 func ssh_handle_upload_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) {
-	repo_path, _, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
+	repo_path, _, _, _, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
 	if err != nil {
 		return err
 	}
diff --git a/ssh_utils.go b/ssh_utils.go
index 757fbc5..bf8bf5e 100644
--- a/ssh_utils.go
+++ b/ssh_utils.go
@@ -9,19 +9,19 @@ import (
 
 var err_ssh_illegal_endpoint = errors.New("illegal endpoint during SSH access")
 
-func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path string, ssh_pubkey string) (repo_path string, direct_access bool, err error) {
+func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path string, ssh_pubkey string) (repo_path string, direct_access bool, contrib_requirements string, is_registered_user bool, err error) {
 	segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/")
 
 	for i, segment := range segments {
 		var err error
 		segments[i], err = url.PathUnescape(segment)
 		if err != nil {
-			return "", false, err
+			return "", false, "", false, err
 		}
 	}
 
 	if segments[0] == ":" {
-		return "", false, err_ssh_illegal_endpoint
+		return "", false, "", false, err_ssh_illegal_endpoint
 	}
 
 	separator_index := -1
@@ -37,9 +37,9 @@ func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path stri
 
 	switch {
 	case separator_index == -1:
-		return "", false, err_ssh_illegal_endpoint
+		return "", false, "", false, err_ssh_illegal_endpoint
 	case len(segments) <= separator_index+2:
-		return "", false, err_ssh_illegal_endpoint
+		return "", false, "", false, err_ssh_illegal_endpoint
 	}
 
 	group_name := segments[0]
@@ -49,6 +49,6 @@ func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path stri
 	case "repos":
 		return get_path_perm_by_group_repo_key(ctx, group_name, module_name, ssh_pubkey)
 	default:
-		return "", false, err_ssh_illegal_endpoint
+		return "", false, "", false, err_ssh_illegal_endpoint
 	}
 }
-- 
cgit v1.2.3