Pull request: 2305 limit message size

Merge in DNS/adguard-home from 2305-limit-message-size to master

Closes #2305.

Squashed commit of the following:

commit 6edd1e0521277a680f0053308efcf3d9cacc8e62
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Nov 23 14:03:36 2020 +0300

    aghio: fix final inaccuracies

commit 4dd382aaf25132b31eb269749a2cd36daf0cb792
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Nov 23 13:59:10 2020 +0300

    all: improve code quality

commit 060f923f6023d0e6f26441559b7023d5e5f96843
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Nov 23 13:10:57 2020 +0300

    aghio: add validation to constructor

commit f57a2f596f5dc578548241c315c68dce7fc93905
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:19:26 2020 +0300

    all: fix minor inaccuracies

commit 93462c71725d3d00655a4bd565b77e64451fff60
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:13:23 2020 +0300

    home: make test name follow convention

commit 4922986ad84481b054479c43b4133a1b97bee86b
Merge: 1f5472abc 046ec13fd
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:09:01 2020 +0300

    Merge branch 'master' into 2305-limit-message-size

commit 1f5472abcfa7427f389825fc59eb4253e1e2bfb7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:08:21 2020 +0300

    aghio: improve readability

commit 60dc706b093fa22bbf62f13b2341934364ddc4df
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 18:44:08 2020 +0300

    home: cover middleware with test

commit bedf436b947ca1fa4493af2fc94f1f40beec7c35
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 17:10:23 2020 +0300

    aghio: improved error informativeness

commit 682c5da9f21fa330fb3536bb1c112129c91b9990
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 13:37:51 2020 +0300

    all: limit readers for ReadAll dealing with miscellanious data.

commit 78c6dd8d90a0a43fe6ee3f9ed4d5fc637b15ba74
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Nov 19 20:07:43 2020 +0300

    all: handle ReadAll calls dealing with request's bodies.

commit bfe1a6faf6468eb44515e2b0ecffa8c51f90b7e8
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Nov 19 17:25:34 2020 +0300

    home: add middlewares

commit bbd1d491b318e6ba07f8af23ad546183383783a8
Merge: 7b77c2cad 62a8fe0b7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Nov 19 16:44:04 2020 +0300

    Merge branch 'master' into 2305-limit-message-size

commit 7b77c2cad03154177392460982e1d73ee2a30177
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Nov 17 15:33:33 2020 +0300

    aghio: create package
This commit is contained in:
Ainar Garipov 2020-11-23 14:14:08 +03:00
parent 046ec13fdc
commit c129361e55
15 changed files with 413 additions and 64 deletions

View File

