summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKarl Lorey <git@karllorey.com>2022-06-20 11:12:00 +0200
committerKarl Lorey <git@karllorey.com>2022-06-20 11:12:00 +0200
commitc6c371223d56f23ad4a588231b5e6f51bee4259c (patch)
tree45d68736dfce2c8bee9d061e19224161f3916332
parent3a4c3234653984768a992747ba45da5e34d3af9c (diff)
Re-implement selector generation with a speedup >10xdevelop
-rw-r--r--mlscraper/html.py108
-rw-r--r--mlscraper/scrapers.py2
-rw-r--r--mlscraper/selectors.py109
-rw-r--r--mlscraper/training.py16
-rw-r--r--mlscraper/util.py11
-rw-r--r--tests/test_selectors.py56
-rw-r--r--tests/test_training.py22
-rw-r--r--tests/test_util.py9
8 files changed, 176 insertions, 157 deletions
diff --git a/mlscraper/html.py b/mlscraper/html.py
index f0d222e..6fb516b 100644
--- a/mlscraper/html.py
+++ b/mlscraper/html.py
@@ -6,13 +6,11 @@ import logging
import typing
from abc import ABC
from dataclasses import dataclass
-from itertools import combinations
-from itertools import product
+from functools import cached_property
from bs4 import BeautifulSoup
from bs4 import NavigableString
from bs4 import Tag
-from mlscraper.util import powerset_max_length
@dataclass
@@ -30,44 +28,6 @@ class AttributeMatch(Match):
attr: str = None
-def _generate_css_selectors_for_node(soup: Tag, complexity: int):
- """
- Generate a selector for the given node.
- :param soup:
- :return:
- """
- assert isinstance(soup, Tag)
-
- # use id
- tag_id = soup.attrs.get("id", None)
- if tag_id:
- yield "#" + tag_id
-
- yield soup.name
-
- # use classes
- css_classes = soup.attrs.get("class", [])
- for css_class_combo in powerset_max_length(css_classes, complexity):
- if css_class_combo:
- css_clases_str = "".join([f".{css_class}" for css_class in css_class_combo])
- yield css_clases_str
- yield soup.name + css_clases_str
- else:
- # empty set, no selector
- pass
-
- # todo: nth applies to whole selectors
- # -> should thus be a step after actual selector generation
- if isinstance(soup.parent, Tag) and hasattr(soup, "name"):
- children_tags = [c for c in soup.parent.children if isinstance(c, Tag)]
- child_index = list(children_tags).index(soup) + 1
- yield ":nth-child(%d)" % child_index
-
- children_of_same_type = [c for c in children_tags if c.name == soup.name]
- child_index = children_of_same_type.index(soup) + 1
- yield ":nth-of-type(%d)" % child_index
-
-
class Node:
soup = None
_page = None
@@ -111,70 +71,22 @@ class Node:
return True
return False
+ @cached_property
+ def parents(self):
+ return [self._page._get_node_for_soup(p) for p in self.soup.parents]
+
@property
def classes(self):
return self.soup.attrs.get("class", [])
@property
+ def id(self):
+ return self.soup.attrs.get("id", None)
+
+ @property
def tag_name(self):
return self.soup.name
- def generate_path_selectors(self, complexity: int):
- """
- Generate a selector for the path to the given node.
- :return:
- """
- if not isinstance(self.soup, Tag):
- error_msg = "Only tags can be selected with CSS, %s given" % type(self.soup)
- raise RuntimeError(error_msg)
-
- # we have a list of n ancestor notes and n-1 nodes including the last node
- # the last node must get selected always
-
- # so we will:
- # 1) generate all selectors for current node
- # 2) append possible selectors for the n-1 descendants
- # starting with all node selectors and increasing number of used descendants
-
- # remove unique parents as they don't improve selection
- # body is unique, html is unique, document is bs4 root element
- parents = [
- n for n in self.soup.parents if n.name not in ("body", "html", "[document]")
- ]
- # print(parents)
-
- # loop from i=0 to i=len(parents) as we consider all parents
- parent_node_count_max = min(len(parents), complexity)
- for parent_node_count in range(parent_node_count_max + 1):
- logging.info(
- "generating path selectors with %d parents" % parent_node_count
- )
- # generate paths with exactly parent_node_count nodes
- for parent_nodes_sampled in combinations(parents, parent_node_count):
- path_sampled = (self.soup,) + parent_nodes_sampled
- # logging.info(path_sampled)
-
- # make a list of selector generators for each node in the path
- # todo limit generated selectors -> huge product
- selector_generators_for_each_path_node = [
- _generate_css_selectors_for_node(n, complexity)
- for n in path_sampled
- ]
-
- # generator that outputs selector paths
- # e.g. (div, .wrapper, .main)
- path_sampled_selectors = product(
- *selector_generators_for_each_path_node
- )
-
- # create an actual css selector for each selector path
- # e.g. .main > .wrapper > .div
- for path_sampled_selector in path_sampled_selectors:
- # if paths are not directly connected, i.e. (1)-(2)-3-(4)
- # join must be " " and not " > "
- css_selector = " ".join(reversed(path_sampled_selector))
- yield css_selector
-
def select(self, css_selector):
return [
self._page._get_node_for_soup(n) for n in self.soup.select(css_selector)
@@ -183,7 +95,7 @@ class Node:
def __repr__(self):
if isinstance(self.soup, NavigableString):
return f"<{self.__class__.__name__} {self.soup.strip()[:10]=}>"
- return f"<{self.__class__.__name__} {self.soup.name=} classes={self.soup.get('class', None)}, text={self.soup.text.strip()[:10]}...>"
+ return f"<{self.__class__.__name__} {self.soup.name=} classes={self.soup.get('class', None)}, text={''.join(self.soup.stripped_strings)[:10]}...>"
def __hash__(self):
return self.soup.__hash__()
diff --git a/mlscraper/scrapers.py b/mlscraper/scrapers.py
index 7ea16f5..d271719 100644
--- a/mlscraper/scrapers.py
+++ b/mlscraper/scrapers.py
@@ -37,7 +37,7 @@ class ListScraper(Scraper):
]
def __repr__(self):
- return f"<ListScraper {self.scraper=}>"
+ return f"<ListScraper {self.selector=} {self.scraper=}>"
class ValueScraper(Scraper):
diff --git a/mlscraper/selectors.py b/mlscraper/selectors.py
index 1ba4c55..e07c084 100644
--- a/mlscraper/selectors.py
+++ b/mlscraper/selectors.py
@@ -1,8 +1,11 @@
import logging
import typing
+from itertools import product
from mlscraper.html import Node
-from mlscraper.html import selector_matches_nodes
+from mlscraper.util import no_duplicates_generator_decorator
+from more_itertools import first
+from more_itertools import powerset
class Selector:
@@ -45,45 +48,75 @@ class CssRuleSelector(Selector):
return f"<{self.__class__.__name__} {self.css_rule=}>"
-def generate_selector_for_nodes(nodes: typing.List[Node], roots, complexity: int):
- logging.info(
- f"trying to find selector for nodes ({nodes=}, {roots=}, {complexity=})"
- )
- assert nodes, "no nodes given"
-
+def generate_unique_selectors_for_nodes(
+ nodes: typing.List[Node], roots, complexity: int
+):
+ """
+ generate a unique selector which only matches the given nodes.
+ """
if roots is None:
logging.info("roots is None, setting roots manually")
roots = [n.root for n in nodes]
- # todo roots and nodes can be uneven here because we just want to find a way
- # to select all the nodes from the given roots
nodes_per_root = {r: [n for n in nodes if n.has_parent(r)] for r in set(roots)}
- logging.info(f"item by root: %s", nodes_per_root)
-
- selectors_seen = set()
-
- for node in nodes:
- for sel in node.generate_path_selectors(complexity):
- logging.info(f"selector: {sel}")
- if sel not in selectors_seen:
- logging.info(
- f"nodes per root: {nodes_per_root}",
- )
- # check if selector returns the desired nodes per root
- if all(
- selector_matches_nodes(root, sel, nodes)
- for root, nodes in nodes_per_root.items()
- ):
- logging.info(f"selector matches all nodes exactly ({sel=})")
- yield CssRuleSelector(sel)
- else:
- for root, nodes in nodes_per_root.items():
- logging.info(
- f"{root=}, {sel=}, {selector_matches_nodes(root, sel, nodes)=}"
- )
- logging.info(f"selector does not match nodes exactly: {sel}")
-
- # add to seen
- selectors_seen.add(sel)
- else:
- logging.info(f"selector already checked: {sel}")
+ for selector in generate_selectors_for_nodes(nodes, roots, complexity):
+ if all(
+ selector.select_all(r) == nodes_of_root
+ for r, nodes_of_root in nodes_per_root.items()
+ ):
+ yield selector
+
+
+@no_duplicates_generator_decorator
+def generate_selectors_for_nodes(nodes: typing.List[Node], roots, complexity: int):
+ """
+ Generate a selector which matches the given nodes.
+ """
+
+ logging.info(
+ f"trying to find selector for nodes ({nodes=}, {roots=}, {complexity=})"
+ )
+ assert nodes, "no nodes given"
+ assert roots, "no roots given"
+ assert len(nodes) == len(roots)
+
+ direct_css_selectors = list(_generate_direct_css_selectors_for_nodes(nodes))
+ for direct_css_selector in direct_css_selectors:
+ yield CssRuleSelector(direct_css_selector)
+
+ parents_of_nodes_below_roots = [
+ [p for p in n.parents if p.has_parent(r) and p.tag_name not in ["html", "body"]]
+ for n, r in zip(nodes, roots)
+ ]
+ for parent_nodes in product(*parents_of_nodes_below_roots):
+ for parent_selector_raw in _generate_direct_css_selectors_for_nodes(
+ parent_nodes
+ ):
+ for css_selector_raw in direct_css_selectors:
+ css_selector_combined = parent_selector_raw + " " + css_selector_raw
+ yield CssRuleSelector(css_selector_combined)
+
+
+def _generate_direct_css_selectors_for_nodes(nodes: typing.List[Node]):
+ logging.info(f"generating direct css selector for nodes ({nodes=})")
+ common_classes = set.intersection(*[set(n.classes) for n in nodes])
+
+ is_same_tag = len({n.tag_name for n in nodes}) == 1
+ common_tag_name = nodes[0].tag_name
+ yield common_tag_name
+
+ common_ids = {n.id for n in nodes}
+ is_same_id = len(common_ids) == 1
+ if is_same_id and None not in common_ids:
+ yield "#" + first(common_ids)
+
+ for class_combination in powerset(common_classes):
+ if class_combination:
+ logging.info(f"- generating selector for ({class_combination=})")
+ css_selector = "".join(map(lambda cl: "." + cl, class_combination))
+ yield css_selector
+ if is_same_tag:
+ yield common_tag_name + css_selector
+ else:
+ # empty combination -> ignore
+ pass
diff --git a/mlscraper/training.py b/mlscraper/training.py
index b57648f..04485f2 100644
--- a/mlscraper/training.py
+++ b/mlscraper/training.py
@@ -1,10 +1,8 @@
import logging
import typing
from itertools import combinations
-from itertools import permutations
from itertools import product
-from mlscraper.html import get_relative_depth
from mlscraper.matches import DictMatch
from mlscraper.matches import ListMatch
from mlscraper.matches import ValueMatch
@@ -12,7 +10,7 @@ from mlscraper.samples import TrainingSet
from mlscraper.scrapers import DictScraper
from mlscraper.scrapers import ListScraper
from mlscraper.scrapers import ValueScraper
-from mlscraper.selectors import generate_selector_for_nodes
+from mlscraper.selectors import generate_unique_selectors_for_nodes
from mlscraper.selectors import PassThroughSelector
from more_itertools import first
from more_itertools import flatten
@@ -47,6 +45,8 @@ def train_scraper(training_set: TrainingSet):
match_combinations_by_depth = sorted(
match_combinations, key=lambda mc: sum(m.depth for m in mc), reverse=True
)
+ # todo compute selectivity of classes to use selective ones first
+ # todo cache selectors of node combinations to avoid re-selecting after increasing complexity
for complexity in range(3):
for match_combination in match_combinations_by_depth:
logging.info(
@@ -115,11 +115,14 @@ def train_scraper_for_matches(matches, roots, complexity: int):
)
selector = first(
- generate_selector_for_nodes([m.node for m in matches], roots, complexity),
+ generate_unique_selectors_for_nodes(
+ [m.node for m in matches], roots, complexity
+ ),
None,
)
if not selector:
raise NoScraperFoundException(f"no selector found {matches=}")
+ logging.info(f"found selector for ValueScraper ({selector=})")
return ValueScraper(selector, extractor)
elif found_type == DictMatch:
logging.info("training DictScraper")
@@ -138,6 +141,7 @@ def train_scraper_for_matches(matches, roots, complexity: int):
)
for k in keys
}
+ logging.info(f"found DictScraper ({scraper_per_key=})")
return DictScraper(scraper_per_key)
elif found_type == ListMatch:
logging.info("training ListScraper")
@@ -160,7 +164,9 @@ def train_scraper_for_matches(matches, roots, complexity: int):
# no need to try other selectors
# -> item_scraper would be the same
selector = first(
- generate_selector_for_nodes(list(item_nodes), list(item_roots), complexity),
+ generate_unique_selectors_for_nodes(
+ list(item_nodes), list(item_roots), complexity
+ ),
None,
)
if selector:
diff --git a/mlscraper/util.py b/mlscraper/util.py
index 1e3f920..00bc25d 100644
--- a/mlscraper/util.py
+++ b/mlscraper/util.py
@@ -3,3 +3,14 @@ from more_itertools import powerset
def powerset_max_length(candidates, length):
return filter(lambda s: len(s) <= length, powerset(candidates))
+
+
+def no_duplicates_generator_decorator(func):
+ def inner(*args, **kwargs):
+ seen = set()
+ for item in func(*args, **kwargs):
+ if item not in seen:
+ yield item
+ seen.add(item)
+
+ return inner
diff --git a/tests/test_selectors.py b/tests/test_selectors.py
index 3e4570a..9e7acc0 100644
--- a/tests/test_selectors.py
+++ b/tests/test_selectors.py
@@ -1,21 +1,49 @@
from mlscraper.html import Page
-from mlscraper.selectors import generate_selector_for_nodes
+from mlscraper.selectors import generate_unique_selectors_for_nodes
-def test_generate_selector_for_nodes():
- page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>'
- page1 = Page(page1_html)
+def get_css_selectors_for_node(node):
+ """
+ helper to extract plain css rules
+ """
+ return [
+ selector.css_rule
+ for selector in generate_unique_selectors_for_nodes([node], None, 100)
+ ]
- page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>'
- page2 = Page(page2_html)
- nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2]))
- gen = generate_selector_for_nodes(nodes, None, 1)
- selectors_found = [sel.css_rule for sel in gen]
- assert {".test", "p.test"} == set(selectors_found)
+class TestGenerateUniqueSelectorsForNodes:
+ def test_basic(self):
+ page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>'
+ page1 = Page(page1_html)
+ page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>'
+ page2 = Page(page2_html)
-class TestGenerateSelectorForNodes:
- def test_generate_selector_for_nodes(self):
- # generate_selector_for_nodes()
- pass
+ nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2]))
+ gen = generate_unique_selectors_for_nodes(nodes, None, 1)
+ selectors_found = [sel.css_rule for sel in gen]
+
+ assert "p" not in selectors_found
+ assert "div" not in selectors_found
+
+ assert ".test" in selectors_found
+ assert "p.test" in selectors_found
+
+ def test_ids(self):
+ page = Page(
+ b"""
+ <html><body>
+ <div id="target">test</div>
+ <div>irrelevant</div>
+ </body></html>"""
+ )
+ node = page.select("#target")[0]
+ selectors = get_css_selectors_for_node(node)
+ assert selectors == ["#target"]
+
+ def test_multi_parents(self):
+ page = Page(b'<html><body><div id="target"><p>test</p></div><div><p></p></div>')
+ node = page.select("#target")[0].select("p")[0]
+ selectors = get_css_selectors_for_node(node)
+ assert "#target p" in selectors
diff --git a/tests/test_training.py b/tests/test_training.py
index ccd0441..f0f1703 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -1,4 +1,3 @@
-import pytest
from mlscraper.html import Page
from mlscraper.matches import TextValueExtractor
from mlscraper.samples import Sample
@@ -61,6 +60,27 @@ def test_train_scraper_list_of_dicts():
assert isinstance(value_scraper.extractor, TextValueExtractor)
+def test_train_scraper_multipage():
+ training_set = TrainingSet()
+ for items in ["ab", "cd"]:
+ html = b"""
+ <html><body>
+ <div class="target">
+ <ul><li>%s</li><li>%s</li></ul>
+ </div>
+ </body></html>
+ """ % (
+ items[0].encode(),
+ items[1].encode(),
+ )
+ training_set.add_sample(Sample(Page(html), [items[0], items[1]]))
+ scraper = train_scraper(training_set)
+ assert scraper.selector.css_rule == "li"
+ assert scraper.get(
+ Page(b"""<html><body><ul><li>first</li><li>second</li></body></html>""")
+ ) == ["first", "second"]
+
+
def test_train_scraper_stackoverflow(stackoverflow_samples):
training_set = TrainingSet()
for s in stackoverflow_samples:
diff --git a/tests/test_util.py b/tests/test_util.py
index e69de29..c2428e8 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -0,0 +1,9 @@
+from mlscraper.util import no_duplicates_generator_decorator
+
+
+def test_no_duplicates_generator_decorator():
+ @no_duplicates_generator_decorator
+ def decorated_generator():
+ yield from [1, 1, 2, 3, 3, 3]
+
+ assert list(decorated_generator()) == [1, 2, 3]