aboutsummaryrefslogtreecommitdiffhomepage
path: root/tpl/data/data.go
diff options
context:
space:
mode:
authorBjørn Erik Pedersen <[email protected]>2021-06-05 12:44:45 +0200
committerBjørn Erik Pedersen <[email protected]>2021-06-06 13:32:12 +0200
commitfcd63de3a54fadcd30972654d8eb86dc4d889784 (patch)
tree5140863493b65783f73ecab8885f684fc623b1da /tpl/data/data.go
parent150d75738b54acddc485d363436757189144da6a (diff)
downloadhugo-fcd63de3a54fadcd30972654d8eb86dc4d889784.tar.gz
hugo-fcd63de3a54fadcd30972654d8eb86dc4d889784.zip
tpl/data: Misc header improvements, tests, allow multiple headers of same key
Closes #5617
Diffstat (limited to 'tpl/data/data.go')
-rw-r--r--tpl/data/data.go99
1 files changed, 69 insertions, 30 deletions
diff --git a/tpl/data/data.go b/tpl/data/data.go
index 4cb8b5e78..e993ed140 100644
--- a/tpl/data/data.go
+++ b/tpl/data/data.go
@@ -23,6 +23,10 @@ import (
"net/http"
"strings"
+ "github.com/gohugoio/hugo/common/maps"
+
+ "github.com/gohugoio/hugo/common/types"
+
"github.com/gohugoio/hugo/common/constants"
"github.com/gohugoio/hugo/common/loggers"
@@ -59,14 +63,10 @@ type Namespace struct {
// If you provide multiple parts for the URL they will be joined together to the final URL.
// GetCSV returns nil or a slice slice to use in a short code.
func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err error) {
- url := joinURL(args)
+ url, headers := toURLAndHeaders(args)
cache := ns.cacheGetCSV
unmarshal := func(b []byte) (bool, error) {
- if !bytes.Contains(b, []byte(sep)) {
- return false, _errors.Errorf("cannot find separator %s in CSV for %s", sep, url)
- }
-
if d, err = parseCSV(b, sep); err != nil {
err = _errors.Wrapf(err, "failed to parse CSV file %s", url)
@@ -82,17 +82,9 @@ func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err
return nil, _errors.Wrapf(err, "failed to create request for getCSV for resource %s", url)
}
- req.Header.Add("Accept", "text/csv")
- req.Header.Add("Accept", "text/plain")
-
- // Add custom user headers to the get request
- finalArg := args[len(args)-1]
-
- if userHeaders, ok := finalArg.(map[string]interface{}); ok {
- for key, val := range userHeaders {
- req.Header.Add(key, val.(string))
- }
- }
+ // Add custom user headers.
+ addUserProvidedHeaders(headers, req)
+ addDefaultHeaders(req, "text/csv", "text/plain")
err = ns.getResource(cache, unmarshal, req)
if err != nil {
@@ -108,7 +100,7 @@ func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err
// GetJSON returns nil or parsed JSON to use in a short code.
func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) {
var v interface{}
- url := joinURL(args)
+ url, headers := toURLAndHeaders(args)
cache := ns.cacheGetJSON
req, err := http.NewRequest("GET", url, nil)
@@ -124,17 +116,8 @@ func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) {
return false, nil
}
- req.Header.Add("Accept", "application/json")
- req.Header.Add("User-Agent", "Hugo Static Site Generator")
-
- // Add custom user headers to the get request
- finalArg := args[len(args)-1]
-
- if userHeaders, ok := finalArg.(map[string]interface{}); ok {
- for key, val := range userHeaders {
- req.Header.Add(key, val.(string))
- }
- }
+ addUserProvidedHeaders(headers, req)
+ addDefaultHeaders(req, "application/json")
err = ns.getResource(cache, unmarshal, req)
if err != nil {
@@ -145,8 +128,64 @@ func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) {
return v, nil
}
-func joinURL(urlParts []interface{}) string {
- return strings.Join(cast.ToStringSlice(urlParts), "")
+func addDefaultHeaders(req *http.Request, accepts ...string) {
+ for _, accept := range accepts {
+ if !hasHeaderValue(req.Header, "Accept", accept) {
+ req.Header.Add("Accept", accept)
+ }
+ }
+ if !hasHeaderKey(req.Header, "User-Agent") {
+ req.Header.Add("User-Agent", "Hugo Static Site Generator")
+ }
+}
+
+func addUserProvidedHeaders(headers map[string]interface{}, req *http.Request) {
+ if headers == nil {
+ return
+ }
+ for key, val := range headers {
+ vals := types.ToStringSlicePreserveString(val)
+ for _, s := range vals {
+ req.Header.Add(key, s)
+ }
+ }
+}
+
+func hasHeaderValue(m http.Header, key, value string) bool {
+ var s []string
+ var ok bool
+
+ if s, ok = m[key]; !ok {
+ return false
+ }
+
+ for _, v := range s {
+ if v == value {
+ return true
+ }
+ }
+ return false
+}
+
+func hasHeaderKey(m http.Header, key string) bool {
+ _, ok := m[key]
+ return ok
+}
+
+func toURLAndHeaders(urlParts []interface{}) (string, map[string]interface{}) {
+ if len(urlParts) == 0 {
+ return "", nil
+ }
+
+ // The last argument may be a map.
+ headers, err := maps.ToStringMapE(urlParts[len(urlParts)-1])
+ if err == nil {
+ urlParts = urlParts[:len(urlParts)-1]
+ } else {
+ headers = nil
+ }
+
+ return strings.Join(cast.ToStringSlice(urlParts), ""), headers
}
// parseCSV parses bytes of CSV data into a slice slice string or an error