aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--git_misc.go17
-rw-r--r--handle_group_index.go29
-rw-r--r--handle_repo_commit.go2
-rw-r--r--handle_repo_index.go2
-rw-r--r--handle_repo_log.go2
-rw-r--r--handle_repo_raw.go2
-rw-r--r--handle_repo_tree.go2
-rw-r--r--router_ssh.go56
-rw-r--r--ssh.go10
-rw-r--r--url_misc.go5
10 files changed, 98 insertions, 29 deletions
diff --git a/git_misc.go b/git_misc.go
index 882c631..2d4c4d3 100644
--- a/git_misc.go
+++ b/git_misc.go
@@ -1,9 +1,9 @@
package main
import (
+ "context"
"errors"
"io"
- "path/filepath"
"strings"
"github.com/go-git/go-git/v5"
@@ -19,16 +19,13 @@ var (
err_getting_parent_commit_object = errors.New("Error getting parent commit object")
)
-func open_git_repo(group_name, repo_name string) (*git.Repository, error) {
- group_name, group_name_ok := misc.Sanitize_path(group_name)
- if !group_name_ok {
- return nil, err_unsafe_path
- }
- repo_name, repo_name_ok := misc.Sanitize_path(repo_name)
- if !repo_name_ok {
- return nil, err_unsafe_path
+func open_git_repo(ctx context.Context, group_name, repo_name string) (*git.Repository, error) {
+ var fs_path string
+ err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, repo_name).Scan(&fs_path)
+ if err != nil {
+ return nil, err
}
- return git.PlainOpen(filepath.Join(config.Git.Root, group_name, repo_name+".git"))
+ return git.PlainOpen(fs_path)
}
type display_git_tree_entry_t struct {
diff --git a/handle_group_index.go b/handle_group_index.go
index bc7a7f4..0bb4a57 100644
--- a/handle_group_index.go
+++ b/handle_group_index.go
@@ -2,29 +2,36 @@ package main
import (
"net/http"
- "os"
- "path/filepath"
- "strings"
)
func handle_group_repos(w http.ResponseWriter, r *http.Request, params map[string]string) {
data := make(map[string]any)
group_name := params["group_name"]
data["group_name"] = group_name
- entries, err := os.ReadDir(filepath.Join(config.Git.Root, group_name))
+
+ var names []string
+ rows, err := database.Query(r.Context(), "SELECT r.name FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1;", group_name)
if err != nil {
- _, _ = w.Write([]byte("Error listing repos: " + err.Error()))
+ _, _ = w.Write([]byte("Error getting groups: " + err.Error()))
return
}
+ defer rows.Close()
- repos := []string{}
- for _, entry := range entries {
- this_name := entry.Name()
- if strings.HasSuffix(this_name, ".git") {
- repos = append(repos, strings.TrimSuffix(this_name, ".git"))
+ for rows.Next() {
+ var name string
+ if err := rows.Scan(&name); err != nil {
+ _, _ = w.Write([]byte("Error scanning row: " + err.Error()))
+ return
}
+ names = append(names, name)
}
- data["repos"] = repos
+
+ if err := rows.Err(); err != nil {
+ _, _ = w.Write([]byte("Error iterating over rows: " + err.Error()))
+ return
+ }
+
+ data["repos"] = names
err = templates.ExecuteTemplate(w, "group_repos", data)
if err != nil {
diff --git a/handle_repo_commit.go b/handle_repo_commit.go
index aefd58b..b567baa 100644
--- a/handle_repo_commit.go
+++ b/handle_repo_commit.go
@@ -20,7 +20,7 @@ func handle_repo_commit(w http.ResponseWriter, r *http.Request, params map[strin
data := make(map[string]any)
group_name, repo_name, commit_id_specified_string := params["group_name"], params["repo_name"], params["commit_id"]
data["group_name"], data["repo_name"] = group_name, repo_name
- repo, err := open_git_repo(group_name, repo_name)
+ repo, err := open_git_repo(r.Context(), group_name, repo_name)
if err != nil {
_, _ = w.Write([]byte("Error opening repo: " + err.Error()))
return
diff --git a/handle_repo_index.go b/handle_repo_index.go
index 6372b03..c0bef4a 100644
--- a/handle_repo_index.go
+++ b/handle_repo_index.go
@@ -8,7 +8,7 @@ func handle_repo_index(w http.ResponseWriter, r *http.Request, params map[string
data := make(map[string]any)
group_name, repo_name := params["group_name"], params["repo_name"]
data["group_name"], data["repo_name"] = group_name, repo_name
- repo, err := open_git_repo(group_name, repo_name)
+ repo, err := open_git_repo(r.Context(), group_name, repo_name)
if err != nil {
_, _ = w.Write([]byte("Error opening repo: " + err.Error()))
return
diff --git a/handle_repo_log.go b/handle_repo_log.go
index eff5859..1c32862 100644
--- a/handle_repo_log.go
+++ b/handle_repo_log.go
@@ -11,7 +11,7 @@ func handle_repo_log(w http.ResponseWriter, r *http.Request, params map[string]s
data := make(map[string]any)
group_name, repo_name, ref_name := params["group_name"], params["repo_name"], params["ref"]
data["group_name"], data["repo_name"], data["ref"] = group_name, repo_name, ref_name
- repo, err := open_git_repo(group_name, repo_name)
+ repo, err := open_git_repo(r.Context(), group_name, repo_name)
if err != nil {
_, _ = w.Write([]byte("Error opening repo: " + err.Error()))
return
diff --git a/handle_repo_raw.go b/handle_repo_raw.go
index d335f6a..4cf7d1a 100644
--- a/handle_repo_raw.go
+++ b/handle_repo_raw.go
@@ -26,7 +26,7 @@ func handle_repo_raw(w http.ResponseWriter, r *http.Request, params map[string]s
data["ref_type"], data["ref"], data["group_name"], data["repo_name"], data["path_spec"] = ref_type, ref_name, group_name, repo_name, path_spec
- repo, err := open_git_repo(group_name, repo_name)
+ repo, err := open_git_repo(r.Context(), group_name, repo_name)
if err != nil {
_, _ = w.Write([]byte("Error opening repo: " + err.Error()))
return
diff --git a/handle_repo_tree.go b/handle_repo_tree.go
index f95e945..8076ed6 100644
--- a/handle_repo_tree.go
+++ b/handle_repo_tree.go
@@ -28,7 +28,7 @@ func handle_repo_tree(w http.ResponseWriter, r *http.Request, params map[string]
}
}
data["ref_type"], data["ref"], data["group_name"], data["repo_name"], data["path_spec"] = ref_type, ref_name, group_name, repo_name, path_spec
- repo, err := open_git_repo(group_name, repo_name)
+ repo, err := open_git_repo(r.Context(), group_name, repo_name)
if err != nil {
_, _ = w.Write([]byte("Error opening repo: " + err.Error()))
return
diff --git a/router_ssh.go b/router_ssh.go
new file mode 100644
index 0000000..6b5280b
--- /dev/null
+++ b/router_ssh.go
@@ -0,0 +1,56 @@
+package main
+
+import (
+ "context"
+ "errors"
+ "net/url"
+ "strings"
+)
+
+var err_ssh_illegal_endpoint = errors.New("Illegal endpoint during SSH access")
+
+func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_path string, err error) {
+ segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/")
+
+ for i, segment := range segments {
+ var err error
+ segments[i], err = url.QueryUnescape(segment)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ if segments[0] == ":" {
+ return "", err_ssh_illegal_endpoint
+ }
+
+ separator_index := -1
+ for i, part := range segments {
+ if part == ":" {
+ separator_index = i
+ break
+ }
+ }
+ if segments[len(segments)-1] == "" {
+ segments = segments[:len(segments)-1]
+ }
+
+ switch {
+ case separator_index == -1:
+ return "", err_ssh_illegal_endpoint
+ case len(segments) <= separator_index+2:
+ return "", err_ssh_illegal_endpoint
+ }
+
+ group_name := segments[0]
+ module_type := segments[separator_index+1]
+ module_name := segments[separator_index+2]
+ switch module_type {
+ case "repos":
+ var fs_path string
+ err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, module_name).Scan(&fs_path)
+ return fs_path, err
+ default:
+ return "", err_ssh_illegal_endpoint
+ }
+}
diff --git a/ssh.go b/ssh.go
index e1b9ff1..4d49fc9 100644
--- a/ssh.go
+++ b/ssh.go
@@ -43,12 +43,18 @@ func serve_ssh() error {
return
}
- proc := exec.CommandContext(session.Context(), cmd[0], "/home/runxiyu/git/forge.git")
+ fs_path, err := get_repo_path_from_ssh_path(session.Context(), cmd[1])
+ if err != nil {
+ fmt.Fprintln(session.Stderr(), "Error while getting repo path:", err)
+ return
+ }
+
+ proc := exec.CommandContext(session.Context(), cmd[0], fs_path)
proc.Stdin = session
proc.Stdout = session
proc.Stderr = session.Stderr()
- err := proc.Start()
+ err = proc.Start()
if err != nil {
fmt.Fprintln(session.Stderr(), "Error while starting process:", err)
return
diff --git a/url_misc.go b/url_misc.go
index e4bfd92..7dc0ad5 100644
--- a/url_misc.go
+++ b/url_misc.go
@@ -50,7 +50,10 @@ func parse_request_uri(request_uri string) (segments []string, params url.Values
segments = strings.Split(strings.TrimPrefix(path, "/"), "/")
for i, segment := range segments {
- segments[i], _ = url.QueryUnescape(segment)
+ segments[i], err = url.QueryUnescape(segment)
+ if err != nil {
+ return nil, nil, misc.Wrap_one_error(err_bad_request, err)
+ }
}
params, err = url.ParseQuery(params_string)