aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRunxi Yu <me@runxiyu.org>2025-02-16 01:48:39 +0800
committerRunxi Yu <me@runxiyu.org>2025-02-16 01:52:47 +0800
commitd212c4606a6eb470067d5302b2350d288d4d9c88 (patch)
tree8eac51da018f6bfbfbae1356968ff8908b887ab6
parentschema.sql: Fix public keys and add basic group ACL (diff)
downloadforge-d212c4606a6eb470067d5302b2350d288d4d9c88.tar.gz
forge-d212c4606a6eb470067d5302b2350d288d4d9c88.tar.zst
forge-d212c4606a6eb470067d5302b2350d288d4d9c88.zip
{ssh_*,acl}.go: Check ACL when receiving packs
-rw-r--r--acl.go31
-rw-r--r--ssh_handle_receive_pack.go9
-rw-r--r--ssh_handle_upload_pack.go2
-rw-r--r--ssh_utils.go16
4 files changed, 47 insertions, 11 deletions
diff --git a/acl.go b/acl.go
new file mode 100644
index 0000000..99cd5fb
--- /dev/null
+++ b/acl.go
@@ -0,0 +1,31 @@
+package main
+
+import (
+ "context"
+)
+
+func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name, ssh_pubkey string) (filesystem_path string, access bool, err error) {
+ err = database.QueryRow(ctx,
+ `SELECT
+ r.filesystem_path,
+ CASE
+ WHEN ugr.user_id IS NOT NULL THEN TRUE
+ ELSE FALSE
+ END AS has_role_in_group
+ FROM
+ groups g
+ JOIN
+ repos r ON r.group_id = g.id
+ LEFT JOIN
+ ssh_public_keys s ON s.key_string = $3
+ LEFT JOIN
+ users u ON u.id = s.user_id
+ LEFT JOIN
+ user_group_roles ugr ON ugr.group_id = g.id AND ugr.user_id = u.id
+ WHERE
+ g.name = $1
+ AND r.name = $2;`,
+ group_name, repo_name, ssh_pubkey,
+ ).Scan(&filesystem_path, &access)
+ return
+}
diff --git a/ssh_handle_receive_pack.go b/ssh_handle_receive_pack.go
index 3395e24..30825ad 100644
--- a/ssh_handle_receive_pack.go
+++ b/ssh_handle_receive_pack.go
@@ -1,6 +1,8 @@
package main
import (
+ "errors"
+
glider_ssh "github.com/gliderlabs/ssh"
"github.com/go-git/go-billy/v5/osfs"
"github.com/go-git/go-git/v5/plumbing/protocol/packp"
@@ -8,11 +10,16 @@ import (
transport_server "github.com/go-git/go-git/v5/plumbing/transport/server"
)
+var err_unauthorized_push = errors.New("You are not authorized to push to this repository")
+
func ssh_handle_receive_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) {
- repo_path, err := get_repo_path_from_ssh_path(session.Context(), repo_identifier)
+ repo_path, access, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
if err != nil {
return err
}
+ if !access {
+ return err_unauthorized_push
+ }
endpoint, err := transport.NewEndpoint("/")
if err != nil {
return err
diff --git a/ssh_handle_upload_pack.go b/ssh_handle_upload_pack.go
index 7812f1a..8281cbd 100644
--- a/ssh_handle_upload_pack.go
+++ b/ssh_handle_upload_pack.go
@@ -9,7 +9,7 @@ import (
)
func ssh_handle_upload_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) {
- repo_path, err := get_repo_path_from_ssh_path(session.Context(), repo_identifier)
+ repo_path, _, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey)
if err != nil {
return err
}
diff --git a/ssh_utils.go b/ssh_utils.go
index 8eaaebd..cf96b21 100644
--- a/ssh_utils.go
+++ b/ssh_utils.go
@@ -9,19 +9,19 @@ import (
var err_ssh_illegal_endpoint = errors.New("Illegal endpoint during SSH access")
-func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_path string, err error) {
+func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path string, ssh_pubkey string) (repo_path string, access bool, err error) {
segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/")
for i, segment := range segments {
var err error
segments[i], err = url.PathUnescape(segment)
if err != nil {
- return "", err
+ return "", false, err
}
}
if segments[0] == ":" {
- return "", err_ssh_illegal_endpoint
+ return "", false, err_ssh_illegal_endpoint
}
separator_index := -1
@@ -37,9 +37,9 @@ func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_pat
switch {
case separator_index == -1:
- return "", err_ssh_illegal_endpoint
+ return "", false, err_ssh_illegal_endpoint
case len(segments) <= separator_index+2:
- return "", err_ssh_illegal_endpoint
+ return "", false, err_ssh_illegal_endpoint
}
group_name := segments[0]
@@ -47,10 +47,8 @@ func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_pat
module_name := segments[separator_index+2]
switch module_type {
case "repos":
- var fs_path string
- err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, module_name).Scan(&fs_path)
- return fs_path, err
+ return get_path_perm_by_group_repo_key(ctx, group_name, module_name, ssh_pubkey)
default:
- return "", err_ssh_illegal_endpoint
+ return "", false, err_ssh_illegal_endpoint
}
}