@ -9,6 +9,19 @@ and this project adheres to
## [Unreleased] ## [Unreleased]
### Added
- HTTP API request body size limit [#2305].
[#2305]: https://github.com/AdguardTeam/AdGuardHome/issues/2305
### Changed
- Various internal improvements ([#2271], [#2297]).
[#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271
[#2297]: https://github.com/AdguardTeam/AdGuardHome/issues/2297
## [v0.104.3] - 2020-11-19 ## [v0.104.3] - 2020-11-19

View File

@ -1,4 +1,4 @@
# AdGuardHome Developer Guidelines # *AdGuardHome* Developer Guidelines
As of **2020-11-20**, this document is still a work-in-progress. Some of the As of **2020-11-20**, this document is still a work-in-progress. Some of the
rules aren't enforced, and others might change. Still, this is a good place to rules aren't enforced, and others might change. Still, this is a good place to
@ -6,7 +6,11 @@ find out about how we **want** our code to look like.
The rules are mostly sorted in the alphabetical order. The rules are mostly sorted in the alphabetical order.
## Git ## *Git*
* Call your branches either `NNNN-fix-foo` (where `NNNN` is the ID of the
*GitHub* issue you worked on in this branch) or just `fix-foo` if there was
no *GitHub* issue.
* Follow the commit message header format: * Follow the commit message header format:
@ -22,9 +26,10 @@ The rules are mostly sorted in the alphabetical order.
* Only use lowercase letters in your commit message headers. The rest of the * Only use lowercase letters in your commit message headers. The rest of the
message should follow the plain text conventions below. message should follow the plain text conventions below.
The only exception are direct mentions of identifiers from the source code. The only exceptions are direct mentions of identifiers from the source code
and filenames like `HACKING.md`.
## Go ## *Go*
* <https://github.com/golang/go/wiki/CodeReviewComments>. * <https://github.com/golang/go/wiki/CodeReviewComments>.
@ -32,6 +37,9 @@ The rules are mostly sorted in the alphabetical order.
* <https://go-proverbs.github.io/> * <https://go-proverbs.github.io/>
* Add an empty line before `break`, `continue`, and `return`, unless it's the
only statement in that block.
* Avoid `init` and use explicit initialization functions instead. * Avoid `init` and use explicit initialization functions instead.
* Avoid `new`, especially with structs. * Avoid `new`, especially with structs.
@ -53,6 +61,18 @@ The rules are mostly sorted in the alphabetical order.
* Eschew external dependencies, including transitive, unless * Eschew external dependencies, including transitive, unless
absolutely necessary. absolutely necessary.
* Name benchmarks and tests using the same convention as examples. For
example:
```go
func TestFunction(t *testing.T) { /* … */ }
func TestFunction_suffix(t *testing.T) { /* … */ }
func TestType_Method(t *testing.T) { /* … */ }
func TestType_Method_suffix(t *testing.T) { /* … */ }
```
* Name the deferred errors (e.g. when closing something) `cerr`.
* No `goto`. * No `goto`.
* No shadowing, since it can often lead to subtle bugs, especially with * No shadowing, since it can often lead to subtle bugs, especially with
@ -103,9 +123,9 @@ The rules are mostly sorted in the alphabetical order.
[constant errors]: https://dave.cheney.net/2016/04/07/constant-errors [constant errors]: https://dave.cheney.net/2016/04/07/constant-errors
[Linus said]: https://www.kernel.org/doc/html/v4.17/process/coding-style.html#indentation [Linus said]: https://www.kernel.org/doc/html/v4.17/process/coding-style.html#indentation
## Markdown ## *Markdown*
* **TODO(a.garipov):** Define our Markdown conventions. * **TODO(a.garipov):** Define our *Markdown* conventions.
## Text, Including Comments ## Text, Including Comments
@ -128,7 +148,7 @@ The rules are mostly sorted in the alphabetical order.
* Use double spacing between sentences to make sentence borders more clear. * Use double spacing between sentences to make sentence borders more clear.
* Use the serial comma (a.k.a. Oxford comma) to improve comprehension, * Use the serial comma (a.k.a. *Oxford* comma) to improve comprehension,
decrease ambiguity, and use a common standard. decrease ambiguity, and use a common standard.
* Write todos like this: * Write todos like this:
@ -143,16 +163,16 @@ The rules are mostly sorted in the alphabetical order.
// TODO(usr1, usr2): Fix the frobulation issue. // TODO(usr1, usr2): Fix the frobulation issue.
``` ```
## YAML ## *YAML*
* **TODO(a.garipov):** Define naming conventions for schema names in our * **TODO(a.garipov):** Define naming conventions for schema names in our
OpenAPI YAML file. And just generally OpenAPI conventions. *OpenAPI* *YAML* file. And just generally OpenAPI conventions.
* **TODO(a.garipov):** Find a YAML formatter or write our own. * **TODO(a.garipov):** Find a *YAML* formatter or write our own.
* All strings, including keys, must be quoted. Reason: the [NO-rway Law]. * All strings, including keys, must be quoted. Reason: the [*NO-rway Law*].
* Indent with two (**2**) spaces. YAML documents can get pretty * Indent with two (**2**) spaces. *YAML* documents can get pretty
deeply-nested. deeply-nested.
* No extra indentation in multiline arrays: * No extra indentation in multiline arrays:
@ -170,4 +190,4 @@ The rules are mostly sorted in the alphabetical order.
* Use `>` for multiline strings, unless you need to keep the line breaks. * Use `>` for multiline strings, unless you need to keep the line breaks.
[NO-rway Law]: https://news.ycombinator.com/item?id=17359376 [*NO-rway Law*]: https://news.ycombinator.com/item?id=17359376

View File

@ -0,0 +1,59 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
// dealing with agherr package.
type limitedReadCloser struct {
limit int64
n int64
rc io.ReadCloser
}
// Read implements Reader interface.
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
if lrc.n == 0 {
return 0, &LimitReachedError{
Limit: lrc.limit,
}
}
if int64(len(p)) > lrc.n {
p = p[0:lrc.n]
}
n, err = lrc.rc.Read(p)
lrc.n -= int64(n)
return n, err
}
// Close implements Closer interface.
func (lrc *limitedReadCloser) Close() error {
return lrc.rc.Close()
}
// LimitReadCloser wraps ReadCloser to make it's Reader stop with
// ErrLimitReached after n bytes read.
func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n)
}
return &limitedReadCloser{
limit: n,
n: n,
rc: rc,
}, nil
}

View File

@ -0,0 +1,108 @@
package aghio
import (
"fmt"
"io"
"io/ioutil"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLimitReadCloser(t *testing.T) {
testCases := []struct {
name string
n int64
want error
}{{
name: "positive",
n: 1,
want: nil,
}, {
name: "zero",
n: 0,
want: nil,
}, {
name: "negative",
n: -1,
want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := LimitReadCloser(nil, tc.n)
assert.Equal(t, tc.want, err)
})
}
}
func TestLimitedReadCloser_Read(t *testing.T) {
testCases := []struct {
name string
limit int64
rStr string
want int
err error
}{{
name: "perfectly_match",
limit: 3,
rStr: "abc",
want: 3,
err: nil,
}, {
name: "eof",
limit: 3,
rStr: "",
want: 0,
err: io.EOF,
}, {
name: "limit_reached",
limit: 0,
rStr: "abc",
want: 0,
err: &LimitReachedError{
Limit: 0,
},
}, {
name: "truncated",
limit: 2,
rStr: "abc",
want: 2,
err: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
readCloser := ioutil.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)
lreader, err := LimitReadCloser(readCloser, tc.limit)
assert.Nil(t, err)
n, err := lreader.Read(buf)
assert.Equal(t, n, tc.want)
assert.Equal(t, tc.err, err)
})
}
}
func TestLimitedReadCloser_LimitReachedError(t *testing.T) {
testCases := []struct {
name string
want string
err error
}{{
name: "simplest",
want: "attempted to read more than 0 bytes",
err: &LimitReachedError{
Limit: 0,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, tc.err.Error())
})
}
}

