diff options
Diffstat (limited to 'mlscraper/training.py')
-rw-r--r-- | mlscraper/training.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/mlscraper/training.py b/mlscraper/training.py index b57648f..04485f2 100644 --- a/mlscraper/training.py +++ b/mlscraper/training.py @@ -1,10 +1,8 @@ import logging import typing from itertools import combinations -from itertools import permutations from itertools import product -from mlscraper.html import get_relative_depth from mlscraper.matches import DictMatch from mlscraper.matches import ListMatch from mlscraper.matches import ValueMatch @@ -12,7 +10,7 @@ from mlscraper.samples import TrainingSet from mlscraper.scrapers import DictScraper from mlscraper.scrapers import ListScraper from mlscraper.scrapers import ValueScraper -from mlscraper.selectors import generate_selector_for_nodes +from mlscraper.selectors import generate_unique_selectors_for_nodes from mlscraper.selectors import PassThroughSelector from more_itertools import first from more_itertools import flatten @@ -47,6 +45,8 @@ def train_scraper(training_set: TrainingSet): match_combinations_by_depth = sorted( match_combinations, key=lambda mc: sum(m.depth for m in mc), reverse=True ) + # todo compute selectivity of classes to use selective ones first + # todo cache selectors of node combinations to avoid re-selecting after increasing complexity for complexity in range(3): for match_combination in match_combinations_by_depth: logging.info( @@ -115,11 +115,14 @@ def train_scraper_for_matches(matches, roots, complexity: int): ) selector = first( - generate_selector_for_nodes([m.node for m in matches], roots, complexity), + generate_unique_selectors_for_nodes( + [m.node for m in matches], roots, complexity + ), None, ) if not selector: raise NoScraperFoundException(f"no selector found {matches=}") + logging.info(f"found selector for ValueScraper ({selector=})") return ValueScraper(selector, extractor) elif found_type == DictMatch: logging.info("training DictScraper") @@ -138,6 +141,7 @@ def train_scraper_for_matches(matches, roots, complexity: int): ) for k in keys } + logging.info(f"found DictScraper ({scraper_per_key=})") return DictScraper(scraper_per_key) elif found_type == ListMatch: logging.info("training ListScraper") @@ -160,7 +164,9 @@ def train_scraper_for_matches(matches, roots, complexity: int): # no need to try other selectors # -> item_scraper would be the same selector = first( - generate_selector_for_nodes(list(item_nodes), list(item_roots), complexity), + generate_unique_selectors_for_nodes( + list(item_nodes), list(item_roots), complexity + ), None, ) if selector: |