diff options
author | Runxi Yu <me@runxiyu.org> | 2025-02-12 19:16:41 +0800 |
---|---|---|
committer | Runxi Yu <me@runxiyu.org> | 2025-02-12 19:16:41 +0800 |
commit | f828acac387aacadd2884837402b0e32b2368470 (patch) | |
tree | 475c374f5e9b1205136cf03171867ead7fad4291 | |
parent | http_router.go: Move from router.go and fix conditional placement bug (diff) | |
download | forge-f828acac387aacadd2884837402b0e32b2368470.tar.gz forge-f828acac387aacadd2884837402b0e32b2368470.tar.zst forge-f828acac387aacadd2884837402b0e32b2368470.zip |
*.go: Use the database for repo info, and fix ssh cloning repo
-rw-r--r-- | git_misc.go | 17 | ||||
-rw-r--r-- | handle_group_index.go | 29 | ||||
-rw-r--r-- | handle_repo_commit.go | 2 | ||||
-rw-r--r-- | handle_repo_index.go | 2 | ||||
-rw-r--r-- | handle_repo_log.go | 2 | ||||
-rw-r--r-- | handle_repo_raw.go | 2 | ||||
-rw-r--r-- | handle_repo_tree.go | 2 | ||||
-rw-r--r-- | router_ssh.go | 56 | ||||
-rw-r--r-- | ssh.go | 10 | ||||
-rw-r--r-- | url_misc.go | 5 |
10 files changed, 98 insertions, 29 deletions
diff --git a/git_misc.go b/git_misc.go index 882c631..2d4c4d3 100644 --- a/git_misc.go +++ b/git_misc.go @@ -1,9 +1,9 @@ package main import ( + "context" "errors" "io" - "path/filepath" "strings" "github.com/go-git/go-git/v5" @@ -19,16 +19,13 @@ var ( err_getting_parent_commit_object = errors.New("Error getting parent commit object") ) -func open_git_repo(group_name, repo_name string) (*git.Repository, error) { - group_name, group_name_ok := misc.Sanitize_path(group_name) - if !group_name_ok { - return nil, err_unsafe_path - } - repo_name, repo_name_ok := misc.Sanitize_path(repo_name) - if !repo_name_ok { - return nil, err_unsafe_path +func open_git_repo(ctx context.Context, group_name, repo_name string) (*git.Repository, error) { + var fs_path string + err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, repo_name).Scan(&fs_path) + if err != nil { + return nil, err } - return git.PlainOpen(filepath.Join(config.Git.Root, group_name, repo_name+".git")) + return git.PlainOpen(fs_path) } type display_git_tree_entry_t struct { diff --git a/handle_group_index.go b/handle_group_index.go index bc7a7f4..0bb4a57 100644 --- a/handle_group_index.go +++ b/handle_group_index.go @@ -2,29 +2,36 @@ package main import ( "net/http" - "os" - "path/filepath" - "strings" ) func handle_group_repos(w http.ResponseWriter, r *http.Request, params map[string]string) { data := make(map[string]any) group_name := params["group_name"] data["group_name"] = group_name - entries, err := os.ReadDir(filepath.Join(config.Git.Root, group_name)) + + var names []string + rows, err := database.Query(r.Context(), "SELECT r.name FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1;", group_name) if err != nil { - _, _ = w.Write([]byte("Error listing repos: " + err.Error())) + _, _ = w.Write([]byte("Error getting groups: " + err.Error())) return } + defer rows.Close() - repos := []string{} - for _, entry := range entries { - this_name := entry.Name() - if strings.HasSuffix(this_name, ".git") { - repos = append(repos, strings.TrimSuffix(this_name, ".git")) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + _, _ = w.Write([]byte("Error scanning row: " + err.Error())) + return } + names = append(names, name) } - data["repos"] = repos + + if err := rows.Err(); err != nil { + _, _ = w.Write([]byte("Error iterating over rows: " + err.Error())) + return + } + + data["repos"] = names err = templates.ExecuteTemplate(w, "group_repos", data) if err != nil { diff --git a/handle_repo_commit.go b/handle_repo_commit.go index aefd58b..b567baa 100644 --- a/handle_repo_commit.go +++ b/handle_repo_commit.go @@ -20,7 +20,7 @@ func handle_repo_commit(w http.ResponseWriter, r *http.Request, params map[strin data := make(map[string]any) group_name, repo_name, commit_id_specified_string := params["group_name"], params["repo_name"], params["commit_id"] data["group_name"], data["repo_name"] = group_name, repo_name - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_index.go b/handle_repo_index.go index 6372b03..c0bef4a 100644 --- a/handle_repo_index.go +++ b/handle_repo_index.go @@ -8,7 +8,7 @@ func handle_repo_index(w http.ResponseWriter, r *http.Request, params map[string data := make(map[string]any) group_name, repo_name := params["group_name"], params["repo_name"] data["group_name"], data["repo_name"] = group_name, repo_name - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_log.go b/handle_repo_log.go index eff5859..1c32862 100644 --- a/handle_repo_log.go +++ b/handle_repo_log.go @@ -11,7 +11,7 @@ func handle_repo_log(w http.ResponseWriter, r *http.Request, params map[string]s data := make(map[string]any) group_name, repo_name, ref_name := params["group_name"], params["repo_name"], params["ref"] data["group_name"], data["repo_name"], data["ref"] = group_name, repo_name, ref_name - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_raw.go b/handle_repo_raw.go index d335f6a..4cf7d1a 100644 --- a/handle_repo_raw.go +++ b/handle_repo_raw.go @@ -26,7 +26,7 @@ func handle_repo_raw(w http.ResponseWriter, r *http.Request, params map[string]s data["ref_type"], data["ref"], data["group_name"], data["repo_name"], data["path_spec"] = ref_type, ref_name, group_name, repo_name, path_spec - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_tree.go b/handle_repo_tree.go index f95e945..8076ed6 100644 --- a/handle_repo_tree.go +++ b/handle_repo_tree.go @@ -28,7 +28,7 @@ func handle_repo_tree(w http.ResponseWriter, r *http.Request, params map[string] } } data["ref_type"], data["ref"], data["group_name"], data["repo_name"], data["path_spec"] = ref_type, ref_name, group_name, repo_name, path_spec - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/router_ssh.go b/router_ssh.go new file mode 100644 index 0000000..6b5280b --- /dev/null +++ b/router_ssh.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "errors" + "net/url" + "strings" +) + +var err_ssh_illegal_endpoint = errors.New("Illegal endpoint during SSH access") + +func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_path string, err error) { + segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/") + + for i, segment := range segments { + var err error + segments[i], err = url.QueryUnescape(segment) + if err != nil { + return "", err + } + } + + if segments[0] == ":" { + return "", err_ssh_illegal_endpoint + } + + separator_index := -1 + for i, part := range segments { + if part == ":" { + separator_index = i + break + } + } + if segments[len(segments)-1] == "" { + segments = segments[:len(segments)-1] + } + + switch { + case separator_index == -1: + return "", err_ssh_illegal_endpoint + case len(segments) <= separator_index+2: + return "", err_ssh_illegal_endpoint + } + + group_name := segments[0] + module_type := segments[separator_index+1] + module_name := segments[separator_index+2] + switch module_type { + case "repos": + var fs_path string + err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, module_name).Scan(&fs_path) + return fs_path, err + default: + return "", err_ssh_illegal_endpoint + } +} @@ -43,12 +43,18 @@ func serve_ssh() error { return } - proc := exec.CommandContext(session.Context(), cmd[0], "/home/runxiyu/git/forge.git") + fs_path, err := get_repo_path_from_ssh_path(session.Context(), cmd[1]) + if err != nil { + fmt.Fprintln(session.Stderr(), "Error while getting repo path:", err) + return + } + + proc := exec.CommandContext(session.Context(), cmd[0], fs_path) proc.Stdin = session proc.Stdout = session proc.Stderr = session.Stderr() - err := proc.Start() + err = proc.Start() if err != nil { fmt.Fprintln(session.Stderr(), "Error while starting process:", err) return diff --git a/url_misc.go b/url_misc.go index e4bfd92..7dc0ad5 100644 --- a/url_misc.go +++ b/url_misc.go @@ -50,7 +50,10 @@ func parse_request_uri(request_uri string) (segments []string, params url.Values segments = strings.Split(strings.TrimPrefix(path, "/"), "/") for i, segment := range segments { - segments[i], _ = url.QueryUnescape(segment) + segments[i], err = url.QueryUnescape(segment) + if err != nil { + return nil, nil, misc.Wrap_one_error(err_bad_request, err) + } } params, err = url.ParseQuery(params_string) |