diff options
author | Runxi Yu <me@runxiyu.org> | 2025-02-16 01:48:39 +0800 |
---|---|---|
committer | Runxi Yu <me@runxiyu.org> | 2025-02-16 01:52:47 +0800 |
commit | d212c4606a6eb470067d5302b2350d288d4d9c88 (patch) | |
tree | 8eac51da018f6bfbfbae1356968ff8908b887ab6 | |
parent | schema.sql: Fix public keys and add basic group ACL (diff) | |
download | forge-d212c4606a6eb470067d5302b2350d288d4d9c88.tar.gz forge-d212c4606a6eb470067d5302b2350d288d4d9c88.tar.zst forge-d212c4606a6eb470067d5302b2350d288d4d9c88.zip |
{ssh_*,acl}.go: Check ACL when receiving packs
-rw-r--r-- | acl.go | 31 | ||||
-rw-r--r-- | ssh_handle_receive_pack.go | 9 | ||||
-rw-r--r-- | ssh_handle_upload_pack.go | 2 | ||||
-rw-r--r-- | ssh_utils.go | 16 |
4 files changed, 47 insertions, 11 deletions
@@ -0,0 +1,31 @@ +package main + +import ( + "context" +) + +func get_path_perm_by_group_repo_key(ctx context.Context, group_name, repo_name, ssh_pubkey string) (filesystem_path string, access bool, err error) { + err = database.QueryRow(ctx, + `SELECT + r.filesystem_path, + CASE + WHEN ugr.user_id IS NOT NULL THEN TRUE + ELSE FALSE + END AS has_role_in_group + FROM + groups g + JOIN + repos r ON r.group_id = g.id + LEFT JOIN + ssh_public_keys s ON s.key_string = $3 + LEFT JOIN + users u ON u.id = s.user_id + LEFT JOIN + user_group_roles ugr ON ugr.group_id = g.id AND ugr.user_id = u.id + WHERE + g.name = $1 + AND r.name = $2;`, + group_name, repo_name, ssh_pubkey, + ).Scan(&filesystem_path, &access) + return +} diff --git a/ssh_handle_receive_pack.go b/ssh_handle_receive_pack.go index 3395e24..30825ad 100644 --- a/ssh_handle_receive_pack.go +++ b/ssh_handle_receive_pack.go @@ -1,6 +1,8 @@ package main import ( + "errors" + glider_ssh "github.com/gliderlabs/ssh" "github.com/go-git/go-billy/v5/osfs" "github.com/go-git/go-git/v5/plumbing/protocol/packp" @@ -8,11 +10,16 @@ import ( transport_server "github.com/go-git/go-git/v5/plumbing/transport/server" ) +var err_unauthorized_push = errors.New("You are not authorized to push to this repository") + func ssh_handle_receive_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) { - repo_path, err := get_repo_path_from_ssh_path(session.Context(), repo_identifier) + repo_path, access, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey) if err != nil { return err } + if !access { + return err_unauthorized_push + } endpoint, err := transport.NewEndpoint("/") if err != nil { return err diff --git a/ssh_handle_upload_pack.go b/ssh_handle_upload_pack.go index 7812f1a..8281cbd 100644 --- a/ssh_handle_upload_pack.go +++ b/ssh_handle_upload_pack.go @@ -9,7 +9,7 @@ import ( ) func ssh_handle_upload_pack(session glider_ssh.Session, pubkey string, repo_identifier string) (err error) { - repo_path, err := get_repo_path_from_ssh_path(session.Context(), repo_identifier) + repo_path, _, err := get_repo_path_perms_from_ssh_path_pubkey(session.Context(), repo_identifier, pubkey) if err != nil { return err } diff --git a/ssh_utils.go b/ssh_utils.go index 8eaaebd..cf96b21 100644 --- a/ssh_utils.go +++ b/ssh_utils.go @@ -9,19 +9,19 @@ import ( var err_ssh_illegal_endpoint = errors.New("Illegal endpoint during SSH access") -func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_path string, err error) { +func get_repo_path_perms_from_ssh_path_pubkey(ctx context.Context, ssh_path string, ssh_pubkey string) (repo_path string, access bool, err error) { segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/") for i, segment := range segments { var err error segments[i], err = url.PathUnescape(segment) if err != nil { - return "", err + return "", false, err } } if segments[0] == ":" { - return "", err_ssh_illegal_endpoint + return "", false, err_ssh_illegal_endpoint } separator_index := -1 @@ -37,9 +37,9 @@ func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_pat switch { case separator_index == -1: - return "", err_ssh_illegal_endpoint + return "", false, err_ssh_illegal_endpoint case len(segments) <= separator_index+2: - return "", err_ssh_illegal_endpoint + return "", false, err_ssh_illegal_endpoint } group_name := segments[0] @@ -47,10 +47,8 @@ func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_pat module_name := segments[separator_index+2] switch module_type { case "repos": - var fs_path string - err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, module_name).Scan(&fs_path) - return fs_path, err + return get_path_perm_by_group_repo_key(ctx, group_name, module_name, ssh_pubkey) default: - return "", err_ssh_illegal_endpoint + return "", false, err_ssh_illegal_endpoint } } |