diff options
-rw-r--r-- | mlscraper/matches.py | 5 | ||||
-rw-r--r-- | mlscraper/training.py | 6 |
2 files changed, 8 insertions, 3 deletions
diff --git a/mlscraper/matches.py b/mlscraper/matches.py index 3b039d5..d34c430 100644 --- a/mlscraper/matches.py +++ b/mlscraper/matches.py @@ -6,6 +6,7 @@ import typing from functools import cached_property from itertools import combinations from itertools import product +from statistics import mean from mlscraper.html import get_relative_depth from mlscraper.html import get_root_node @@ -135,7 +136,7 @@ class DictMatch(Match): keys = set(self.match_by_key.keys()).intersection( set(match.match_by_key.keys()) ) - return sum( + return mean( self.match_by_key[key].get_similarity_to(match.match_by_key[key]) for key in keys ) @@ -163,7 +164,7 @@ class ListMatch(Match): def get_similarity_to(self, match: "Match"): assert isinstance(match, self.__class__) - return sum( + return mean( lm1.get_similarity_to(lm2) for lm1, lm2 in product(self.matches, match.matches) ) diff --git a/mlscraper/training.py b/mlscraper/training.py index 3e3b259..d9b1b8a 100644 --- a/mlscraper/training.py +++ b/mlscraper/training.py @@ -1,6 +1,7 @@ import logging from itertools import combinations from itertools import product +from statistics import mean from mlscraper.matches import DictMatch from mlscraper.matches import ListMatch @@ -25,8 +26,11 @@ class NoScraperFoundException(TrainingException): def get_match_combination_priority(matches): + if len(matches) == 1: + return 1 + # check for similarity between matches - return sum(m1.get_similarity_to(m2) for m1, m2 in combinations(matches, 2)) + return mean(m1.get_similarity_to(m2) for m1, m2 in combinations(matches, 2)) def train_scraper(training_set: TrainingSet, complexity=100): |