summaryrefslogtreecommitdiffstats
path: root/mlscraper/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'mlscraper/training.py')
-rw-r--r--mlscraper/training.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/mlscraper/training.py b/mlscraper/training.py
index b57648f..04485f2 100644
--- a/mlscraper/training.py
+++ b/mlscraper/training.py
@@ -1,10 +1,8 @@
import logging
import typing
from itertools import combinations
-from itertools import permutations
from itertools import product
-from mlscraper.html import get_relative_depth
from mlscraper.matches import DictMatch
from mlscraper.matches import ListMatch
from mlscraper.matches import ValueMatch
@@ -12,7 +10,7 @@ from mlscraper.samples import TrainingSet
from mlscraper.scrapers import DictScraper
from mlscraper.scrapers import ListScraper
from mlscraper.scrapers import ValueScraper
-from mlscraper.selectors import generate_selector_for_nodes
+from mlscraper.selectors import generate_unique_selectors_for_nodes
from mlscraper.selectors import PassThroughSelector
from more_itertools import first
from more_itertools import flatten
@@ -47,6 +45,8 @@ def train_scraper(training_set: TrainingSet):
match_combinations_by_depth = sorted(
match_combinations, key=lambda mc: sum(m.depth for m in mc), reverse=True
)
+ # todo compute selectivity of classes to use selective ones first
+ # todo cache selectors of node combinations to avoid re-selecting after increasing complexity
for complexity in range(3):
for match_combination in match_combinations_by_depth:
logging.info(
@@ -115,11 +115,14 @@ def train_scraper_for_matches(matches, roots, complexity: int):
)
selector = first(
- generate_selector_for_nodes([m.node for m in matches], roots, complexity),
+ generate_unique_selectors_for_nodes(
+ [m.node for m in matches], roots, complexity
+ ),
None,
)
if not selector:
raise NoScraperFoundException(f"no selector found {matches=}")
+ logging.info(f"found selector for ValueScraper ({selector=})")
return ValueScraper(selector, extractor)
elif found_type == DictMatch:
logging.info("training DictScraper")
@@ -138,6 +141,7 @@ def train_scraper_for_matches(matches, roots, complexity: int):
)
for k in keys
}
+ logging.info(f"found DictScraper ({scraper_per_key=})")
return DictScraper(scraper_per_key)
elif found_type == ListMatch:
logging.info("training ListScraper")
@@ -160,7 +164,9 @@ def train_scraper_for_matches(matches, roots, complexity: int):
# no need to try other selectors
# -> item_scraper would be the same
selector = first(
- generate_selector_for_nodes(list(item_nodes), list(item_roots), complexity),
+ generate_unique_selectors_for_nodes(
+ list(item_nodes), list(item_roots), complexity
+ ),
None,
)
if selector: