summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKarl Lorey <git@karllorey.com>2022-06-17 21:58:11 +0200
committerKarl Lorey <git@karllorey.com>2022-06-17 21:58:11 +0200
commitedc327cbc52de5fc9d9cb8eb475d8007ea7337f1 (patch)
treefdf40ce9115c8ca5a4e014aa70b13d4d6caabadf
parent26e96a8e4e306bf350dc2f6d6b379d9509d18198 (diff)
Fix ListScraper and introduce maximum complexity parameter
-rw-r--r--mlscraper/html.py16
-rw-r--r--mlscraper/matches.py40
-rw-r--r--mlscraper/samples.py63
-rw-r--r--mlscraper/selectors.py10
-rw-r--r--mlscraper/training.py77
-rw-r--r--tests/test_samples.py15
-rw-r--r--tests/test_selectors.py2
-rw-r--r--tests/test_training.py49
8 files changed, 122 insertions, 150 deletions
diff --git a/mlscraper/html.py b/mlscraper/html.py
index c19f0db..0a6f4de 100644
--- a/mlscraper/html.py
+++ b/mlscraper/html.py
@@ -14,9 +14,6 @@ from bs4 import NavigableString
from bs4 import Tag
from mlscraper.util import powerset_max_length
-PARENT_NODE_COUNT_MAX = 2
-CSS_CLASS_COMBINATIONS_MAX = 2
-
@dataclass
class Match(ABC):
@@ -33,7 +30,7 @@ class AttributeMatch(Match):
attr: str = None
-def _generate_css_selectors_for_node(soup: Tag):
+def _generate_css_selectors_for_node(soup: Tag, complexity: int):
"""
Generate a selector for the given node.
:param soup:
@@ -48,7 +45,7 @@ def _generate_css_selectors_for_node(soup: Tag):
# use classes
css_classes = soup.attrs.get("class", [])
- for css_class_combo in powerset_max_length(css_classes, CSS_CLASS_COMBINATIONS_MAX):
+ for css_class_combo in powerset_max_length(css_classes, complexity):
css_clases_str = "".join([f".{css_class}" for css_class in css_class_combo])
css_selector = soup.name + css_clases_str
yield css_selector
@@ -108,7 +105,7 @@ class Node:
return True
return False
- def generate_path_selectors(self):
+ def generate_path_selectors(self, complexity: int):
"""
Generate a selector for the path to the given node.
:return:
@@ -133,7 +130,7 @@ class Node:
# print(parents)
# loop from i=0 to i=len(parents) as we consider all parents
- parent_node_count_max = min(len(parents), PARENT_NODE_COUNT_MAX)
+ parent_node_count_max = min(len(parents), complexity)
for parent_node_count in range(parent_node_count_max + 1):
logging.info(
"generating path selectors with %d parents" % parent_node_count
@@ -146,7 +143,8 @@ class Node:
# make a list of selector generators for each node in the path
# todo limit generated selectors -> huge product
selector_generators_for_each_path_node = [
- _generate_css_selectors_for_node(n) for n in path_sampled
+ _generate_css_selectors_for_node(n, complexity)
+ for n in path_sampled
]
# generator that outputs selector paths
@@ -171,7 +169,7 @@ class Node:
def __repr__(self):
if isinstance(self.soup, NavigableString):
return f"<{self.__class__.__name__} {self.soup[:100]=}>"
- return f"<{self.__class__.__name__} {self.soup.name=} classes={self.soup.get('class', None)}>"
+ return f"<{self.__class__.__name__} {self.soup.name=} classes={self.soup.get('class', None)}, text={self.soup.text[:10]}...>"
def __hash__(self):
return self.soup.__hash__()
diff --git a/mlscraper/matches.py b/mlscraper/matches.py
index ca5d66a..d1a8498 100644
--- a/mlscraper/matches.py
+++ b/mlscraper/matches.py
@@ -3,12 +3,12 @@ Matches are specific elements found on a page that match a sample.
"""
import logging
import typing
+from functools import cached_property
from mlscraper.html import AttributeMatch
from mlscraper.html import get_root_node
from mlscraper.html import Node
from mlscraper.html import TextMatch
-from mlscraper.selectors import Selector
class Match:
@@ -33,30 +33,6 @@ class Extractor:
raise NotImplementedError()
-class Matcher:
- """
- Class that finds/selects nodes and extracts items from these nodes.
- """
-
- selector = None
- extractor = None
-
- def __init__(self, selector: Selector, extractor: Extractor):
- self.selector = selector
- self.extractor = extractor
-
- def match_one(self, node: Node) -> Match:
- selected_node = self.selector.select_one(node)
- return Match(selected_node, self.extractor)
-
- def match_all(self, node: Node) -> typing.List[Match]:
- selected_nodes = self.selector.select_all(node)
- return [Match(n, self.extractor) for n in selected_nodes]
-
- def __repr__(self):
- return f"<{self.__class__.__name__} {self.selector=} {self.extractor=}>"
-
-
class TextValueExtractor(Extractor):
"""
Class to extract text from a node.
@@ -102,23 +78,13 @@ class AttributeValueExtractor(Extractor):
return isinstance(other, AttributeValueExtractor) and self.attr == other.attr
-class DictExtractor(Extractor):
- def __init__(self, matcher_by_key: typing.Dict[str, Matcher]):
- self.matcher_by_key = matcher_by_key
-
- def extract(self, node: Node):
- return {
- key: matcher.match_one(node) for key, matcher in self.matcher_by_key.items()
- }
-
-
class DictMatch(Match):
match_by_key = None
def __init__(self, match_by_key: dict):
self.match_by_key = match_by_key
- @property
+ @cached_property
def root(self) -> Node:
match_roots = [m.root for m in self.match_by_key.values()]
return get_root_node(match_roots)
@@ -136,7 +102,7 @@ class ListMatch(Match):
def __repr__(self):
return f"<{self.__class__.__name__} {self.matches=}>"
- @property
+ @cached_property
def root(self) -> Node:
return get_root_node([m.root for m in self.matches])
diff --git a/mlscraper/samples.py b/mlscraper/samples.py
index f369de4..8f54be9 100644
--- a/mlscraper/samples.py
+++ b/mlscraper/samples.py
@@ -1,15 +1,10 @@
-import logging
import typing
from itertools import product
-from mlscraper.html import Node
from mlscraper.html import Page
from mlscraper.matches import DictMatch
from mlscraper.matches import generate_all_value_matches
from mlscraper.matches import ListMatch
-from mlscraper.matches import Matcher
-from mlscraper.selectors import CssRuleSelector
-from more_itertools import flatten
class ItemStructureException(Exception):
@@ -35,6 +30,7 @@ class Sample:
# generate list of combinations
# todo filter combinations that use the same matches twice
+ # todo create combinations only in order
match_combis = product(*matches_by_value)
return [ListMatch(tuple(match_combi)) for match_combi in match_combis]
@@ -155,60 +151,3 @@ def make_training_set(pages, items):
ts.add_sample(Sample(p, i))
return ts
-
-
-def make_matcher_for_samples(
- samples: typing.List[Sample], roots: typing.Optional[typing.List[Node]] = None
-) -> typing.Union[Matcher, None]:
- for sample in samples:
- # todo leverage generator or cache
- assert sample.get_matches(), f"no matches found for {sample}"
-
- for matcher in generate_matchers_for_samples(samples, roots):
- return matcher
- return None
-
-
-def generate_matchers_for_samples(
- samples: typing.List[Sample], roots: typing.Optional[typing.List[Node]] = None
-) -> typing.Generator:
- """
- Generate CSS selectors that match the given samples.
- :param samples:
- :param roots: root nodes to search from
- :return:
- """
- logging.info(f"generating matchers for samples {samples=} {roots=}")
- if not roots:
- roots = [s.page for s in samples]
- logging.info("roots not set, will use samples' pages")
-
- assert len(samples) == len(roots)
-
- # make a list containing sets of nodes for each possible combination of matches
- # -> enables fast searching and set ensures order
- # todo add only matches below roots here
- matches_per_sample = [s.get_matches() for s in samples]
- match_combinations = list(map(set, product(*matches_per_sample)))
- logging.info(f"match combinations: {match_combinations}")
- node_combinations = [{m.node for m in matches} for matches in match_combinations]
-
- for sample in samples:
- for match in sample.get_matches():
- for css_sel in match.root.generate_path_selectors():
- logging.info(f"testing selector: {css_sel}")
- matched_nodes = set(flatten(root.select(css_sel) for root in roots))
- if matched_nodes in node_combinations:
- logging.info(f"{css_sel} matches one of the possible combinations")
- i = node_combinations.index(matched_nodes)
- matches = match_combinations[i]
- match_extractors = {m.extractor for m in matches}
- if len(match_extractors) == 1:
- logging.info(f"{css_sel} matches same extractors")
- selector = CssRuleSelector(css_sel)
- extractor = next(iter(match_extractors))
- yield Matcher(selector, extractor)
- else:
- logging.info(
- f"{css_sel} would need different extractors, ignoring: {match_extractors}"
- )
diff --git a/mlscraper/selectors.py b/mlscraper/selectors.py
index 8a15ad3..1ba4c55 100644
--- a/mlscraper/selectors.py
+++ b/mlscraper/selectors.py
@@ -2,9 +2,7 @@ import logging
import typing
from mlscraper.html import Node
-from mlscraper.html import Page
from mlscraper.html import selector_matches_nodes
-from more_itertools import bucket
class Selector:
@@ -47,8 +45,10 @@ class CssRuleSelector(Selector):
return f"<{self.__class__.__name__} {self.css_rule=}>"
-def generate_selector_for_nodes(nodes: typing.List[Node], roots):
- logging.info(f"trying to find selector for nodes ({nodes=}, {roots=})")
+def generate_selector_for_nodes(nodes: typing.List[Node], roots, complexity: int):
+ logging.info(
+ f"trying to find selector for nodes ({nodes=}, {roots=}, {complexity=})"
+ )
assert nodes, "no nodes given"
if roots is None:
@@ -63,7 +63,7 @@ def generate_selector_for_nodes(nodes: typing.List[Node], roots):
selectors_seen = set()
for node in nodes:
- for sel in node.generate_path_selectors():
+ for sel in node.generate_path_selectors(complexity):
logging.info(f"selector: {sel}")
if sel not in selectors_seen:
logging.info(
diff --git a/mlscraper/training.py b/mlscraper/training.py
index 8512682..55e4aa7 100644
--- a/mlscraper/training.py
+++ b/mlscraper/training.py
@@ -32,17 +32,35 @@ def train_scraper(training_set: TrainingSet):
sample_matches = [s.get_matches() for s in training_set.item.samples]
roots = [s.page for s in training_set.item.samples]
- for match_combination in product(*sample_matches):
- logging.info(f"trying to train scraper for matches ({match_combination=})")
- scraper = train_scraper_for_matches(match_combination, roots)
- return scraper
+ match_combinations = [mc for mc in product(*sample_matches)]
+ logging.info(f"Trying {len(match_combinations)=}")
-
-def train_scraper_for_matches(matches, roots):
+ for complexity in range(3):
+ for match_combination in match_combinations:
+ logging.info(
+ f"progress {match_combinations.index(match_combination)/len(match_combinations)}"
+ )
+ try:
+ logging.info(
+ f"trying to train scraper for matches ({match_combination=})"
+ )
+ scraper = train_scraper_for_matches(
+ match_combination, roots, complexity
+ )
+ return scraper
+ except NoScraperFoundException:
+ logging.info(
+ f"no scraper found for complexity and match_combination ({complexity=}, {match_combination=})"
+ )
+ raise NoScraperFoundException(f"did not find scraper")
+
+
+def train_scraper_for_matches(matches, roots, complexity: int):
"""
Train a scraper that finds the given matches from the given roots.
:param matches: the matches to scrape
:param roots: the root elements containing the matches, e.g. pages or elements on pages
+ :param complexity: the complexity to try
"""
found_types = set(map(type, matches))
assert (
@@ -71,9 +89,15 @@ def train_scraper_for_matches(matches, roots):
if all(m.node == r for m, r in zip(matches, roots)):
# nodes are matched already, done
return ValueScraper(PassThroughSelector(), extractor=extractor)
+ else:
+ logging.info(
+ "no early return: %s",
+ [(m.node, r, m.node == r) for m, r in zip(matches, roots)],
+ )
selector = first(
- generate_selector_for_nodes([m.node for m in matches], roots), None
+ generate_selector_for_nodes([m.node for m in matches], roots, complexity),
+ None,
)
if not selector:
raise NoScraperFoundException(f"no selector found {matches=}")
@@ -90,35 +114,50 @@ def train_scraper_for_matches(matches, roots):
# matches are the matches for the keys
# roots are the original roots(?)
scraper_per_key = {
- k: train_scraper_for_matches([m.match_by_key[k] for m in matches], roots)
+ k: train_scraper_for_matches(
+ [m.match_by_key[k] for m in matches], roots, complexity
+ )
for k in keys
}
return DictScraper(scraper_per_key)
elif found_type == ListMatch:
logging.info("training ListScraper")
matches: typing.List[ListMatch]
+ logging.info(matches)
# so we have a list of ListMatch objects
# we have to find a selector that uniquely matches the list elements
# todo can be one of the parents
- match_roots = [m.root for m in matches]
- logging.info(f"{match_roots=}")
+ # for each match, generate all the nodes of list items
+ list_item_match_and_roots = [
+ (im, r) for m, r in zip(matches, roots) for im in m.matches
+ ]
+ list_item_nodes_and_roots = [
+ (im.root, r) for im, r in list_item_match_and_roots
+ ]
+ item_nodes, item_roots = unzip(list_item_nodes_and_roots)
# first selector is fine as it matches perfectly
# no need to try other selectors
# -> item_scraper would be the same
- selector = first(generate_selector_for_nodes(match_roots, roots))
+ selector = first(
+ generate_selector_for_nodes(list(item_nodes), list(item_roots), complexity),
+ None,
+ )
if selector:
- # for all the item_matches, create a tuple
- # that contains the item_match and the new root
- matches_and_roots = [
- (im, selector.select_one(r))
- for m, r in zip(matches, roots)
- for im in m.matches
+ logging.info(f"selector that matches list items found ({selector=})")
+ # so we have found a selector that matches the list items
+ # we now need a scraper, that scrapes each contained item
+ # todo im.root does not hold for all items, could be a parent
+ item_matches_and_item_roots = [
+ (im, im.root) for im, r in list_item_match_and_roots
]
- item_matches, list_roots = unzip(matches_and_roots)
+ logging.info(
+ f"training to extract list items now ({item_matches_and_item_roots})"
+ )
+ item_matches, item_roots = unzip(item_matches_and_item_roots)
item_scraper = train_scraper_for_matches(
- list(item_matches), list(list_roots)
+ list(item_matches), list(item_roots), complexity
)
return ListScraper(selector, item_scraper)
else:
diff --git a/tests/test_samples.py b/tests/test_samples.py
index df18d86..b6677f0 100644
--- a/tests/test_samples.py
+++ b/tests/test_samples.py
@@ -3,7 +3,6 @@ from mlscraper.html import Page
from mlscraper.matches import DictMatch
from mlscraper.matches import ListMatch
from mlscraper.samples import ItemStructureException
-from mlscraper.samples import make_matcher_for_samples
from mlscraper.samples import make_training_set
from mlscraper.samples import Sample
@@ -64,17 +63,3 @@ class TestMatch:
assert len(match.matches) == 2
assert all(isinstance(m, DictMatch) for m in match.matches)
print(match.root)
-
-
-def test_make_matcher_for_samples():
- page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>'
- page1 = Page(page1_html)
- sample1 = Sample(page1, "test")
-
- page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>'
- page2 = Page(page2_html)
- sample2 = Sample(page2, "hallo")
-
- samples = [sample1, sample2]
- matcher = make_matcher_for_samples(samples)
- assert matcher.selector.css_rule in ["p.test", ".test"]
diff --git a/tests/test_selectors.py b/tests/test_selectors.py
index 57a5ded..c2d204a 100644
--- a/tests/test_selectors.py
+++ b/tests/test_selectors.py
@@ -16,6 +16,6 @@ def test_generate_selector_for_nodes():
nodes = [s.get_matches()[0].root for s in samples]
print(nodes)
- gen = generate_selector_for_nodes(nodes, None)
+ gen = generate_selector_for_nodes(nodes, None, 1)
# todo .test is also possible
assert ["p.test"] == [sel.css_rule for sel in gen]
diff --git a/tests/test_training.py b/tests/test_training.py
index 3e48756..ddd86fb 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -1,7 +1,12 @@
import pytest
from mlscraper.html import Page
+from mlscraper.matches import TextValueExtractor
from mlscraper.samples import Sample
from mlscraper.samples import TrainingSet
+from mlscraper.scrapers import ListScraper
+from mlscraper.scrapers import ValueScraper
+from mlscraper.selectors import CssRuleSelector
+from mlscraper.selectors import PassThroughSelector
from mlscraper.training import train_scraper
@@ -13,10 +18,50 @@ def test_train_scraper_simple_list():
["a", "b", "c"],
)
training_set.add_sample(sample)
- train_scraper(training_set)
+ scraper = train_scraper(training_set)
+
+ # check list scraper
+ assert isinstance(scraper, ListScraper)
+ assert isinstance(scraper.selector, CssRuleSelector)
+ assert scraper.selector.css_rule == "p"
+
+ # check item scraper
+ item_scraper = scraper.scraper
+ assert isinstance(item_scraper, ValueScraper)
+ assert isinstance(item_scraper.selector, PassThroughSelector)
+ assert isinstance(item_scraper.extractor, TextValueExtractor)
+
+
+def test_train_scraper_list_of_dicts():
+ html = b"""
+ <html>
+ <body>
+ <div><p>a</p><p>b</p></div>
+ <div><p>c</p><p>d</p></div>
+ </body>
+ </html
+ """
+ page = Page(html)
+ sample = Sample(page, [["a", "b"], ["c", "d"]])
+ training_set = TrainingSet()
+ training_set.add_sample(sample)
+ scraper = train_scraper(training_set)
+ assert isinstance(scraper, ListScraper)
+ assert isinstance(scraper.selector, CssRuleSelector)
+ assert scraper.selector.css_rule == "div"
+
+ inner_scraper = scraper.scraper
+ assert isinstance(inner_scraper, ListScraper)
+ assert isinstance(inner_scraper.selector, CssRuleSelector)
+ assert inner_scraper.selector.css_rule == "p"
+
+ value_scraper = inner_scraper.scraper
+ assert isinstance(value_scraper, ValueScraper)
+ assert isinstance(value_scraper.selector, PassThroughSelector)
+ assert isinstance(value_scraper.extractor, TextValueExtractor)
-@pytest.mark.skip("does not work yet")
+@pytest.mark.skip("takes too long")
def test_train_scraper(stackoverflow_samples):
training_set = TrainingSet()
for s in stackoverflow_samples: