summaryrefslogtreecommitdiffstats
path: root/tpl
diff options
context:
space:
mode:
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>2023-07-11 09:48:57 +0200
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>2023-07-11 12:11:39 +0200
commit5bec50838c9d5ce097a82d91b7924f22453de1a9 (patch)
tree8b99d19e7ec74063fdf9aff401c63523dabc3c01 /tpl
parentf650e4d751d83566a88f7fd78d264730b5d06262 (diff)
tpl/collections: Fix WordCount (etc.) regression in Where, Sort, Delimit
Fixes #11234
Diffstat (limited to 'tpl')
-rw-r--r--tpl/collections/collections.go5
-rw-r--r--tpl/collections/collections_test.go5
-rw-r--r--tpl/collections/integration_test.go41
-rw-r--r--tpl/collections/sort.go9
-rw-r--r--tpl/collections/sort_test.go5
-rw-r--r--tpl/collections/where.go35
-rw-r--r--tpl/collections/where_test.go13
7 files changed, 86 insertions, 27 deletions
diff --git a/tpl/collections/collections.go b/tpl/collections/collections.go
index 04b777dfb..92aa2b9e6 100644
--- a/tpl/collections/collections.go
+++ b/tpl/collections/collections.go
@@ -16,6 +16,7 @@
package collections
import (
+ "context"
"fmt"
"html/template"
"math/rand"
@@ -99,7 +100,7 @@ func (ns *Namespace) After(n any, l any) (any, error) {
// Delimit takes a given list l and returns a string delimited by sep.
// If last is passed to the function, it will be used as the final delimiter.
-func (ns *Namespace) Delimit(l, sep any, last ...any) (template.HTML, error) {
+func (ns *Namespace) Delimit(ctx context.Context, l, sep any, last ...any) (template.HTML, error) {
d, err := cast.ToStringE(sep)
if err != nil {
return "", err
@@ -125,7 +126,7 @@ func (ns *Namespace) Delimit(l, sep any, last ...any) (template.HTML, error) {
var str string
switch lv.Kind() {
case reflect.Map:
- sortSeq, err := ns.Sort(l)
+ sortSeq, err := ns.Sort(ctx, l)
if err != nil {
return "", err
}
diff --git a/tpl/collections/collections_test.go b/tpl/collections/collections_test.go
index 86192c480..43f8377f3 100644
--- a/tpl/collections/collections_test.go
+++ b/tpl/collections/collections_test.go
@@ -14,6 +14,7 @@
package collections
import (
+ "context"
"errors"
"fmt"
"html/template"
@@ -166,9 +167,9 @@ func TestDelimit(t *testing.T) {
var err error
if test.last == nil {
- result, err = ns.Delimit(test.seq, test.delimiter)
+ result, err = ns.Delimit(context.Background(), test.seq, test.delimiter)
} else {
- result, err = ns.Delimit(test.seq, test.delimiter, test.last)
+ result, err = ns.Delimit(context.Background(), test.seq, test.delimiter, test.last)
}
c.Assert(err, qt.IsNil, errMsg)
diff --git a/tpl/collections/integration_test.go b/tpl/collections/integration_test.go
index 80d2f043a..7ef0b6c47 100644
--- a/tpl/collections/integration_test.go
+++ b/tpl/collections/integration_test.go
@@ -155,3 +155,44 @@ func TestAppendNilsToSliceWithNils(t *testing.T) {
}
}
+
+// Issue 11234.
+func TestWhereWithWordCount(t *testing.T) {
+ t.Parallel()
+
+ files := `
+-- config.toml --
+baseURL = 'http://example.com/'
+-- layouts/index.html --
+Home: {{ range where site.RegularPages "WordCount" "gt" 50 }}{{ .Title }}|{{ end }}
+-- layouts/shortcodes/lorem.html --
+{{ "ipsum " | strings.Repeat (.Get 0 | int) }}
+
+-- content/p1.md --
+---
+title: "p1"
+---
+{{< lorem 100 >}}
+-- content/p2.md --
+---
+title: "p2"
+---
+{{< lorem 20 >}}
+-- content/p3.md --
+---
+title: "p3"
+---
+{{< lorem 60 >}}
+ `
+
+ b := hugolib.NewIntegrationTestBuilder(
+ hugolib.IntegrationTestConfig{
+ T: t,
+ TxtarString: files,
+ },
+ ).Build()
+
+ b.AssertFileContent("public/index.html", `
+Home: p1|p3|
+`)
+}
diff --git a/tpl/collections/sort.go b/tpl/collections/sort.go
index 4a2106039..2040f8490 100644
--- a/tpl/collections/sort.go
+++ b/tpl/collections/sort.go
@@ -14,6 +14,7 @@
package collections
import (
+ "context"
"errors"
"reflect"
"sort"
@@ -26,7 +27,7 @@ import (
)
// Sort returns a sorted copy of the list l.
-func (ns *Namespace) Sort(l any, args ...any) (any, error) {
+func (ns *Namespace) Sort(ctx context.Context, l any, args ...any) (any, error) {
if l == nil {
return nil, errors.New("sequence must be provided")
}
@@ -36,6 +37,8 @@ func (ns *Namespace) Sort(l any, args ...any) (any, error) {
return nil, errors.New("can't iterate over a nil value")
}
+ ctxv := reflect.ValueOf(ctx)
+
var sliceType reflect.Type
switch seqv.Kind() {
case reflect.Array, reflect.Slice:
@@ -78,7 +81,7 @@ func (ns *Namespace) Sort(l any, args ...any) (any, error) {
v := p.Pairs[i].Value
var err error
for i, elemName := range path {
- v, err = evaluateSubElem(v, elemName)
+ v, err = evaluateSubElem(ctxv, v, elemName)
if err != nil {
return nil, err
}
@@ -108,7 +111,7 @@ func (ns *Namespace) Sort(l any, args ...any) (any, error) {
v := p.Pairs[i].Value
var err error
for i, elemName := range path {
- v, err = evaluateSubElem(v, elemName)
+ v, err = evaluateSubElem(ctxv, v, elemName)
if err != nil {
return nil, err
}
diff --git a/tpl/collections/sort_test.go b/tpl/collections/sort_test.go
index da9c75d04..1ec95882f 100644
--- a/tpl/collections/sort_test.go
+++ b/tpl/collections/sort_test.go
@@ -14,6 +14,7 @@
package collections
import (
+ "context"
"fmt"
"reflect"
"testing"
@@ -240,9 +241,9 @@ func TestSort(t *testing.T) {
var result any
var err error
if test.sortByField == nil {
- result, err = ns.Sort(test.seq)
+ result, err = ns.Sort(context.Background(), test.seq)
} else {
- result, err = ns.Sort(test.seq, test.sortByField, test.sortAsc)
+ result, err = ns.Sort(context.Background(), test.seq, test.sortByField, test.sortAsc)
}
if b, ok := test.expect.(bool); ok && !b {
diff --git a/tpl/collections/where.go b/tpl/collections/where.go
index b20c290fa..2904b7cdd 100644
--- a/tpl/collections/where.go
+++ b/tpl/collections/where.go
@@ -14,6 +14,7 @@
package collections
import (
+ "context"
"errors"
"fmt"
"reflect"
@@ -24,7 +25,7 @@ import (
)
// Where returns a filtered subset of collection c.
-func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
+func (ns *Namespace) Where(ctx context.Context, c, key any, args ...any) (any, error) {
seqv, isNil := indirect(reflect.ValueOf(c))
if isNil {
return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(c).Type().String())
@@ -35,6 +36,8 @@ func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
return nil, err
}
+ ctxv := reflect.ValueOf(ctx)
+
var path []string
kv := reflect.ValueOf(key)
if kv.Kind() == reflect.String {
@@ -43,9 +46,9 @@ func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
switch seqv.Kind() {
case reflect.Array, reflect.Slice:
- return ns.checkWhereArray(seqv, kv, mv, path, op)
+ return ns.checkWhereArray(ctxv, seqv, kv, mv, path, op)
case reflect.Map:
- return ns.checkWhereMap(seqv, kv, mv, path, op)
+ return ns.checkWhereMap(ctxv, seqv, kv, mv, path, op)
default:
return nil, fmt.Errorf("can't iterate over %v", c)
}
@@ -275,7 +278,7 @@ func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error
return false, nil
}
-func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
+func evaluateSubElem(ctx, obj reflect.Value, elemName string) (reflect.Value, error) {
if !obj.IsValid() {
return zero, errors.New("can't evaluate an invalid value")
}
@@ -301,12 +304,20 @@ func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error)
index := hreflect.GetMethodIndexByName(objPtr.Type(), elemName)
if index != -1 {
+ var args []reflect.Value
mt := objPtr.Type().Method(index)
+ num := mt.Type.NumIn()
+ maxNumIn := 1
+ if num > 1 && mt.Type.In(1).Implements(hreflect.ContextInterface) {
+ args = []reflect.Value{ctx}
+ maxNumIn = 2
+ }
+
switch {
case mt.PkgPath != "":
return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
- case mt.Type.NumIn() > 1:
- return zero, fmt.Errorf("%s is a method of type %s but requires more than 1 parameter", elemName, typ)
+ case mt.Type.NumIn() > maxNumIn:
+ return zero, fmt.Errorf("%s is a method of type %s but requires more than %d parameter", elemName, typ, maxNumIn)
case mt.Type.NumOut() == 0:
return zero, fmt.Errorf("%s is a method of type %s but returns no output", elemName, typ)
case mt.Type.NumOut() > 2:
@@ -316,7 +327,7 @@ func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error)
case mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType):
return zero, fmt.Errorf("%s is a method of type %s returning two values but the second value is not an error type", elemName, typ)
}
- res := objPtr.Method(mt.Index).Call([]reflect.Value{})
+ res := objPtr.Method(mt.Index).Call(args)
if len(res) == 2 && !res[1].IsNil() {
return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
}
@@ -371,7 +382,7 @@ func parseWhereArgs(args ...any) (mv reflect.Value, op string, err error) {
// checkWhereArray handles the where-matching logic when the seqv value is an
// Array or Slice.
-func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
+func (ns *Namespace) checkWhereArray(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
for i := 0; i < seqv.Len(); i++ {
@@ -385,7 +396,7 @@ func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string,
vvv = rvv
for i, elemName := range path {
var err error
- vvv, err = evaluateSubElem(vvv, elemName)
+ vvv, err = evaluateSubElem(ctxv, vvv, elemName)
if err != nil {
continue
@@ -417,14 +428,14 @@ func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string,
}
// checkWhereMap handles the where-matching logic when the seqv value is a Map.
-func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
+func (ns *Namespace) checkWhereMap(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
rv := reflect.MakeMap(seqv.Type())
keys := seqv.MapKeys()
for _, k := range keys {
elemv := seqv.MapIndex(k)
switch elemv.Kind() {
case reflect.Array, reflect.Slice:
- r, err := ns.checkWhereArray(elemv, kv, mv, path, op)
+ r, err := ns.checkWhereArray(ctxv, elemv, kv, mv, path, op)
if err != nil {
return nil, err
}
@@ -443,7 +454,7 @@ func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op
switch elemvv.Kind() {
case reflect.Array, reflect.Slice:
- r, err := ns.checkWhereArray(elemvv, kv, mv, path, op)
+ r, err := ns.checkWhereArray(ctxv, elemvv, kv, mv, path, op)
if err != nil {
return nil, err
}
diff --git a/tpl/collections/where_test.go b/tpl/collections/where_test.go
index e5ae85e88..1b787daa2 100644
--- a/tpl/collections/where_test.go
+++ b/tpl/collections/where_test.go
@@ -14,6 +14,7 @@
package collections
import (
+ "context"
"fmt"
"html/template"
"reflect"
@@ -641,9 +642,9 @@ func TestWhere(t *testing.T) {
var err error
if len(test.op) > 0 {
- results, err = ns.Where(test.seq, test.key, test.op, test.match)
+ results, err = ns.Where(context.Background(), test.seq, test.key, test.op, test.match)
} else {
- results, err = ns.Where(test.seq, test.key, test.match)
+ results, err = ns.Where(context.Background(), test.seq, test.key, test.match)
}
if b, ok := test.expect.(bool); ok && !b {
if err == nil {
@@ -662,17 +663,17 @@ func TestWhere(t *testing.T) {
}
var err error
- _, err = ns.Where(map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1)
+ _, err = ns.Where(context.Background(), map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1)
if err == nil {
t.Errorf("Where called with none string op value didn't return an expected error")
}
- _, err = ns.Where(map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1, 2)
+ _, err = ns.Where(context.Background(), map[string]int{"a": 1, "b": 2}, "a", []byte("="), 1, 2)
if err == nil {
t.Errorf("Where called with more than two variable arguments didn't return an expected error")
}
- _, err = ns.Where(map[string]int{"a": 1, "b": 2}, "a")
+ _, err = ns.Where(context.Background(), map[string]int{"a": 1, "b": 2}, "a")
if err == nil {
t.Errorf("Where called with no variable arguments didn't return an expected error")
}
@@ -842,7 +843,7 @@ func TestEvaluateSubElem(t *testing.T) {
{reflect.ValueOf(map[int]string{1: "foo", 2: "bar"}), "1", false},
{reflect.ValueOf([]string{"foo", "bar"}), "1", false},
} {
- result, err := evaluateSubElem(test.value, test.key)
+ result, err := evaluateSubElem(reflect.ValueOf(context.Background()), test.value, test.key)
if b, ok := test.expect.(bool); ok && !b {
if err == nil {
t.Errorf("[%d] evaluateSubElem didn't return an expected error", i)