summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKarl Lorey <git@karllorey.com>2022-07-07 13:32:38 +0200
committerKarl Lorey <git@karllorey.com>2022-07-07 13:32:38 +0200
commit20c492c62dd7457c485c3bf79e28a81ea135d84d (patch)
tree6f50838f75369059f6e162837c554c3749cc30ff
parent2c58adb28b1bc8750d0a434f721676b8cf83734f (diff)
Use mean instead of sum for similarity to improve understandability
-rw-r--r--mlscraper/matches.py5
-rw-r--r--mlscraper/training.py6
2 files changed, 8 insertions, 3 deletions
diff --git a/mlscraper/matches.py b/mlscraper/matches.py
index 3b039d5..d34c430 100644
--- a/mlscraper/matches.py
+++ b/mlscraper/matches.py
@@ -6,6 +6,7 @@ import typing
from functools import cached_property
from itertools import combinations
from itertools import product
+from statistics import mean
from mlscraper.html import get_relative_depth
from mlscraper.html import get_root_node
@@ -135,7 +136,7 @@ class DictMatch(Match):
keys = set(self.match_by_key.keys()).intersection(
set(match.match_by_key.keys())
)
- return sum(
+ return mean(
self.match_by_key[key].get_similarity_to(match.match_by_key[key])
for key in keys
)
@@ -163,7 +164,7 @@ class ListMatch(Match):
def get_similarity_to(self, match: "Match"):
assert isinstance(match, self.__class__)
- return sum(
+ return mean(
lm1.get_similarity_to(lm2)
for lm1, lm2 in product(self.matches, match.matches)
)
diff --git a/mlscraper/training.py b/mlscraper/training.py
index 3e3b259..d9b1b8a 100644
--- a/mlscraper/training.py
+++ b/mlscraper/training.py
@@ -1,6 +1,7 @@
import logging
from itertools import combinations
from itertools import product
+from statistics import mean
from mlscraper.matches import DictMatch
from mlscraper.matches import ListMatch
@@ -25,8 +26,11 @@ class NoScraperFoundException(TrainingException):
def get_match_combination_priority(matches):
+ if len(matches) == 1:
+ return 1
+
# check for similarity between matches
- return sum(m1.get_similarity_to(m2) for m1, m2 in combinations(matches, 2))
+ return mean(m1.get_similarity_to(m2) for m1, m2 in combinations(matches, 2))
def train_scraper(training_set: TrainingSet, complexity=100):