diff options
author | Junegunn Choi <junegunn.c@gmail.com> | 2022-12-22 20:44:49 +0900 |
---|---|---|
committer | Junegunn Choi <junegunn.c@gmail.com> | 2022-12-22 20:44:49 +0900 |
commit | 1a9761736ed2190aa7536d4b6754f617f46cfd9c (patch) | |
tree | 047e9e4f4f9ff94b71569479db926a7e78cae9d4 /src/server.go | |
parent | fd1f7665a77f3bd062586f0f3b54d4ff9863dd4b (diff) |
Add time and size limit to remote requests
Diffstat (limited to 'src/server.go')
-rw-r--r-- | src/server.go | 46 |
1 files changed, 29 insertions, 17 deletions
diff --git a/src/server.go b/src/server.go index 75af23ff..421bc20b 100644 --- a/src/server.go +++ b/src/server.go @@ -8,12 +8,15 @@ import ( "net" "strconv" "strings" + "time" ) const ( - crlf = "\r\n" - httpOk = "HTTP/1.1 200 OK" + crlf - httpBadRequest = "HTTP/1.1 400 Bad Request" + crlf + crlf = "\r\n" + httpOk = "HTTP/1.1 200 OK" + crlf + httpBadRequest = "HTTP/1.1 400 Bad Request" + crlf + httpReadTimeout = 10 * time.Second + maxContentLength = 1024 * 1024 ) func startHttpServer(port int, channel chan []*action) error { @@ -52,14 +55,13 @@ func startHttpServer(port int, channel chan []*action) error { // * --listen with net/http: 5.7MB // * --listen w/o net/http: 3.3MB func handleHttpRequest(conn net.Conn, channel chan []*action) string { - line := 0 - headerRead := false contentLength := 0 body := "" bad := func(message string) string { message += "\n" return httpBadRequest + fmt.Sprintf("Content-Length: %d%s", len(message), crlf+crlf+message) } + conn.SetReadDeadline(time.Now().Add(httpReadTimeout)) scanner := bufio.NewScanner(conn) scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) { found := bytes.Index(data, []byte(crlf)) @@ -73,31 +75,41 @@ func handleHttpRequest(conn net.Conn, channel chan []*action) string { return 0, nil, nil }) + section := 0 for scanner.Scan() { text := scanner.Text() - if line == 0 && !strings.HasPrefix(text, "POST / HTTP") { - return bad("invalid request method") - } - if text == crlf { - headerRead = true - } - if !headerRead { + switch section { + case 0: + if !strings.HasPrefix(text, "POST / HTTP") { + return bad("invalid request method") + } + section++ + case 1: + if text == crlf { + if contentLength == 0 { + return bad("content-length header missing") + } + section++ + continue + } pair := strings.SplitN(text, ":", 2) if len(pair) == 2 && strings.ToLower(pair[0]) == "content-length" { length, err := strconv.Atoi(strings.TrimSpace(pair[1])) - if err != nil { + if err != nil || length <= 0 || length > maxContentLength { return bad("invalid content length") } contentLength = length } - } else if contentLength <= 0 { - break - } else { + case 2: body += text } - line++ } + if len(body) < contentLength { + return bad("incomplete request") + } + body = body[:contentLength] + errorMessage := "" actions := parseSingleActionList(strings.Trim(string(body), "\r\n"), func(message string) { errorMessage = message |