summaryrefslogtreecommitdiffstats
path: root/tpl/compare
diff options
context:
space:
mode:
authorJoe Mooring <joe.mooring@veriphor.com>2022-02-04 03:01:54 -0800
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>2022-02-05 17:41:43 +0100
commit9262719092d6ec8e034c7a097575defe8611dadb (patch)
treebe37829e7b3b9d2a5788889ccde9ca015c2922f0 /tpl/compare
parent3336762939d8600f9af943128e7072c8789a6ae3 (diff)
Validate comparison operator argument count
Fixes #9462
Diffstat (limited to 'tpl/compare')
-rw-r--r--tpl/compare/compare.go17
-rw-r--r--tpl/compare/compare_test.go17
2 files changed, 30 insertions, 4 deletions
diff --git a/tpl/compare/compare.go b/tpl/compare/compare.go
index 88b18f00c..8f3f536b5 100644
--- a/tpl/compare/compare.go
+++ b/tpl/compare/compare.go
@@ -95,10 +95,7 @@ func (n *Namespace) Eq(first interface{}, others ...interface{}) bool {
if n.caseInsensitive {
panic("caseInsensitive not implemented for Eq")
}
- if len(others) == 0 {
- panic("missing arguments for comparison")
- }
-
+ n.checkComparisonArgCount(1, others...)
normalize := func(v interface{}) interface{} {
if types.IsNil(v) {
return nil
@@ -145,6 +142,7 @@ func (n *Namespace) Eq(first interface{}, others ...interface{}) bool {
// Ne returns the boolean truth of arg1 != arg2 && arg1 != arg3 && arg1 != arg4.
func (n *Namespace) Ne(first interface{}, others ...interface{}) bool {
+ n.checkComparisonArgCount(1, others...)
for _, other := range others {
if n.Eq(first, other) {
return false
@@ -155,6 +153,7 @@ func (n *Namespace) Ne(first interface{}, others ...interface{}) bool {
// Ge returns the boolean truth of arg1 >= arg2 && arg1 >= arg3 && arg1 >= arg4.
func (n *Namespace) Ge(first interface{}, others ...interface{}) bool {
+ n.checkComparisonArgCount(1, others...)
for _, other := range others {
left, right := n.compareGet(first, other)
if !(left >= right) {
@@ -166,6 +165,7 @@ func (n *Namespace) Ge(first interface{}, others ...interface{}) bool {
// Gt returns the boolean truth of arg1 > arg2 && arg1 > arg3 && arg1 > arg4.
func (n *Namespace) Gt(first interface{}, others ...interface{}) bool {
+ n.checkComparisonArgCount(1, others...)
for _, other := range others {
left, right := n.compareGet(first, other)
if !(left > right) {
@@ -177,6 +177,7 @@ func (n *Namespace) Gt(first interface{}, others ...interface{}) bool {
// Le returns the boolean truth of arg1 <= arg2 && arg1 <= arg3 && arg1 <= arg4.
func (n *Namespace) Le(first interface{}, others ...interface{}) bool {
+ n.checkComparisonArgCount(1, others...)
for _, other := range others {
left, right := n.compareGet(first, other)
if !(left <= right) {
@@ -188,6 +189,7 @@ func (n *Namespace) Le(first interface{}, others ...interface{}) bool {
// Lt returns the boolean truth of arg1 < arg2 && arg1 < arg3 && arg1 < arg4.
func (n *Namespace) Lt(first interface{}, others ...interface{}) bool {
+ n.checkComparisonArgCount(1, others...)
for _, other := range others {
left, right := n.compareGet(first, other)
if !(left < right) {
@@ -197,6 +199,13 @@ func (n *Namespace) Lt(first interface{}, others ...interface{}) bool {
return true
}
+func (n *Namespace) checkComparisonArgCount(min int, others ...interface{}) bool {
+ if len(others) < min {
+ panic("missing arguments for comparison")
+ }
+ return true
+}
+
// Conditional can be used as a ternary operator.
// It returns a if condition, else b.
func (n *Namespace) Conditional(condition bool, a, b interface{}) interface{} {
diff --git a/tpl/compare/compare_test.go b/tpl/compare/compare_test.go
index 76fe2698a..9ef32fd85 100644
--- a/tpl/compare/compare_test.go
+++ b/tpl/compare/compare_test.go
@@ -440,3 +440,20 @@ func TestConditional(t *testing.T) {
c.Assert(n.Conditional(true, a, b), qt.Equals, a)
c.Assert(n.Conditional(false, a, b), qt.Equals, b)
}
+
+// Issue 9462
+func TestComparisonArgCount(t *testing.T) {
+ t.Parallel()
+ c := qt.New(t)
+
+ ns := New(false)
+
+ panicMsg := "missing arguments for comparison"
+
+ c.Assert(func() { ns.Eq(1) }, qt.PanicMatches, panicMsg)
+ c.Assert(func() { ns.Ge(1) }, qt.PanicMatches, panicMsg)
+ c.Assert(func() { ns.Gt(1) }, qt.PanicMatches, panicMsg)
+ c.Assert(func() { ns.Le(1) }, qt.PanicMatches, panicMsg)
+ c.Assert(func() { ns.Lt(1) }, qt.PanicMatches, panicMsg)
+ c.Assert(func() { ns.Ne(1) }, qt.PanicMatches, panicMsg)
+}