aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md10
-rw-r--r--main.go224
-rw-r--r--main_test.go90
3 files changed, 313 insertions, 11 deletions
diff --git a/README.md b/README.md
index c95ab3e..7e9a48e 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,15 @@ scraper bots.
## Credits
-Inspired by [Anubis](https://github.com/TecharoHQ/anubis).
+Inspired by [Anubis](https://github.com/TecharoHQ/anubis). But much simpler.
+
+## Bugs
+
+- If a user is attempting to submit a POST request but their powxy cookie is
+ invalid, powxy would redirect them to a challenge, and their POST data will
+ be lost.
+- It does not work when duplex connections are needed, e.g. with Git's Smart
+ HTTP protocol.
## License
diff --git a/main.go b/main.go
index 3a0df60..3235a8b 100644
--- a/main.go
+++ b/main.go
@@ -1,33 +1,237 @@
package main
import (
+ "crypto/rand"
+ "crypto/hmac"
+ "crypto/sha256"
+ "crypto/subtle"
+ "encoding/base64"
+ "encoding/binary"
+ "errors"
+ "flag"
+ "html/template"
"io"
"log"
"maps"
"net/http"
+ "strings"
+ "time"
+ "unsafe"
)
+var (
+ difficulty uint
+ listenAddr string
+ destHost string
+)
+
+func init() {
+ flag.UintVar(&difficulty, "difficulty", 20, "leading zero bits required for the challenge")
+ flag.StringVar(&listenAddr, "listen", ":8081", "address to listen on")
+ flag.StringVar(&destHost, "host", "127.0.0.1:8080", "destination host to proxy to")
+ flag.Parse()
+}
+
var client = http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse },
}
+var (
+ privkey = make([]byte, 32)
+ privkeyHash = make([]byte, 0, sha256.Size)
+)
+
+func init() {
+ if _, err := rand.Read(privkey); err != nil {
+ log.Fatal(err)
+ }
+ h := sha256.New()
+ h.Write(privkey)
+ privkeyHash = h.Sum(nil)
+}
+
+var tmpl *template.Template
+
+func init() {
+ var err error
+ tmpl, err = template.New("powxy").Parse(`
+<!DOCTYPE html>
+<html>
+<head>
+<title>Proof of Work Challenge</title>
+</head>
+<body>
+<h1>Proof of Work Challenge</h1>
+<p>You must complete this proof of work challenge before you could access this site.</p>
+{{- if .Message }}
+<p><strong>{{ .Message }}</strong></p>
+{{- end }}
+<p>Select a value, such that when it is appended to the decoded form of the following base64 string, and a SHA-256 hash is taken as a whole, the first {{ .NeedBits }} bits of the SHA-256 hash are zeros. Within one octet, higher bits are considered to be in front of lower bits.</p>
+<p>{{ .UnsignedTokenBase64 }}</p>
+<form method="POST">
+<p>
+Encode your selected value in base64 and submit it below:
+</p>
+<input name="powxy" type="text" />
+<input type="submit" value="Submit" />
+</form>
+</body>
+</html>
+`)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
+
+type tparams struct {
+ UnsignedTokenBase64 string
+ NeedBits uint
+ Message string
+}
+
func main() {
- log.Fatal(http.ListenAndServe(":8081", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
+ log.Fatal(http.ListenAndServe(listenAddr, http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log.Println(request.RemoteAddr, request.RequestURI)
- request.Host = "127.0.0.1:8080"
- request.URL.Host = "127.0.0.1:8080"
- request.URL.Scheme = "http"
- request.RequestURI = ""
+ cookie, err := request.Cookie("powxy")
+ if err != nil {
+ if !errors.Is(err, http.ErrNoCookie) {
+ http.Error(writer, "error fetching cookie", http.StatusInternalServerError)
+ }
+ }
+
+ expectedToken := makeSignedToken(request)
+
+ if validateCookie(cookie, expectedToken) {
+ proxyRequest(writer, request)
+ return
+ }
+
+ authPage := func(message string) {
+ tmpl.Execute(writer, tparams{
+ UnsignedTokenBase64: base64.StdEncoding.EncodeToString(expectedToken[:sha256.Size]),
+ Message: message,
+ NeedBits: difficulty,
+ })
+ }
+
+ if request.ParseForm() != nil {
+ authPage("You submitted a malformed form.")
+ return
+ }
+
+ formValues, ok := request.PostForm["powxy"]
+ if !ok {
+ authPage("")
+ return
+ } else if len(formValues) != 1 {
+ authPage("You submitted an invalid number of form values.")
+ return
+ }
- response, err := client.Do(request)
+ nonce, err := base64.StdEncoding.DecodeString(formValues[0])
if err != nil {
- http.Error(writer, err.Error(), http.StatusBadGateway)
+ authPage("Your submission was improperly encoded.")
+ return
+ }
+
+ h := sha256.New()
+ h.Write(expectedToken[:sha256.Size])
+ h.Write(nonce)
+ ck := h.Sum(nil)
+ if !validateBitZeros(ck, difficulty) {
+ authPage("Your submission was incorrect.")
return
}
+
+ http.SetCookie(writer, &http.Cookie{
+ Name: "powxy",
+ Value: base64.StdEncoding.EncodeToString(expectedToken),
+ })
- maps.Copy(writer.Header(), response.Header)
- writer.WriteHeader(response.StatusCode)
- io.Copy(writer, response.Body)
+ http.Redirect(writer, request, "", http.StatusSeeOther)
})))
}
+
+func validateCookie(cookie *http.Cookie, expectedToken []byte) bool {
+ if cookie == nil {
+ return false
+ }
+
+ gotToken, err := base64.StdEncoding.DecodeString(cookie.Value)
+ if err != nil {
+ return false
+ }
+
+ return subtle.ConstantTimeCompare(gotToken, expectedToken) == 1
+}
+func makeSignedToken(request *http.Request) []byte {
+ buf := make([]byte, 0, 2 * sha256.Size)
+
+ timeBuf := make([]byte, binary.MaxVarintLen64)
+ binary.PutVarint(timeBuf, time.Now().Unix() / 604800)
+
+ remoteAddr, _, _ := strings.Cut(request.RemoteAddr, ":")
+
+ h := sha256.New()
+ h.Write(timeBuf)
+ h.Write(stringToBytes(remoteAddr))
+ h.Write(stringToBytes(request.Header.Get("User-Agent")))
+ h.Write(stringToBytes(request.Header.Get("Accept-Encoding")))
+ h.Write(stringToBytes(request.Header.Get("Accept-Language")))
+ h.Write(privkeyHash)
+ buf = h.Sum(buf)
+ if len(buf) != sha256.Size {
+ panic("unexpected buffer length after hashing contents")
+ }
+
+ mac := hmac.New(sha256.New, privkey)
+ mac.Write(buf)
+ buf = mac.Sum(buf)
+ if len(buf) != 2 * sha256.Size {
+ panic("unexpected buffer length after hmac")
+ }
+
+ return buf
+}
+
+func proxyRequest(writer http.ResponseWriter, request *http.Request) {
+ request.Host = destHost
+ request.URL.Host = destHost
+ request.URL.Scheme = "http"
+ request.RequestURI = ""
+
+ response, err := client.Do(request)
+ if err != nil {
+ http.Error(writer, err.Error(), http.StatusBadGateway)
+ return
+ }
+
+ maps.Copy(writer.Header(), response.Header)
+ writer.WriteHeader(response.StatusCode)
+ io.Copy(writer, response.Body)
+}
+
+func stringToBytes(s string) (bytes []byte) {
+ return unsafe.Slice(unsafe.StringData(s), len(s))
+}
+
+func validateBitZeros(bs []byte, n uint) bool {
+ q := n / 8
+ r := n % 8
+
+ for i := uint(0); i < q; i++ {
+ if bs[i] != 0 {
+ return false
+ }
+ }
+
+ if r > 0 {
+ mask := byte(0xFF << (8 - r))
+ if bs[q]&mask != 0 {
+ return false
+ }
+ }
+
+ return true
+}
diff --git a/main_test.go b/main_test.go
new file mode 100644
index 0000000..394d407
--- /dev/null
+++ b/main_test.go
@@ -0,0 +1,90 @@
+package main
+
+import (
+ "testing"
+)
+
+func TestValidateBitZeros(t *testing.T) {
+ tests := []struct {
+ name string
+ bs []byte
+ n uint
+ expected bool
+ }{
+ {
+ name: "First 8 bits are zeros",
+ bs: []byte{0x00, 0x01},
+ n: 8,
+ expected: true,
+ },
+ {
+ name: "First 8 bits are not all zeros",
+ bs: []byte{0x01, 0x00},
+ n: 8,
+ expected: false,
+ },
+ {
+ name: "First 16 bits are zeros",
+ bs: []byte{0x00, 0x00, 0x01},
+ n: 16,
+ expected: true,
+ },
+ {
+ name: "First 16 bits are not all zeros",
+ bs: []byte{0x01, 0x00, 0x00},
+ n: 16,
+ expected: false,
+ },
+ {
+ name: "First 9 bits are zeros",
+ bs: []byte{0x00, 0x01},
+ n: 9,
+ expected: true,
+ },
+ {
+ name: "First 9 bits are not all zeros",
+ bs: []byte{0x01, 0x01},
+ n: 9,
+ expected: false,
+ },
+ {
+ name: "First 10 bits are zeros",
+ bs: []byte{0x00, 0x20},
+ n: 10,
+ expected: true,
+ },
+ {
+ name: "First 10 bits are not all zeros",
+ bs: []byte{0x00, 0x40},
+ n: 10,
+ expected: false,
+ },
+ {
+ name: "First 24 bits are zeros",
+ bs: []byte{0x00, 0x00, 0x00, 0x01},
+ n: 24,
+ expected: true,
+ },
+ {
+ name: "First 24 bits are not all zeros",
+ bs: []byte{0x00, 0x01, 0x00, 0x00},
+ n: 24,
+ expected: false,
+ },
+ {
+ name: "Checking zero bits",
+ bs: []byte{0xFF, 0xFF},
+ n: 0,
+ expected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := validateBitZeros(tt.bs, tt.n)
+ if got != tt.expected {
+ t.Errorf("validateBitZeros(%v, %v) = %v; want %v", tt.bs, tt.n, got, tt.expected)
+ }
+ })
+ }
+}