diff options
author | Karl Lorey <git@karllorey.com> | 2022-07-07 13:32:38 +0200 |
---|---|---|
committer | Karl Lorey <git@karllorey.com> | 2022-07-07 13:32:38 +0200 |
commit | 20c492c62dd7457c485c3bf79e28a81ea135d84d (patch) | |
tree | 6f50838f75369059f6e162837c554c3749cc30ff | |
parent | 2c58adb28b1bc8750d0a434f721676b8cf83734f (diff) |
Use mean instead of sum for similarity to improve understandability
-rw-r--r-- | mlscraper/matches.py | 5 | ||||
-rw-r--r-- | mlscraper/training.py | 6 |
2 files changed, 8 insertions, 3 deletions
diff --git a/mlscraper/matches.py b/mlscraper/matches.py index 3b039d5..d34c430 100644 --- a/mlscraper/matches.py +++ b/mlscraper/matches.py @@ -6,6 +6,7 @@ import typing from functools import cached_property from itertools import combinations from itertools import product +from statistics import mean from mlscraper.html import get_relative_depth from mlscraper.html import get_root_node @@ -135,7 +136,7 @@ class DictMatch(Match): keys = set(self.match_by_key.keys()).intersection( set(match.match_by_key.keys()) ) - return sum( + return mean( self.match_by_key[key].get_similarity_to(match.match_by_key[key]) for key in keys ) @@ -163,7 +164,7 @@ class ListMatch(Match): def get_similarity_to(self, match: "Match"): assert isinstance(match, self.__class__) - return sum( + return mean( lm1.get_similarity_to(lm2) for lm1, lm2 in product(self.matches, match.matches) ) diff --git a/mlscraper/training.py b/mlscraper/training.py index 3e3b259..d9b1b8a 100644 --- a/mlscraper/training.py +++ b/mlscraper/training.py @@ -1,6 +1,7 @@ import logging from itertools import combinations from itertools import product +from statistics import mean from mlscraper.matches import DictMatch from mlscraper.matches import ListMatch @@ -25,8 +26,11 @@ class NoScraperFoundException(TrainingException): def get_match_combination_priority(matches): + if len(matches) == 1: + return 1 + # check for similarity between matches - return sum(m1.get_similarity_to(m2) for m1, m2 in combinations(matches, 2)) + return mean(m1.get_similarity_to(m2) for m1, m2 in combinations(matches, 2)) def train_scraper(training_set: TrainingSet, complexity=100): |