View File

@ -299,6 +299,7 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
// . Check if a static IP is configured for the network interface // . Check if a static IP is configured for the network interface
// Respond with results // Respond with results
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err) msg := fmt.Sprintf("failed to read request body: %s", err)

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"unsafe" "unsafe"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -18,8 +19,10 @@ var GLMode bool
var glFilePrefix = "/tmp/gl_token_" var glFilePrefix = "/tmp/gl_token_"
const glTokenTimeoutSeconds = 3600 const (
const glCookieName = "Admin-Token" glTokenTimeoutSeconds = 3600
glCookieName = "Admin-Token"
)
func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool { func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool {
if !GLMode { if !GLMode {
@ -71,14 +74,28 @@ func archIsLittleEndian() bool {
return (b == 0x04) return (b == 0x04)
} }
// MaxFileSize is a maximum file length in bytes.
const MaxFileSize = 1024 * 1024
func glGetTokenDate(file string) uint32 { func glGetTokenDate(file string) uint32 {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
log.Error("os.Open: %s", err) log.Error("os.Open: %s", err)
return 0 return 0
} }
defer f.Close()
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize)
if err != nil {
log.Error("LimitReadCloser: %s", err)
return 0
}
defer fileReadCloser.Close()
var dateToken uint32 var dateToken uint32
bs, err := ioutil.ReadAll(f)
// This use of ReadAll is now safe, because we limited reader.
bs, err := ioutil.ReadAll(fileReadCloser)
if err != nil { if err != nil {
log.Error("ioutil.ReadAll: %s", err) log.Error("ioutil.ReadAll: %s", err)
return 0 return 0

View File

@ -3,7 +3,6 @@ package home
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
) )
@ -150,16 +149,11 @@ func clientHostToJSON(ip string, ch ClientHost) clientJSON {
// Add a new client // Add a new client
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
cj := clientJSON{} cj := clientJSON{}
err = json.Unmarshal(body, &cj) err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err) httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
return return
} }
@ -183,16 +177,17 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
// Remove client // Remove client
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
return return
} }
cj := clientJSON{} if len(cj.Name) == 0 {
err = json.Unmarshal(body, &cj) httpError(w, http.StatusBadRequest, "client's name must be non-empty")
if err != nil || len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return return
} }
@ -211,18 +206,14 @@ type updateJSON struct {
// Update client's properties // Update client's properties
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) dj := updateJSON{}
err := json.NewDecoder(r.Body).Decode(&dj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
return return
} }
var dj updateJSON
err = json.Unmarshal(body, &dj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
if len(dj.Name) == 0 { if len(dj.Name) == 0 {
httpError(w, http.StatusBadRequest, "Invalid request") httpError(w, http.StatusBadRequest, "Invalid request")
return return

View File

@ -214,6 +214,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
} }
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)

View File

@ -66,6 +66,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
} }
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err) msg := fmt.Sprintf("failed to read request body: %s", err)

View File

@ -0,0 +1,59 @@
package home
import (
"net/http"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/log"
)
// middlerware is a wrapper function signature.
type middleware func(http.Handler) http.Handler
// withMiddlewares consequently wraps h with all the middlewares.
func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Handler) {
wrapped = h
for _, mw := range middlewares {
wrapped = mw(wrapped)
}
return wrapped
}
// RequestBodySizeLimit is maximum request body length in bytes.
const RequestBodySizeLimit = 64 * 1024
// limitRequestBody wraps underlying handler h, making it's request's body Read
// method limited.
func limitRequestBody(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
if err != nil {
log.Error("limitRequestBody: %s", err)
return
}
h.ServeHTTP(w, r)
})
}
// TODO(a.garipov): We currently have to use this, because everything registers
// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP
// API initialization process and stop using the gosh darn http.DefaultServeMux
// for anything at all. Gosh darn global variables.
func filterPProf(h http.Handler) (filtered http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/debug/pprof") {
http.NotFound(w, r)
return
}
h.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,64 @@
package home
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/stretchr/testify/assert"
)
func TestLimitRequestBody(t *testing.T) {
errReqLimitReached := &aghio.LimitReachedError{
Limit: RequestBodySizeLimit,
}
testCases := []struct {
name string
body string
want []byte
wantErr error
}{{
name: "not_so_big",
body: "somestr",
want: []byte("somestr"),
wantErr: nil,
}, {
name: "so_big",
body: string(make([]byte, RequestBodySizeLimit+1)),
want: make([]byte, RequestBodySizeLimit),
wantErr: errReqLimitReached,
}, {
name: "empty",
body: "",
want: []byte(nil),
wantErr: nil,
}}
makeHandler := func(err *error) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte
b, *err = ioutil.ReadAll(r.Body)
w.Write(b)
})
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var err error
handler := makeHandler(&err)
lim := limitRequestBody(handler)
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
res := httptest.NewRecorder()
lim.ServeHTTP(res, req)
assert.Equal(t, tc.want, res.Body.Bytes())
assert.Equal(t, tc.wantErr, err)
})
}
}

View File

