From 0566bbf7c7f2898fcd1d6156b27733cd48aa0449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Mon, 6 Jun 2022 09:48:40 +0200 Subject: Fix raw TOML dates in where/eq Note that this has only been a problem with "raw dates" in TOML files in /data and similar. The predefined front matter dates `.Date` etc. are converted to a Go Time and has worked fine even after upgrading to v2 of the go-toml lib. Fixes #9979 --- tpl/compare/compare.go | 34 ++++++++++++++++++---------------- tpl/compare/compare_test.go | 29 +++++++++++++++-------------- tpl/compare/init.go | 7 ++++++- 3 files changed, 39 insertions(+), 31 deletions(-) (limited to 'tpl/compare') diff --git a/tpl/compare/compare.go b/tpl/compare/compare.go index 9905003b2..0b2d065ab 100644 --- a/tpl/compare/compare.go +++ b/tpl/compare/compare.go @@ -23,16 +23,19 @@ import ( "github.com/gohugoio/hugo/compare" "github.com/gohugoio/hugo/langs" + "github.com/gohugoio/hugo/common/hreflect" + "github.com/gohugoio/hugo/common/htime" "github.com/gohugoio/hugo/common/types" ) // New returns a new instance of the compare-namespaced template functions. -func New(caseInsensitive bool) *Namespace { - return &Namespace{caseInsensitive: caseInsensitive} +func New(loc *time.Location, caseInsensitive bool) *Namespace { + return &Namespace{loc: loc, caseInsensitive: caseInsensitive} } // Namespace provides template functions for the "compare" namespace. type Namespace struct { + loc *time.Location // Enable to do case insensitive string compares. caseInsensitive bool } @@ -101,6 +104,11 @@ func (n *Namespace) Eq(first any, others ...any) bool { if types.IsNil(v) { return nil } + + if at, ok := v.(htime.AsTimeProvider); ok { + return at.AsTime(n.loc) + } + vv := reflect.ValueOf(v) switch vv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -269,9 +277,8 @@ func (ns *Namespace) compareGetWithCollator(collator *langs.Collator, a any, b a leftStr = &str } case reflect.Struct: - switch av.Type() { - case timeType: - left = float64(toTimeUnix(av)) + if hreflect.IsTime(av.Type()) { + left = float64(ns.toTimeUnix(av)) } case reflect.Bool: left = 0 @@ -297,9 +304,8 @@ func (ns *Namespace) compareGetWithCollator(collator *langs.Collator, a any, b a rightStr = &str } case reflect.Struct: - switch bv.Type() { - case timeType: - right = float64(toTimeUnix(bv)) + if hreflect.IsTime(bv.Type()) { + right = float64(ns.toTimeUnix(bv)) } case reflect.Bool: right = 0 @@ -337,14 +343,10 @@ func (ns *Namespace) compareGetWithCollator(collator *langs.Collator, a any, b a return left, right } -var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - -func toTimeUnix(v reflect.Value) int64 { - if v.Kind() == reflect.Interface { - return toTimeUnix(v.Elem()) - } - if v.Type() != timeType { +func (ns *Namespace) toTimeUnix(v reflect.Value) int64 { + t, ok := hreflect.AsTime(v, ns.loc) + if !ok { panic("coding error: argument must be time.Time type reflect Value") } - return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int() + return t.Unix() } diff --git a/tpl/compare/compare_test.go b/tpl/compare/compare_test.go index 782941b15..ce2016b38 100644 --- a/tpl/compare/compare_test.go +++ b/tpl/compare/compare_test.go @@ -88,7 +88,7 @@ func TestDefaultFunc(t *testing.T) { then := time.Now() now := time.Now() - ns := New(false) + ns := New(time.UTC, false) for i, test := range []struct { dflt any @@ -147,7 +147,7 @@ func TestDefaultFunc(t *testing.T) { func TestCompare(t *testing.T) { t.Parallel() - n := New(false) + n := New(time.UTC, false) twoEq := func(a, b any) bool { return n.Eq(a, b) @@ -269,7 +269,7 @@ func TestEqualExtend(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) for _, test := range []struct { first any @@ -294,7 +294,7 @@ func TestNotEqualExtend(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) for _, test := range []struct { first any @@ -314,7 +314,7 @@ func TestGreaterEqualExtend(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) for _, test := range []struct { first any @@ -335,7 +335,7 @@ func TestGreaterThanExtend(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) for _, test := range []struct { first any @@ -355,7 +355,7 @@ func TestLessEqualExtend(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) for _, test := range []struct { first any @@ -376,7 +376,7 @@ func TestLessThanExtend(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) for _, test := range []struct { first any @@ -395,7 +395,7 @@ func TestLessThanExtend(t *testing.T) { func TestCase(t *testing.T) { c := qt.New(t) - n := New(false) + n := New(time.UTC, false) c.Assert(n.Eq("az", "az"), qt.Equals, true) c.Assert(n.Eq("az", stringType("az")), qt.Equals, true) @@ -403,7 +403,7 @@ func TestCase(t *testing.T) { func TestStringType(t *testing.T) { c := qt.New(t) - n := New(true) + n := New(time.UTC, true) c.Assert(n.Lt("az", "Za"), qt.Equals, true) c.Assert(n.Gt("ab", "Ab"), qt.Equals, true) @@ -411,11 +411,12 @@ func TestStringType(t *testing.T) { func TestTimeUnix(t *testing.T) { t.Parallel() + n := New(time.UTC, false) var sec int64 = 1234567890 tv := reflect.ValueOf(time.Unix(sec, 0)) i := 1 - res := toTimeUnix(tv) + res := n.toTimeUnix(tv) if sec != res { t.Errorf("[%d] timeUnix got %v but expected %v", i, res, sec) } @@ -428,13 +429,13 @@ func TestTimeUnix(t *testing.T) { } }() iv := reflect.ValueOf(sec) - toTimeUnix(iv) + n.toTimeUnix(iv) }(t) } func TestConditional(t *testing.T) { c := qt.New(t) - n := New(false) + n := New(time.UTC, false) a, b := "a", "b" c.Assert(n.Conditional(true, a, b), qt.Equals, a) @@ -446,7 +447,7 @@ func TestComparisonArgCount(t *testing.T) { t.Parallel() c := qt.New(t) - ns := New(false) + ns := New(time.UTC, false) panicMsg := "missing arguments for comparison" diff --git a/tpl/compare/init.go b/tpl/compare/init.go index 2308b235e..98c07f41b 100644 --- a/tpl/compare/init.go +++ b/tpl/compare/init.go @@ -15,6 +15,7 @@ package compare import ( "github.com/gohugoio/hugo/deps" + "github.com/gohugoio/hugo/langs" "github.com/gohugoio/hugo/tpl/internal" ) @@ -22,7 +23,11 @@ const name = "compare" func init() { f := func(d *deps.Deps) *internal.TemplateFuncsNamespace { - ctx := New(false) + if d.Language == nil { + panic("language must be set") + } + + ctx := New(langs.GetLocation(d.Language), false) ns := &internal.TemplateFuncsNamespace{ Name: name, -- cgit v1.2.3