aboutsummaryrefslogtreecommitdiff
path: root/ssh_server.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--ssh_server.go124
1 files changed, 83 insertions, 41 deletions
diff --git a/ssh_server.go b/ssh_server.go
index 8eaaebd..fb23db6 100644
--- a/ssh_server.go
+++ b/ssh_server.go
@@ -1,56 +1,98 @@
package main
import (
- "context"
- "errors"
- "net/url"
- "strings"
-)
+ "fmt"
+ "net"
+ "os"
+ "os/exec"
-var err_ssh_illegal_endpoint = errors.New("Illegal endpoint during SSH access")
+ glider_ssh "github.com/gliderlabs/ssh"
+ "go.lindenii.runxiyu.org/lindenii-common/clog"
+ go_ssh "golang.org/x/crypto/ssh"
+)
-func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_path string, err error) {
- segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/")
+var (
+ server_public_key_string string
+ server_public_key_fingerprint string
+ server_public_key go_ssh.PublicKey
+)
- for i, segment := range segments {
- var err error
- segments[i], err = url.PathUnescape(segment)
- if err != nil {
- return "", err
- }
+func serve_ssh() error {
+ host_key_bytes, err := os.ReadFile(config.SSH.Key)
+ if err != nil {
+ return err
}
- if segments[0] == ":" {
- return "", err_ssh_illegal_endpoint
+ host_key, err := go_ssh.ParsePrivateKey(host_key_bytes)
+ if err != nil {
+ return err
}
- separator_index := -1
- for i, part := range segments {
- if part == ":" {
- separator_index = i
- break
- }
- }
- if segments[len(segments)-1] == "" {
- segments = segments[:len(segments)-1]
- }
+ server_public_key = host_key.PublicKey()
+ server_public_key_string = string(go_ssh.MarshalAuthorizedKey(server_public_key))
+ server_public_key_fingerprint = string(go_ssh.FingerprintSHA256(server_public_key))
+
+ server := &glider_ssh.Server{
+ Handler: func(session glider_ssh.Session) {
+ client_public_key := session.PublicKey()
+ var client_public_key_string string
+ if client_public_key != nil {
+ client_public_key_string = string(go_ssh.MarshalAuthorizedKey(client_public_key))
+ }
+ _ = client_public_key_string
+
+ cmd := session.Command()
- switch {
- case separator_index == -1:
- return "", err_ssh_illegal_endpoint
- case len(segments) <= separator_index+2:
- return "", err_ssh_illegal_endpoint
+ if len(cmd) < 2 {
+ fmt.Fprintln(session.Stderr(), "Insufficient arguments")
+ return
+ }
+
+ if cmd[0] != "git-upload-pack" {
+ fmt.Fprintln(session.Stderr(), "Unsupported command")
+ return
+ }
+
+ fs_path, err := get_repo_path_from_ssh_path(session.Context(), cmd[1])
+ if err != nil {
+ fmt.Fprintln(session.Stderr(), "Error while getting repo path:", err)
+ return
+ }
+
+ proc := exec.CommandContext(session.Context(), cmd[0], fs_path)
+ proc.Stdin = session
+ proc.Stdout = session
+ proc.Stderr = session.Stderr()
+
+ err = proc.Start()
+ if err != nil {
+ fmt.Fprintln(session.Stderr(), "Error while starting process:", err)
+ return
+ }
+ err = proc.Wait()
+ if exit_error, ok := err.(*exec.ExitError); ok {
+ fmt.Fprintln(session.Stderr(), "Process exited with error", exit_error.ExitCode())
+ } else if err != nil {
+ fmt.Fprintln(session.Stderr(), "Error while waiting for process:", err)
+ }
+ },
+ PublicKeyHandler: func(ctx glider_ssh.Context, key glider_ssh.PublicKey) bool { return true },
+ KeyboardInteractiveHandler: func(ctx glider_ssh.Context, challenge go_ssh.KeyboardInteractiveChallenge) bool { return true },
}
- group_name := segments[0]
- module_type := segments[separator_index+1]
- module_name := segments[separator_index+2]
- switch module_type {
- case "repos":
- var fs_path string
- err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, module_name).Scan(&fs_path)
- return fs_path, err
- default:
- return "", err_ssh_illegal_endpoint
+ server.AddHostKey(host_key)
+
+ listener, err := net.Listen(config.SSH.Net, config.SSH.Addr)
+ if err != nil {
+ return err
}
+
+ go func() {
+ err = server.Serve(listener)
+ if err != nil {
+ clog.Fatal(1, "Serving SSH: "+err.Error())
+ }
+ }()
+
+ return nil
}