diff options
author | Bjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com> | 2019-04-18 17:06:54 +0200 |
---|---|---|
committer | Bjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com> | 2019-04-18 23:42:01 +0200 |
commit | 06f56fc983d460506d39b3a6f638b1632af07073 (patch) | |
tree | 88f8a53b384e303f008842c859f0170b3450cdcf | |
parent | d7a67dcb51829b12d492d3f2ee4f6e2a3834da63 (diff) |
tpl/collections: Make Pages etc. work with the in func
Fixes #5875
-rw-r--r-- | tpl/collections/collections.go | 31 | ||||
-rw-r--r-- | tpl/collections/collections_test.go | 9 |
2 files changed, 23 insertions, 17 deletions
diff --git a/tpl/collections/collections.go b/tpl/collections/collections.go index 15f17ff6d..69d0f7042 100644 --- a/tpl/collections/collections.go +++ b/tpl/collections/collections.go @@ -250,27 +250,26 @@ func (ns *Namespace) In(l interface{}, v interface{}) bool { lv := reflect.ValueOf(l) vv := reflect.ValueOf(v) + if !vv.Type().Comparable() { + // TODO(bep) consider adding error to the signature. + return false + } + + // Normalize numeric types to float64 etc. + vvk := normalize(vv) + switch lv.Kind() { case reflect.Array, reflect.Slice: for i := 0; i < lv.Len(); i++ { - lvv := lv.Index(i) - lvv, isNil := indirect(lvv) - if isNil { + lvv, isNil := indirectInterface(lv.Index(i)) + if isNil || !lvv.Type().Comparable() { continue } - switch lvv.Kind() { - case reflect.String: - if vv.Type() == lvv.Type() && vv.String() == lvv.String() { - return true - } - default: - if isNumber(vv.Kind()) && isNumber(lvv.Kind()) { - f1, err1 := numberToFloat(vv) - f2, err2 := numberToFloat(lvv) - if err1 == nil && err2 == nil && f1 == f2 { - return true - } - } + + lvvk := normalize(lvv) + + if lvvk == vvk { + return true } } case reflect.String: diff --git a/tpl/collections/collections_test.go b/tpl/collections/collections_test.go index 741dd074d..c87490b2c 100644 --- a/tpl/collections/collections_test.go +++ b/tpl/collections/collections_test.go @@ -276,6 +276,7 @@ func TestFirst(t *testing.T) { func TestIn(t *testing.T) { t.Parallel() + assert := require.New(t) ns := New(&deps.Deps{}) @@ -302,12 +303,18 @@ func TestIn(t *testing.T) { {"this substring should be found", "substring", true}, {"this substring should not be found", "subseastring", false}, {nil, "foo", false}, + // Pointers + {pagesPtr{p1, p2, p3, p2}, p2, true}, + {pagesPtr{p1, p2, p3, p2}, p4, false}, + // Structs + {pagesVals{p3v, p2v, p3v, p2v}, p2v, true}, + {pagesVals{p3v, p2v, p3v, p2v}, p4v, false}, } { errMsg := fmt.Sprintf("[%d] %v", i, test) result := ns.In(test.l1, test.l2) - assert.Equal(t, test.expect, result, errMsg) + assert.Equal(test.expect, result, errMsg) } } |