diff options
author | Karl Lorey <git@karllorey.com> | 2022-06-15 17:33:48 +0200 |
---|---|---|
committer | Karl Lorey <git@karllorey.com> | 2022-06-15 17:33:48 +0200 |
commit | 7d7a07ea8baf7ee2af2a93b41ac42ee73d16fbda (patch) | |
tree | 98da7246e2071d3f0f32efe0c853819633e07983 | |
parent | 04aef4bfdd331e99d4d5f9f10111fed6b8e39de1 (diff) |
Rewrite training module to decrease complexity
-rw-r--r-- | mlscraper/html.py | 6 | ||||
-rw-r--r-- | mlscraper/matches.py | 7 | ||||
-rw-r--r-- | mlscraper/samples.py | 4 | ||||
-rw-r--r-- | mlscraper/selectors.py | 19 | ||||
-rw-r--r-- | mlscraper/training.py | 155 | ||||
-rw-r--r-- | tests/test_scrapers.py | 8 | ||||
-rw-r--r-- | tests/test_training.py | 4 |
7 files changed, 131 insertions, 72 deletions
diff --git a/mlscraper/html.py b/mlscraper/html.py index 2b2ecd2..c19f0db 100644 --- a/mlscraper/html.py +++ b/mlscraper/html.py @@ -102,6 +102,12 @@ class Node: # todo implement other find methods + def has_parent(self, node: "Node"): + for p in self.soup.parents: + if p == node.soup: + return True + return False + def generate_path_selectors(self): """ Generate a selector for the path to the given node. diff --git a/mlscraper/matches.py b/mlscraper/matches.py index 51c4255..ca5d66a 100644 --- a/mlscraper/matches.py +++ b/mlscraper/matches.py @@ -136,6 +136,7 @@ class ListMatch(Match): def __repr__(self): return f"<{self.__class__.__name__} {self.matches=}>" + @property def root(self) -> Node: return get_root_node([m.root for m in self.matches]) @@ -156,8 +157,10 @@ class ValueMatch(Match): return self.node -def generate_all_matches(node: Node, item) -> typing.Generator[Match, None, None]: - logging.info(f"generating all matches ({node=}, {item=})") +def generate_all_value_matches( + node: Node, item: str +) -> typing.Generator[Match, None, None]: + logging.info(f"generating all value matches ({node=}, {item=})") for html_match in node.find_all(item): matched_node = html_match.node if isinstance(html_match, TextMatch): diff --git a/mlscraper/samples.py b/mlscraper/samples.py index df7f576..f369de4 100644 --- a/mlscraper/samples.py +++ b/mlscraper/samples.py @@ -5,7 +5,7 @@ from itertools import product from mlscraper.html import Node from mlscraper.html import Page from mlscraper.matches import DictMatch -from mlscraper.matches import generate_all_matches +from mlscraper.matches import generate_all_value_matches from mlscraper.matches import ListMatch from mlscraper.matches import Matcher from mlscraper.selectors import CssRuleSelector @@ -28,7 +28,7 @@ class Sample: # todo: fix creating new sample objects, maybe by using Item class? if isinstance(self.value, str): - return list(generate_all_matches(self.page, self.value)) + return list(generate_all_value_matches(self.page, self.value)) if isinstance(self.value, list): matches_by_value = [Sample(self.page, v).get_matches() for v in self.value] diff --git a/mlscraper/selectors.py b/mlscraper/selectors.py index 31b26c9..8a15ad3 100644 --- a/mlscraper/selectors.py +++ b/mlscraper/selectors.py @@ -4,6 +4,7 @@ import typing from mlscraper.html import Node from mlscraper.html import Page from mlscraper.html import selector_matches_nodes +from more_itertools import bucket class Selector: @@ -18,6 +19,15 @@ class Selector: raise NotImplementedError() +class PassThroughSelector(Selector): + def select_one(self, node: Node) -> Node: + return node + + def select_all(self, node: Node) -> typing.List[Node]: + # this does not make sense as we have only one node to pass through + raise RuntimeError("cannot apply select_all to PassThroughSelector") + + class CssRuleSelector(Selector): def __init__(self, css_rule): self.css_rule = css_rule @@ -38,7 +48,8 @@ class CssRuleSelector(Selector): def generate_selector_for_nodes(nodes: typing.List[Node], roots): - logging.info(f"trying to find selector for nodes ({nodes=})") + logging.info(f"trying to find selector for nodes ({nodes=}, {roots=})") + assert nodes, "no nodes given" if roots is None: logging.info("roots is None, setting roots manually") @@ -46,10 +57,8 @@ def generate_selector_for_nodes(nodes: typing.List[Node], roots): # todo roots and nodes can be uneven here because we just want to find a way # to select all the nodes from the given roots - - nodes_per_root = {} - for root in set(roots): - nodes_per_root[root] = [n for n in nodes if n.root == root] + nodes_per_root = {r: [n for n in nodes if n.has_parent(r)] for r in set(roots)} + logging.info(f"item by root: %s", nodes_per_root) selectors_seen = set() diff --git a/mlscraper/training.py b/mlscraper/training.py index 6647bdd..38b15d9 100644 --- a/mlscraper/training.py +++ b/mlscraper/training.py @@ -2,17 +2,18 @@ import logging import typing from itertools import product -from mlscraper.html import Node -from mlscraper.samples import DictItem -from mlscraper.samples import Item -from mlscraper.samples import ListItem -from mlscraper.samples import make_matcher_for_samples -from mlscraper.samples import Sample -from mlscraper.samples import ValueItem +from mlscraper.matches import DictMatch +from mlscraper.matches import ListMatch +from mlscraper.matches import ValueMatch +from mlscraper.samples import TrainingSet from mlscraper.scrapers import DictScraper from mlscraper.scrapers import ListScraper from mlscraper.scrapers import ValueScraper from mlscraper.selectors import generate_selector_for_nodes +from mlscraper.selectors import PassThroughSelector +from more_itertools import first +from more_itertools import flatten +from more_itertools import unzip class TrainingException(Exception): @@ -23,64 +24,98 @@ class NoScraperFoundException(TrainingException): pass -def train_scraper(item: Item, roots: typing.Optional[typing.List[Node]] = None): +def train_scraper(training_set: TrainingSet): """ Train a scraper able to extract the given training data. """ - logging.info(f"training {item}") - - # set roots to page if not set - if roots is None: - roots = [s.page for s in item.samples] - logging.info(f"roots inferred: {roots}") - - assert len(item.samples) == len(roots), f"{len(item.samples)=} != {len(roots)=}" - - if isinstance(item, ListItem): - # so we have to extract a list from each root - # to do this, we take all matches we can find - # and try to find a common selector for one of the combinations root elements - # if that works, we've succeeded - # if not, we're out of luck for now - # todo add root to get_matches to receive only matches below roots - matches_per_sample = [s.get_matches() for s in item.item.samples] - for match_combi in product(*matches_per_sample): - # match_combi is one possible way to combine matches to extract the list - logging.info(f"{match_combi=}") - # we now take the root of every element - match_roots = [m.root for m in match_combi] - logging.info(f"{match_roots=}") - for selector in generate_selector_for_nodes(match_roots, roots): - # roots are the newly matched root elements - item_scraper = train_scraper(item.item, match_roots) - scraper = ListScraper(selector, item_scraper) - return scraper - - raise NoScraperFoundException(f"no matcher found for {item}") - - if isinstance(item, DictItem): - # train a scraper for each key, keep roots - scraper_per_key = { - k: train_scraper(i, roots) for k, i in item.item_per_key.items() - } - return DictScraper(scraper_per_key) + logging.info(f"training {training_set=}") - if isinstance(item, ValueItem): - # find a selector that uniquely matches the value given the root node - matcher = make_matcher_for_samples(item.samples, roots) - if matcher: - return ValueScraper(matcher.selector, matcher.extractor) - else: - raise NoScraperFoundException(f"deriving matcher failed for {item}") + sample_matches = [s.get_matches() for s in training_set.item.samples] + roots = [s.page for s in training_set.item.samples] + for match_combination in product(*sample_matches): + logging.info(f"trying to train scraper for matches ({match_combination=})") + scraper = train_scraper_for_matches(match_combination, roots) + return scraper -def get_smallest_span_match_per_sample(samples: typing.List[Sample]): +def train_scraper_for_matches(matches, roots): """ - Get the best match for each sample by using the smallest span. - :param samples: - :return: + Train a scraper that finds the given matches from the given roots. + :param matches: the matches to scrape + :param roots: the root elements containing the matches, e.g. pages or elements on pages """ - best_match_per_sample = [ - sorted(s.get_matches(), key=lambda m: m.get_span())[0] for s in samples - ] - return best_match_per_sample + found_types = set(map(type, matches)) + assert ( + len(found_types) == 1 + ), f"different match types passed {found_types=}, {matches=}" + found_type = first(found_types) + + # make sure we have lists + matches = list(matches) + roots = list(roots) + + assert len(matches) == len(roots), f"got uneven inputs ({matches=}, {roots=})" + if found_type == ValueMatch: + logging.info("training ValueScraper") + matches: typing.List[ValueMatch] + + # if matches have different extractors, we can't find a common scraper + extractors = set(map(lambda m: m.extractor, matches)) + if len(extractors) != 1: + raise NoScraperFoundException( + "different extractors found for matches, aborting" + ) + extractor = first(extractors) + + # early return: nodes are matched already, e.g. for List of Values + if all(m.node == r for m, r in zip(matches, roots)): + # nodes are matched already, done + return ValueScraper(PassThroughSelector(), extractor=extractor) + + selector = first( + generate_selector_for_nodes([m.node for m in matches], roots), None + ) + if not selector: + raise NoScraperFoundException(f"no selector found {matches=}") + return ValueScraper(selector, extractor) + elif found_type == DictMatch: + logging.info("training DictScraper") + matches: typing.List[DictMatch] + + # what if some matches have missing keys? idk + # by using union of all keys, we'll get errors two lines below to be sure + keys = set(flatten(m.match_by_key.keys() for m in matches)) + + # train scraper for each key of dict + # matches are the matches for the keys + # roots are the original roots(?) + scraper_per_key = { + k: train_scraper_for_matches([m.match_by_key[k] for m in matches], roots) + for k in keys + } + return DictScraper(scraper_per_key) + elif found_type == ListMatch: + logging.info("training ListScraper") + matches: typing.List[ListMatch] + + # so we have a list of ListMatch objects + # we have to find a selector that uniquely matches the list elements + # todo can be one of the parents + match_roots = [m.root for m in matches] + logging.info(f"{match_roots=}") + selector = first(generate_selector_for_nodes(match_roots, roots)) + if selector: + # for all the item_matches, create a tuple + # that contains the item_match and the new root + matches_and_roots = [ + (im, selector.select_one(r)) + for m, r in zip(matches, roots) + for im in m.matches + ] + item_matches, list_roots = unzip(matches_and_roots) + item_scraper = train_scraper_for_matches( + list(item_matches), list(list_roots) + ) + return ListScraper(selector, item_scraper) + else: + raise RuntimeError(f"type not matched: {found_type}") diff --git a/tests/test_scrapers.py b/tests/test_scrapers.py index dfb4781..200dde5 100644 --- a/tests/test_scrapers.py +++ b/tests/test_scrapers.py @@ -5,6 +5,7 @@ from mlscraper.scrapers import DictScraper from mlscraper.scrapers import ListScraper from mlscraper.scrapers import ValueScraper from mlscraper.selectors import CssRuleSelector +from mlscraper.selectors import PassThroughSelector class TestListOfDictScraper: @@ -62,7 +63,12 @@ class TestValueScraper: assert vs.get(page1) == "test" assert vs.get(page2) == "hallo" + class TestListOfValuesScraper: def test_list_of_values_scraper(self): page = Page(b"<html><body><p>a</p><i>noise</i><p>b</p><p>c</p></body></html>") - ListScraper('p', ) + scraper = ListScraper( + CssRuleSelector("p"), + ValueScraper(PassThroughSelector(), TextValueExtractor()), + ) + assert scraper.get(page) == ["a", "b", "c"] diff --git a/tests/test_training.py b/tests/test_training.py index 947fbb0..49288a3 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -13,7 +13,7 @@ def test_train_scraper_simple_list(): ["a", "b", "c"], ) training_set.add_sample(sample) - train_scraper(training_set.item) + train_scraper(training_set) @pytest.mark.skip("fucking fails") @@ -22,7 +22,7 @@ def test_train_scraper(stackoverflow_samples): for s in stackoverflow_samples: training_set.add_sample(s) - scraper = train_scraper(training_set.item) + scraper = train_scraper(training_set) print(f"result scraper: {scraper}") print(f"selector for list items: {scraper.selector}") |