diff options
author | Bjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com> | 2021-06-05 12:44:45 +0200 |
---|---|---|
committer | Bjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com> | 2021-06-06 13:32:12 +0200 |
commit | fcd63de3a54fadcd30972654d8eb86dc4d889784 (patch) | |
tree | 5140863493b65783f73ecab8885f684fc623b1da /tpl | |
parent | 150d75738b54acddc485d363436757189144da6a (diff) |
tpl/data: Misc header improvements, tests, allow multiple headers of same key
Closes #5617
Diffstat (limited to 'tpl')
-rw-r--r-- | tpl/data/data.go | 99 | ||||
-rw-r--r-- | tpl/data/data_test.go | 281 | ||||
-rw-r--r-- | tpl/data/resources.go | 9 |
3 files changed, 241 insertions, 148 deletions
diff --git a/tpl/data/data.go b/tpl/data/data.go index 4cb8b5e78..e993ed140 100644 --- a/tpl/data/data.go +++ b/tpl/data/data.go @@ -23,6 +23,10 @@ import ( "net/http" "strings" + "github.com/gohugoio/hugo/common/maps" + + "github.com/gohugoio/hugo/common/types" + "github.com/gohugoio/hugo/common/constants" "github.com/gohugoio/hugo/common/loggers" @@ -59,14 +63,10 @@ type Namespace struct { // If you provide multiple parts for the URL they will be joined together to the final URL. // GetCSV returns nil or a slice slice to use in a short code. func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err error) { - url := joinURL(args) + url, headers := toURLAndHeaders(args) cache := ns.cacheGetCSV unmarshal := func(b []byte) (bool, error) { - if !bytes.Contains(b, []byte(sep)) { - return false, _errors.Errorf("cannot find separator %s in CSV for %s", sep, url) - } - if d, err = parseCSV(b, sep); err != nil { err = _errors.Wrapf(err, "failed to parse CSV file %s", url) @@ -82,17 +82,9 @@ func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err return nil, _errors.Wrapf(err, "failed to create request for getCSV for resource %s", url) } - req.Header.Add("Accept", "text/csv") - req.Header.Add("Accept", "text/plain") - - // Add custom user headers to the get request - finalArg := args[len(args)-1] - - if userHeaders, ok := finalArg.(map[string]interface{}); ok { - for key, val := range userHeaders { - req.Header.Add(key, val.(string)) - } - } + // Add custom user headers. + addUserProvidedHeaders(headers, req) + addDefaultHeaders(req, "text/csv", "text/plain") err = ns.getResource(cache, unmarshal, req) if err != nil { @@ -108,7 +100,7 @@ func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err // GetJSON returns nil or parsed JSON to use in a short code. func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) { var v interface{} - url := joinURL(args) + url, headers := toURLAndHeaders(args) cache := ns.cacheGetJSON req, err := http.NewRequest("GET", url, nil) @@ -124,17 +116,8 @@ func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) { return false, nil } - req.Header.Add("Accept", "application/json") - req.Header.Add("User-Agent", "Hugo Static Site Generator") - - // Add custom user headers to the get request - finalArg := args[len(args)-1] - - if userHeaders, ok := finalArg.(map[string]interface{}); ok { - for key, val := range userHeaders { - req.Header.Add(key, val.(string)) - } - } + addUserProvidedHeaders(headers, req) + addDefaultHeaders(req, "application/json") err = ns.getResource(cache, unmarshal, req) if err != nil { @@ -145,8 +128,64 @@ func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) { return v, nil } -func joinURL(urlParts []interface{}) string { - return strings.Join(cast.ToStringSlice(urlParts), "") +func addDefaultHeaders(req *http.Request, accepts ...string) { + for _, accept := range accepts { + if !hasHeaderValue(req.Header, "Accept", accept) { + req.Header.Add("Accept", accept) + } + } + if !hasHeaderKey(req.Header, "User-Agent") { + req.Header.Add("User-Agent", "Hugo Static Site Generator") + } +} + +func addUserProvidedHeaders(headers map[string]interface{}, req *http.Request) { + if headers == nil { + return + } + for key, val := range headers { + vals := types.ToStringSlicePreserveString(val) + for _, s := range vals { + req.Header.Add(key, s) + } + } +} + +func hasHeaderValue(m http.Header, key, value string) bool { + var s []string + var ok bool + + if s, ok = m[key]; !ok { + return false + } + + for _, v := range s { + if v == value { + return true + } + } + return false +} + +func hasHeaderKey(m http.Header, key string) bool { + _, ok := m[key] + return ok +} + +func toURLAndHeaders(urlParts []interface{}) (string, map[string]interface{}) { + if len(urlParts) == 0 { + return "", nil + } + + // The last argument may be a map. + headers, err := maps.ToStringMapE(urlParts[len(urlParts)-1]) + if err == nil { + urlParts = urlParts[:len(urlParts)-1] + } else { + headers = nil + } + + return strings.Join(cast.ToStringSlice(urlParts), ""), headers } // parseCSV parses bytes of CSV data into a slice slice string or an error diff --git a/tpl/data/data_test.go b/tpl/data/data_test.go index 6b62a2b0d..8a18a19e4 100644 --- a/tpl/data/data_test.go +++ b/tpl/data/data_test.go @@ -14,12 +14,16 @@ package data import ( + "bytes" + "html/template" "net/http" "net/http/httptest" "path/filepath" "strings" "testing" + "github.com/gohugoio/hugo/common/maps" + qt "github.com/frankban/quicktest" ) @@ -48,12 +52,6 @@ func TestGetCSV(t *testing.T) { }, { ",", - `http://error.no.sep/`, - "gomeetup;city\nyes;Sydney\nyes;San Francisco\nyes;Stockholm\n", - false, - }, - { - ",", `http://nofound/404`, ``, false, @@ -73,66 +71,54 @@ func TestGetCSV(t *testing.T) { false, }, } { - msg := qt.Commentf("Test %d", i) - ns := newTestNs() - - // Setup HTTP test server - var srv *httptest.Server - srv, ns.client = getTestServer(func(w http.ResponseWriter, r *http.Request) { - if !haveHeader(r.Header, "Accept", "text/csv") && !haveHeader(r.Header, "Accept", "text/plain") { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } + c.Run(test.url, func(c *qt.C) { + msg := qt.Commentf("Test %d", i) - if r.URL.Path == "/404" { - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) - return - } + ns := newTestNs() - w.Header().Add("Content-type", "text/csv") + // Setup HTTP test server + var srv *httptest.Server + srv, ns.client = getTestServer(func(w http.ResponseWriter, r *http.Request) { + if !hasHeaderValue(r.Header, "Accept", "text/csv") && !hasHeaderValue(r.Header, "Accept", "text/plain") { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } - w.Write([]byte(test.content)) - }) - defer func() { srv.Close() }() + if r.URL.Path == "/404" { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } - // Setup local test file for schema-less URLs - if !strings.Contains(test.url, ":") && !strings.HasPrefix(test.url, "fail/") { - f, err := ns.deps.Fs.Source.Create(filepath.Join(ns.deps.Cfg.GetString("workingDir"), test.url)) - c.Assert(err, qt.IsNil, msg) - f.WriteString(test.content) - f.Close() - } + w.Header().Add("Content-type", "text/csv") - // Get on with it - got, err := ns.GetCSV(test.sep, test.url) + w.Write([]byte(test.content)) + }) + defer func() { srv.Close() }() - if _, ok := test.expect.(bool); ok { - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) - // c.Assert(err, msg, qt.Not(qt.IsNil)) - c.Assert(got, qt.IsNil) - continue - } + // Setup local test file for schema-less URLs + if !strings.Contains(test.url, ":") && !strings.HasPrefix(test.url, "fail/") { + f, err := ns.deps.Fs.Source.Create(filepath.Join(ns.deps.Cfg.GetString("workingDir"), test.url)) + c.Assert(err, qt.IsNil, msg) + f.WriteString(test.content) + f.Close() + } - c.Assert(err, qt.IsNil, msg) - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0) - c.Assert(got, qt.Not(qt.IsNil), msg) - c.Assert(got, qt.DeepEquals, test.expect, msg) + // Get on with it + got, err := ns.GetCSV(test.sep, test.url) - // Test user-defined headers as well - gotHeader, _ := ns.GetCSV(test.sep, test.url, map[string]interface{}{"Accept-Charset": "utf-8", "Max-Forwards": "10"}) + if _, ok := test.expect.(bool); ok { + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) + c.Assert(got, qt.IsNil) + return + } - if _, ok := test.expect.(bool); ok { - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) - // c.Assert(err, msg, qt.Not(qt.IsNil)) - c.Assert(got, qt.IsNil) - continue - } + c.Assert(err, qt.IsNil, msg) + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0) + c.Assert(got, qt.Not(qt.IsNil), msg) + c.Assert(got, qt.DeepEquals, test.expect, msg) + }) - c.Assert(err, qt.IsNil, msg) - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0) - c.Assert(gotHeader, qt.Not(qt.IsNil), msg) - c.Assert(gotHeader, qt.DeepEquals, test.expect, msg) } } @@ -178,68 +164,153 @@ func TestGetJSON(t *testing.T) { }, } { - msg := qt.Commentf("Test %d", i) - ns := newTestNs() + c.Run(test.url, func(c *qt.C) { - // Setup HTTP test server - var srv *httptest.Server - srv, ns.client = getTestServer(func(w http.ResponseWriter, r *http.Request) { - if !haveHeader(r.Header, "Accept", "application/json") { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return + msg := qt.Commentf("Test %d", i) + ns := newTestNs() + + // Setup HTTP test server + var srv *httptest.Server + srv, ns.client = getTestServer(func(w http.ResponseWriter, r *http.Request) { + if !hasHeaderValue(r.Header, "Accept", "application/json") { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + if r.URL.Path == "/404" { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + + w.Header().Add("Content-type", "application/json") + + w.Write([]byte(test.content)) + }) + defer func() { srv.Close() }() + + // Setup local test file for schema-less URLs + if !strings.Contains(test.url, ":") && !strings.HasPrefix(test.url, "fail/") { + f, err := ns.deps.Fs.Source.Create(filepath.Join(ns.deps.Cfg.GetString("workingDir"), test.url)) + c.Assert(err, qt.IsNil, msg) + f.WriteString(test.content) + f.Close() } - if r.URL.Path == "/404" { - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + // Get on with it + got, _ := ns.GetJSON(test.url) + + if _, ok := test.expect.(bool); ok { + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) return } - w.Header().Add("Content-type", "application/json") + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg) + c.Assert(got, qt.Not(qt.IsNil), msg) + c.Assert(got, qt.DeepEquals, test.expect) - w.Write([]byte(test.content)) }) - defer func() { srv.Close() }() + } +} - // Setup local test file for schema-less URLs - if !strings.Contains(test.url, ":") && !strings.HasPrefix(test.url, "fail/") { - f, err := ns.deps.Fs.Source.Create(filepath.Join(ns.deps.Cfg.GetString("workingDir"), test.url)) - c.Assert(err, qt.IsNil, msg) - f.WriteString(test.content) - f.Close() - } +func TestHeaders(t *testing.T) { + t.Parallel() + c := qt.New(t) - // Get on with it - got, _ := ns.GetJSON(test.url) + for _, test := range []struct { + name string + headers interface{} + assert func(c *qt.C, headers string) + }{ + { + `Misc header variants`, + map[string]interface{}{ + "Accept-Charset": "utf-8", + "Max-forwards": "10", + "X-Int": 32, + "X-Templ": template.HTML("a"), + "X-Multiple": []string{"a", "b"}, + "X-MultipleInt": []int{3, 4}, + }, + func(c *qt.C, headers string) { + c.Assert(headers, qt.Contains, "Accept-Charset: utf-8") + c.Assert(headers, qt.Contains, "Max-Forwards: 10") + c.Assert(headers, qt.Contains, "X-Int: 32") + c.Assert(headers, qt.Contains, "X-Templ: a") + c.Assert(headers, qt.Contains, "X-Multiple: a") + c.Assert(headers, qt.Contains, "X-Multiple: b") + c.Assert(headers, qt.Contains, "X-Multipleint: 3") + c.Assert(headers, qt.Contains, "X-Multipleint: 4") + c.Assert(headers, qt.Contains, "User-Agent: Hugo Static Site Generator") + }, + }, + { + `Params`, + maps.Params{ + "Accept-Charset": "utf-8", + }, + func(c *qt.C, headers string) { + c.Assert(headers, qt.Contains, "Accept-Charset: utf-8") + }, + }, + { + `Override User-Agent`, + map[string]interface{}{ + "User-Agent": "007", + }, + func(c *qt.C, headers string) { + c.Assert(headers, qt.Contains, "User-Agent: 007") + }, + }, + } { - if _, ok := test.expect.(bool); ok { - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) - // c.Assert(err, msg, qt.Not(qt.IsNil)) - continue - } + c.Run(test.name, func(c *qt.C) { - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg) - c.Assert(got, qt.Not(qt.IsNil), msg) - c.Assert(got, qt.DeepEquals, test.expect) + ns := newTestNs() - // Test user-defined headers as well - gotHeader, _ := ns.GetJSON(test.url, map[string]interface{}{"Accept-Charset": "utf-8", "Max-Forwards": "10"}) + // Setup HTTP test server + var srv *httptest.Server + var headers bytes.Buffer + srv, ns.client = getTestServer(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.URL.String(), qt.Equals, "http://gohugo.io/api?foo") + w.Write([]byte("{}")) + r.Header.Write(&headers) - if _, ok := test.expect.(bool); ok { - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1) - // c.Assert(err, msg, qt.Not(qt.IsNil)) - continue - } + }) + defer func() { srv.Close() }() + + testFunc := func(fn func(args ...interface{}) error) { + defer headers.Reset() + err := fn("http://example.org/api", "?foo", test.headers) + + c.Assert(err, qt.IsNil) + c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0) + test.assert(c, headers.String()) + } + + testFunc(func(args ...interface{}) error { + _, err := ns.GetJSON(args...) + return err + }) + testFunc(func(args ...interface{}) error { + _, err := ns.GetCSV(",", args...) + return err + }) + + }) - c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg) - c.Assert(gotHeader, qt.Not(qt.IsNil), msg) - c.Assert(gotHeader, qt.DeepEquals, test.expect) } } -func TestJoinURL(t *testing.T) { +func TestToURLAndHeaders(t *testing.T) { t.Parallel() c := qt.New(t) - c.Assert(joinURL([]interface{}{"https://foo?id=", 32}), qt.Equals, "https://foo?id=32") + url, headers := toURLAndHeaders([]interface{}{"https://foo?id=", 32}) + c.Assert(url, qt.Equals, "https://foo?id=32") + c.Assert(headers, qt.IsNil) + + url, headers = toURLAndHeaders([]interface{}{"https://foo?id=", 32, map[string]interface{}{"a": "b"}}) + c.Assert(url, qt.Equals, "https://foo?id=32") + c.Assert(headers, qt.DeepEquals, map[string]interface{}{"a": "b"}) } func TestParseCSV(t *testing.T) { @@ -276,19 +347,3 @@ func TestParseCSV(t *testing.T) { c.Assert(act, qt.Equals, test.exp, msg) } } - -func haveHeader(m http.Header, key, needle string) bool { - var s []string - var ok bool - - if s, ok = m[key]; !ok { - return false - } - - for _, v := range s { - if v == needle { - return true - } - } - return false -} diff --git a/tpl/data/resources.go b/tpl/data/resources.go index ba98f12b4..68f18c48e 100644 --- a/tpl/data/resources.go +++ b/tpl/data/resources.go @@ -14,6 +14,7 @@ package data import ( + "bytes" "io/ioutil" "net/http" "net/url" @@ -37,7 +38,9 @@ var ( // getRemote loads the content of a remote file. This method is thread safe. func (ns *Namespace) getRemote(cache *filecache.Cache, unmarshal func([]byte) (bool, error), req *http.Request) error { url := req.URL.String() - id := helpers.MD5String(url) + var headers bytes.Buffer + req.Header.Write(&headers) + id := helpers.MD5String(url + headers.String()) var handled bool var retry bool @@ -94,10 +97,6 @@ func (ns *Namespace) getRemote(cache *filecache.Cache, unmarshal func([]byte) (b // getLocal loads the content of a local file func getLocal(url string, fs afero.Fs, cfg config.Provider) ([]byte, error) { filename := filepath.Join(cfg.GetString("workingDir"), url) - if e, err := helpers.Exists(filename, fs); !e { - return nil, err - } - return afero.ReadFile(fs, filename) } |