@ -7,7 +7,6 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/util"
@ -142,7 +141,7 @@ func (web *Web) Start() {
web.httpServer = &http.Server{ web.httpServer = &http.Server{
ErrorLog: web.errLogger, ErrorLog: web.errLogger,
Addr: address, Addr: address,
Handler: filterPPROF(http.DefaultServeMux), Handler: withMiddlewares(http.DefaultServeMux, filterPProf, limitRequestBody),
} }
err := web.httpServer.ListenAndServe() err := web.httpServer.ListenAndServe()
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
@ -153,22 +152,6 @@ func (web *Web) Start() {
} }
} }
// TODO(a.garipov): We currently have to use this, because everything registers
// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP
// API initialization process and stop using the gosh darn http.DefaultServeMux
// for anything at all. Gosh darn global variables.
func filterPPROF(h http.Handler) (filtered http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/debug/pprof") {
http.NotFound(w, r)
return
}
h.ServeHTTP(w, r)
})
}
// Close - stop HTTP server, possibly waiting for all active connections to be closed // Close - stop HTTP server, possibly waiting for all active connections to be closed
func (web *Web) Close() { func (web *Web) Close() {
log.Info("Stopping HTTP server...") log.Info("Stopping HTTP server...")

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
@ -115,6 +116,9 @@ func whoisParse(data string) map[string]string {
return m return m
} }
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response // Send request to a server and receive the response
func (w *Whois) query(target, serverAddr string) (string, error) { func (w *Whois) query(target, serverAddr string) (string, error) {
addr, _, _ := net.SplitHostPort(serverAddr) addr, _, _ := net.SplitHostPort(serverAddr)
@ -127,13 +131,20 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
} }
defer conn.Close() defer conn.Close()
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize)
if err != nil {
return "", err
}
defer connReadCloser.Close()
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond)) _ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n")) _, err = conn.Write([]byte(target + "\r\n"))
if err != nil { if err != nil {
return "", err return "", err
} }
data, err := ioutil.ReadAll(conn) // This use of ReadAll is now safe, because we limited the conn Reader.
data, err := ioutil.ReadAll(connReadCloser)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -6,6 +6,8 @@ import (
"io/ioutil" "io/ioutil"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
) )
const versionCheckPeriod = 8 * 60 * 60 const versionCheckPeriod = 8 * 60 * 60
@ -19,6 +21,9 @@ type VersionInfo struct {
CanAutoUpdate bool // If true - we can auto-update CanAutoUpdate bool // If true - we can auto-update
} }
// MaxResponseSize is responses on server's requests maximum length in bytes.
const MaxResponseSize = 64 * 1024
// GetVersionResponse - downloads version.json (if needed) and deserializes it // GetVersionResponse - downloads version.json (if needed) and deserializes it
func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) { func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
if !forceRecheck && if !forceRecheck &&
@ -27,14 +32,19 @@ func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
} }
resp, err := u.Client.Get(u.VersionURL) resp, err := u.Client.Get(u.VersionURL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil { if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err) return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)
} }
defer resp.Body.Close()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
}
defer resp.Body.Close()
// This use of ReadAll is safe, because we just limited the appropriate
// ReadCloser.
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err) return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)

View File

@ -14,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -217,17 +218,27 @@ func (u *Updater) clean() {
_ = os.RemoveAll(u.updateDir) _ = os.RemoveAll(u.updateDir)
} }
// MaxPackageFileSize is a maximum package file length in bytes. The largest
// package whose size is limited by this constant currently has the size of
// approximately 9 MiB.
const MaxPackageFileSize = 32 * 1024 * 1024
// Download package file and save it to disk // Download package file and save it to disk
func (u *Updater) downloadPackageFile(url string, filename string) error { func (u *Updater) downloadPackageFile(url string, filename string) error {
resp, err := u.Client.Get(url) resp, err := u.Client.Get(url)
if err != nil { if err != nil {
return fmt.Errorf("http request failed: %w", err) return fmt.Errorf("http request failed: %w", err)
} }
if resp != nil && resp.Body != nil { defer resp.Body.Close()
defer resp.Body.Close()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
} }
defer resp.Body.Close()
log.Debug("updater: reading HTTP body") log.Debug("updater: reading HTTP body")
// This use of ReadAll is now safe, because we limited body's Reader.
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("ioutil.ReadAll() failed: %w", err) return fmt.Errorf("ioutil.ReadAll() failed: %w", err)