diff options
author | Karl Lorey <git@karllorey.com> | 2022-06-20 11:12:00 +0200 |
---|---|---|
committer | Karl Lorey <git@karllorey.com> | 2022-06-20 11:12:00 +0200 |
commit | c6c371223d56f23ad4a588231b5e6f51bee4259c (patch) | |
tree | 45d68736dfce2c8bee9d061e19224161f3916332 | |
parent | 3a4c3234653984768a992747ba45da5e34d3af9c (diff) |
Re-implement selector generation with a speedup >10xdevelop
-rw-r--r-- | mlscraper/html.py | 108 | ||||
-rw-r--r-- | mlscraper/scrapers.py | 2 | ||||
-rw-r--r-- | mlscraper/selectors.py | 109 | ||||
-rw-r--r-- | mlscraper/training.py | 16 | ||||
-rw-r--r-- | mlscraper/util.py | 11 | ||||
-rw-r--r-- | tests/test_selectors.py | 56 | ||||
-rw-r--r-- | tests/test_training.py | 22 | ||||
-rw-r--r-- | tests/test_util.py | 9 |
8 files changed, 176 insertions, 157 deletions
diff --git a/mlscraper/html.py b/mlscraper/html.py index f0d222e..6fb516b 100644 --- a/mlscraper/html.py +++ b/mlscraper/html.py @@ -6,13 +6,11 @@ import logging import typing from abc import ABC from dataclasses import dataclass -from itertools import combinations -from itertools import product +from functools import cached_property from bs4 import BeautifulSoup from bs4 import NavigableString from bs4 import Tag -from mlscraper.util import powerset_max_length @dataclass @@ -30,44 +28,6 @@ class AttributeMatch(Match): attr: str = None -def _generate_css_selectors_for_node(soup: Tag, complexity: int): - """ - Generate a selector for the given node. - :param soup: - :return: - """ - assert isinstance(soup, Tag) - - # use id - tag_id = soup.attrs.get("id", None) - if tag_id: - yield "#" + tag_id - - yield soup.name - - # use classes - css_classes = soup.attrs.get("class", []) - for css_class_combo in powerset_max_length(css_classes, complexity): - if css_class_combo: - css_clases_str = "".join([f".{css_class}" for css_class in css_class_combo]) - yield css_clases_str - yield soup.name + css_clases_str - else: - # empty set, no selector - pass - - # todo: nth applies to whole selectors - # -> should thus be a step after actual selector generation - if isinstance(soup.parent, Tag) and hasattr(soup, "name"): - children_tags = [c for c in soup.parent.children if isinstance(c, Tag)] - child_index = list(children_tags).index(soup) + 1 - yield ":nth-child(%d)" % child_index - - children_of_same_type = [c for c in children_tags if c.name == soup.name] - child_index = children_of_same_type.index(soup) + 1 - yield ":nth-of-type(%d)" % child_index - - class Node: soup = None _page = None @@ -111,70 +71,22 @@ class Node: return True return False + @cached_property + def parents(self): + return [self._page._get_node_for_soup(p) for p in self.soup.parents] + @property def classes(self): return self.soup.attrs.get("class", []) @property + def id(self): + return self.soup.attrs.get("id", None) + + @property def tag_name(self): return self.soup.name - def generate_path_selectors(self, complexity: int): - """ - Generate a selector for the path to the given node. - :return: - """ - if not isinstance(self.soup, Tag): - error_msg = "Only tags can be selected with CSS, %s given" % type(self.soup) - raise RuntimeError(error_msg) - - # we have a list of n ancestor notes and n-1 nodes including the last node - # the last node must get selected always - - # so we will: - # 1) generate all selectors for current node - # 2) append possible selectors for the n-1 descendants - # starting with all node selectors and increasing number of used descendants - - # remove unique parents as they don't improve selection - # body is unique, html is unique, document is bs4 root element - parents = [ - n for n in self.soup.parents if n.name not in ("body", "html", "[document]") - ] - # print(parents) - - # loop from i=0 to i=len(parents) as we consider all parents - parent_node_count_max = min(len(parents), complexity) - for parent_node_count in range(parent_node_count_max + 1): - logging.info( - "generating path selectors with %d parents" % parent_node_count - ) - # generate paths with exactly parent_node_count nodes - for parent_nodes_sampled in combinations(parents, parent_node_count): - path_sampled = (self.soup,) + parent_nodes_sampled - # logging.info(path_sampled) - - # make a list of selector generators for each node in the path - # todo limit generated selectors -> huge product - selector_generators_for_each_path_node = [ - _generate_css_selectors_for_node(n, complexity) - for n in path_sampled - ] - - # generator that outputs selector paths - # e.g. (div, .wrapper, .main) - path_sampled_selectors = product( - *selector_generators_for_each_path_node - ) - - # create an actual css selector for each selector path - # e.g. .main > .wrapper > .div - for path_sampled_selector in path_sampled_selectors: - # if paths are not directly connected, i.e. (1)-(2)-3-(4) - # join must be " " and not " > " - css_selector = " ".join(reversed(path_sampled_selector)) - yield css_selector - def select(self, css_selector): return [ self._page._get_node_for_soup(n) for n in self.soup.select(css_selector) @@ -183,7 +95,7 @@ class Node: def __repr__(self): if isinstance(self.soup, NavigableString): return f"<{self.__class__.__name__} {self.soup.strip()[:10]=}>" - return f"<{self.__class__.__name__} {self.soup.name=} classes={self.soup.get('class', None)}, text={self.soup.text.strip()[:10]}...>" + return f"<{self.__class__.__name__} {self.soup.name=} classes={self.soup.get('class', None)}, text={''.join(self.soup.stripped_strings)[:10]}...>" def __hash__(self): return self.soup.__hash__() diff --git a/mlscraper/scrapers.py b/mlscraper/scrapers.py index 7ea16f5..d271719 100644 --- a/mlscraper/scrapers.py +++ b/mlscraper/scrapers.py @@ -37,7 +37,7 @@ class ListScraper(Scraper): ] def __repr__(self): - return f"<ListScraper {self.scraper=}>" + return f"<ListScraper {self.selector=} {self.scraper=}>" class ValueScraper(Scraper): diff --git a/mlscraper/selectors.py b/mlscraper/selectors.py index 1ba4c55..e07c084 100644 --- a/mlscraper/selectors.py +++ b/mlscraper/selectors.py @@ -1,8 +1,11 @@ import logging import typing +from itertools import product from mlscraper.html import Node -from mlscraper.html import selector_matches_nodes +from mlscraper.util import no_duplicates_generator_decorator +from more_itertools import first +from more_itertools import powerset class Selector: @@ -45,45 +48,75 @@ class CssRuleSelector(Selector): return f"<{self.__class__.__name__} {self.css_rule=}>" -def generate_selector_for_nodes(nodes: typing.List[Node], roots, complexity: int): - logging.info( - f"trying to find selector for nodes ({nodes=}, {roots=}, {complexity=})" - ) - assert nodes, "no nodes given" - +def generate_unique_selectors_for_nodes( + nodes: typing.List[Node], roots, complexity: int +): + """ + generate a unique selector which only matches the given nodes. + """ if roots is None: logging.info("roots is None, setting roots manually") roots = [n.root for n in nodes] - # todo roots and nodes can be uneven here because we just want to find a way - # to select all the nodes from the given roots nodes_per_root = {r: [n for n in nodes if n.has_parent(r)] for r in set(roots)} - logging.info(f"item by root: %s", nodes_per_root) - - selectors_seen = set() - - for node in nodes: - for sel in node.generate_path_selectors(complexity): - logging.info(f"selector: {sel}") - if sel not in selectors_seen: - logging.info( - f"nodes per root: {nodes_per_root}", - ) - # check if selector returns the desired nodes per root - if all( - selector_matches_nodes(root, sel, nodes) - for root, nodes in nodes_per_root.items() - ): - logging.info(f"selector matches all nodes exactly ({sel=})") - yield CssRuleSelector(sel) - else: - for root, nodes in nodes_per_root.items(): - logging.info( - f"{root=}, {sel=}, {selector_matches_nodes(root, sel, nodes)=}" - ) - logging.info(f"selector does not match nodes exactly: {sel}") - - # add to seen - selectors_seen.add(sel) - else: - logging.info(f"selector already checked: {sel}") + for selector in generate_selectors_for_nodes(nodes, roots, complexity): + if all( + selector.select_all(r) == nodes_of_root + for r, nodes_of_root in nodes_per_root.items() + ): + yield selector + + +@no_duplicates_generator_decorator +def generate_selectors_for_nodes(nodes: typing.List[Node], roots, complexity: int): + """ + Generate a selector which matches the given nodes. + """ + + logging.info( + f"trying to find selector for nodes ({nodes=}, {roots=}, {complexity=})" + ) + assert nodes, "no nodes given" + assert roots, "no roots given" + assert len(nodes) == len(roots) + + direct_css_selectors = list(_generate_direct_css_selectors_for_nodes(nodes)) + for direct_css_selector in direct_css_selectors: + yield CssRuleSelector(direct_css_selector) + + parents_of_nodes_below_roots = [ + [p for p in n.parents if p.has_parent(r) and p.tag_name not in ["html", "body"]] + for n, r in zip(nodes, roots) + ] + for parent_nodes in product(*parents_of_nodes_below_roots): + for parent_selector_raw in _generate_direct_css_selectors_for_nodes( + parent_nodes + ): + for css_selector_raw in direct_css_selectors: + css_selector_combined = parent_selector_raw + " " + css_selector_raw + yield CssRuleSelector(css_selector_combined) + + +def _generate_direct_css_selectors_for_nodes(nodes: typing.List[Node]): + logging.info(f"generating direct css selector for nodes ({nodes=})") + common_classes = set.intersection(*[set(n.classes) for n in nodes]) + + is_same_tag = len({n.tag_name for n in nodes}) == 1 + common_tag_name = nodes[0].tag_name + yield common_tag_name + + common_ids = {n.id for n in nodes} + is_same_id = len(common_ids) == 1 + if is_same_id and None not in common_ids: + yield "#" + first(common_ids) + + for class_combination in powerset(common_classes): + if class_combination: + logging.info(f"- generating selector for ({class_combination=})") + css_selector = "".join(map(lambda cl: "." + cl, class_combination)) + yield css_selector + if is_same_tag: + yield common_tag_name + css_selector + else: + # empty combination -> ignore + pass diff --git a/mlscraper/training.py b/mlscraper/training.py index b57648f..04485f2 100644 --- a/mlscraper/training.py +++ b/mlscraper/training.py @@ -1,10 +1,8 @@ import logging import typing from itertools import combinations -from itertools import permutations from itertools import product -from mlscraper.html import get_relative_depth from mlscraper.matches import DictMatch from mlscraper.matches import ListMatch from mlscraper.matches import ValueMatch @@ -12,7 +10,7 @@ from mlscraper.samples import TrainingSet from mlscraper.scrapers import DictScraper from mlscraper.scrapers import ListScraper from mlscraper.scrapers import ValueScraper -from mlscraper.selectors import generate_selector_for_nodes +from mlscraper.selectors import generate_unique_selectors_for_nodes from mlscraper.selectors import PassThroughSelector from more_itertools import first from more_itertools import flatten @@ -47,6 +45,8 @@ def train_scraper(training_set: TrainingSet): match_combinations_by_depth = sorted( match_combinations, key=lambda mc: sum(m.depth for m in mc), reverse=True ) + # todo compute selectivity of classes to use selective ones first + # todo cache selectors of node combinations to avoid re-selecting after increasing complexity for complexity in range(3): for match_combination in match_combinations_by_depth: logging.info( @@ -115,11 +115,14 @@ def train_scraper_for_matches(matches, roots, complexity: int): ) selector = first( - generate_selector_for_nodes([m.node for m in matches], roots, complexity), + generate_unique_selectors_for_nodes( + [m.node for m in matches], roots, complexity + ), None, ) if not selector: raise NoScraperFoundException(f"no selector found {matches=}") + logging.info(f"found selector for ValueScraper ({selector=})") return ValueScraper(selector, extractor) elif found_type == DictMatch: logging.info("training DictScraper") @@ -138,6 +141,7 @@ def train_scraper_for_matches(matches, roots, complexity: int): ) for k in keys } + logging.info(f"found DictScraper ({scraper_per_key=})") return DictScraper(scraper_per_key) elif found_type == ListMatch: logging.info("training ListScraper") @@ -160,7 +164,9 @@ def train_scraper_for_matches(matches, roots, complexity: int): # no need to try other selectors # -> item_scraper would be the same selector = first( - generate_selector_for_nodes(list(item_nodes), list(item_roots), complexity), + generate_unique_selectors_for_nodes( + list(item_nodes), list(item_roots), complexity + ), None, ) if selector: diff --git a/mlscraper/util.py b/mlscraper/util.py index 1e3f920..00bc25d 100644 --- a/mlscraper/util.py +++ b/mlscraper/util.py @@ -3,3 +3,14 @@ from more_itertools import powerset def powerset_max_length(candidates, length): return filter(lambda s: len(s) <= length, powerset(candidates)) + + +def no_duplicates_generator_decorator(func): + def inner(*args, **kwargs): + seen = set() + for item in func(*args, **kwargs): + if item not in seen: + yield item + seen.add(item) + + return inner diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 3e4570a..9e7acc0 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -1,21 +1,49 @@ from mlscraper.html import Page -from mlscraper.selectors import generate_selector_for_nodes +from mlscraper.selectors import generate_unique_selectors_for_nodes -def test_generate_selector_for_nodes(): - page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>' - page1 = Page(page1_html) +def get_css_selectors_for_node(node): + """ + helper to extract plain css rules + """ + return [ + selector.css_rule + for selector in generate_unique_selectors_for_nodes([node], None, 100) + ] - page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>' - page2 = Page(page2_html) - nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2])) - gen = generate_selector_for_nodes(nodes, None, 1) - selectors_found = [sel.css_rule for sel in gen] - assert {".test", "p.test"} == set(selectors_found) +class TestGenerateUniqueSelectorsForNodes: + def test_basic(self): + page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>' + page1 = Page(page1_html) + page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>' + page2 = Page(page2_html) -class TestGenerateSelectorForNodes: - def test_generate_selector_for_nodes(self): - # generate_selector_for_nodes() - pass + nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2])) + gen = generate_unique_selectors_for_nodes(nodes, None, 1) + selectors_found = [sel.css_rule for sel in gen] + + assert "p" not in selectors_found + assert "div" not in selectors_found + + assert ".test" in selectors_found + assert "p.test" in selectors_found + + def test_ids(self): + page = Page( + b""" + <html><body> + <div id="target">test</div> + <div>irrelevant</div> + </body></html>""" + ) + node = page.select("#target")[0] + selectors = get_css_selectors_for_node(node) + assert selectors == ["#target"] + + def test_multi_parents(self): + page = Page(b'<html><body><div id="target"><p>test</p></div><div><p></p></div>') + node = page.select("#target")[0].select("p")[0] + selectors = get_css_selectors_for_node(node) + assert "#target p" in selectors diff --git a/tests/test_training.py b/tests/test_training.py index ccd0441..f0f1703 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,4 +1,3 @@ -import pytest from mlscraper.html import Page from mlscraper.matches import TextValueExtractor from mlscraper.samples import Sample @@ -61,6 +60,27 @@ def test_train_scraper_list_of_dicts(): assert isinstance(value_scraper.extractor, TextValueExtractor) +def test_train_scraper_multipage(): + training_set = TrainingSet() + for items in ["ab", "cd"]: + html = b""" + <html><body> + <div class="target"> + <ul><li>%s</li><li>%s</li></ul> + </div> + </body></html> + """ % ( + items[0].encode(), + items[1].encode(), + ) + training_set.add_sample(Sample(Page(html), [items[0], items[1]])) + scraper = train_scraper(training_set) + assert scraper.selector.css_rule == "li" + assert scraper.get( + Page(b"""<html><body><ul><li>first</li><li>second</li></body></html>""") + ) == ["first", "second"] + + def test_train_scraper_stackoverflow(stackoverflow_samples): training_set = TrainingSet() for s in stackoverflow_samples: diff --git a/tests/test_util.py b/tests/test_util.py index e69de29..c2428e8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -0,0 +1,9 @@ +from mlscraper.util import no_duplicates_generator_decorator + + +def test_no_duplicates_generator_decorator(): + @no_duplicates_generator_decorator + def decorated_generator(): + yield from [1, 1, 2, 3, 3, 3] + + assert list(decorated_generator()) == [1, 2, 3] |