summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKarl Lorey <git@karllorey.com>2022-06-15 17:33:48 +0200
committerKarl Lorey <git@karllorey.com>2022-06-15 17:33:48 +0200
commit7d7a07ea8baf7ee2af2a93b41ac42ee73d16fbda (patch)
tree98da7246e2071d3f0f32efe0c853819633e07983
parent04aef4bfdd331e99d4d5f9f10111fed6b8e39de1 (diff)
Rewrite training module to decrease complexity
-rw-r--r--mlscraper/html.py6
-rw-r--r--mlscraper/matches.py7
-rw-r--r--mlscraper/samples.py4
-rw-r--r--mlscraper/selectors.py19
-rw-r--r--mlscraper/training.py155
-rw-r--r--tests/test_scrapers.py8
-rw-r--r--tests/test_training.py4
7 files changed, 131 insertions, 72 deletions
diff --git a/mlscraper/html.py b/mlscraper/html.py
index 2b2ecd2..c19f0db 100644
--- a/mlscraper/html.py
+++ b/mlscraper/html.py
@@ -102,6 +102,12 @@ class Node:
# todo implement other find methods
+ def has_parent(self, node: "Node"):
+ for p in self.soup.parents:
+ if p == node.soup:
+ return True
+ return False
+
def generate_path_selectors(self):
"""
Generate a selector for the path to the given node.
diff --git a/mlscraper/matches.py b/mlscraper/matches.py
index 51c4255..ca5d66a 100644
--- a/mlscraper/matches.py
+++ b/mlscraper/matches.py
@@ -136,6 +136,7 @@ class ListMatch(Match):
def __repr__(self):
return f"<{self.__class__.__name__} {self.matches=}>"
+ @property
def root(self) -> Node:
return get_root_node([m.root for m in self.matches])
@@ -156,8 +157,10 @@ class ValueMatch(Match):
return self.node
-def generate_all_matches(node: Node, item) -> typing.Generator[Match, None, None]:
- logging.info(f"generating all matches ({node=}, {item=})")
+def generate_all_value_matches(
+ node: Node, item: str
+) -> typing.Generator[Match, None, None]:
+ logging.info(f"generating all value matches ({node=}, {item=})")
for html_match in node.find_all(item):
matched_node = html_match.node
if isinstance(html_match, TextMatch):
diff --git a/mlscraper/samples.py b/mlscraper/samples.py
index df7f576..f369de4 100644
--- a/mlscraper/samples.py
+++ b/mlscraper/samples.py
@@ -5,7 +5,7 @@ from itertools import product
from mlscraper.html import Node
from mlscraper.html import Page
from mlscraper.matches import DictMatch
-from mlscraper.matches import generate_all_matches
+from mlscraper.matches import generate_all_value_matches
from mlscraper.matches import ListMatch
from mlscraper.matches import Matcher
from mlscraper.selectors import CssRuleSelector
@@ -28,7 +28,7 @@ class Sample:
# todo: fix creating new sample objects, maybe by using Item class?
if isinstance(self.value, str):
- return list(generate_all_matches(self.page, self.value))
+ return list(generate_all_value_matches(self.page, self.value))
if isinstance(self.value, list):
matches_by_value = [Sample(self.page, v).get_matches() for v in self.value]
diff --git a/mlscraper/selectors.py b/mlscraper/selectors.py
index 31b26c9..8a15ad3 100644
--- a/mlscraper/selectors.py
+++ b/mlscraper/selectors.py
@@ -4,6 +4,7 @@ import typing
from mlscraper.html import Node
from mlscraper.html import Page
from mlscraper.html import selector_matches_nodes
+from more_itertools import bucket
class Selector:
@@ -18,6 +19,15 @@ class Selector:
raise NotImplementedError()
+class PassThroughSelector(Selector):
+ def select_one(self, node: Node) -> Node:
+ return node
+
+ def select_all(self, node: Node) -> typing.List[Node]:
+ # this does not make sense as we have only one node to pass through
+ raise RuntimeError("cannot apply select_all to PassThroughSelector")
+
+
class CssRuleSelector(Selector):
def __init__(self, css_rule):
self.css_rule = css_rule
@@ -38,7 +48,8 @@ class CssRuleSelector(Selector):
def generate_selector_for_nodes(nodes: typing.List[Node], roots):
- logging.info(f"trying to find selector for nodes ({nodes=})")
+ logging.info(f"trying to find selector for nodes ({nodes=}, {roots=})")
+ assert nodes, "no nodes given"
if roots is None:
logging.info("roots is None, setting roots manually")
@@ -46,10 +57,8 @@ def generate_selector_for_nodes(nodes: typing.List[Node], roots):
# todo roots and nodes can be uneven here because we just want to find a way
# to select all the nodes from the given roots
-
- nodes_per_root = {}
- for root in set(roots):
- nodes_per_root[root] = [n for n in nodes if n.root == root]
+ nodes_per_root = {r: [n for n in nodes if n.has_parent(r)] for r in set(roots)}
+ logging.info(f"item by root: %s", nodes_per_root)
selectors_seen = set()
diff --git a/mlscraper/training.py b/mlscraper/training.py
index 6647bdd..38b15d9 100644
--- a/mlscraper/training.py
+++ b/mlscraper/training.py
@@ -2,17 +2,18 @@ import logging
import typing
from itertools import product
-from mlscraper.html import Node
-from mlscraper.samples import DictItem
-from mlscraper.samples import Item
-from mlscraper.samples import ListItem
-from mlscraper.samples import make_matcher_for_samples
-from mlscraper.samples import Sample
-from mlscraper.samples import ValueItem
+from mlscraper.matches import DictMatch
+from mlscraper.matches import ListMatch
+from mlscraper.matches import ValueMatch
+from mlscraper.samples import TrainingSet
from mlscraper.scrapers import DictScraper
from mlscraper.scrapers import ListScraper
from mlscraper.scrapers import ValueScraper
from mlscraper.selectors import generate_selector_for_nodes
+from mlscraper.selectors import PassThroughSelector
+from more_itertools import first
+from more_itertools import flatten
+from more_itertools import unzip
class TrainingException(Exception):
@@ -23,64 +24,98 @@ class NoScraperFoundException(TrainingException):
pass
-def train_scraper(item: Item, roots: typing.Optional[typing.List[Node]] = None):
+def train_scraper(training_set: TrainingSet):
"""
Train a scraper able to extract the given training data.
"""
- logging.info(f"training {item}")
-
- # set roots to page if not set
- if roots is None:
- roots = [s.page for s in item.samples]
- logging.info(f"roots inferred: {roots}")
-
- assert len(item.samples) == len(roots), f"{len(item.samples)=} != {len(roots)=}"
-
- if isinstance(item, ListItem):
- # so we have to extract a list from each root
- # to do this, we take all matches we can find
- # and try to find a common selector for one of the combinations root elements
- # if that works, we've succeeded
- # if not, we're out of luck for now
- # todo add root to get_matches to receive only matches below roots
- matches_per_sample = [s.get_matches() for s in item.item.samples]
- for match_combi in product(*matches_per_sample):
- # match_combi is one possible way to combine matches to extract the list
- logging.info(f"{match_combi=}")
- # we now take the root of every element
- match_roots = [m.root for m in match_combi]
- logging.info(f"{match_roots=}")
- for selector in generate_selector_for_nodes(match_roots, roots):
- # roots are the newly matched root elements
- item_scraper = train_scraper(item.item, match_roots)
- scraper = ListScraper(selector, item_scraper)
- return scraper
-
- raise NoScraperFoundException(f"no matcher found for {item}")
-
- if isinstance(item, DictItem):
- # train a scraper for each key, keep roots
- scraper_per_key = {
- k: train_scraper(i, roots) for k, i in item.item_per_key.items()
- }
- return DictScraper(scraper_per_key)
+ logging.info(f"training {training_set=}")
- if isinstance(item, ValueItem):
- # find a selector that uniquely matches the value given the root node
- matcher = make_matcher_for_samples(item.samples, roots)
- if matcher:
- return ValueScraper(matcher.selector, matcher.extractor)
- else:
- raise NoScraperFoundException(f"deriving matcher failed for {item}")
+ sample_matches = [s.get_matches() for s in training_set.item.samples]
+ roots = [s.page for s in training_set.item.samples]
+ for match_combination in product(*sample_matches):
+ logging.info(f"trying to train scraper for matches ({match_combination=})")
+ scraper = train_scraper_for_matches(match_combination, roots)
+ return scraper
-def get_smallest_span_match_per_sample(samples: typing.List[Sample]):
+def train_scraper_for_matches(matches, roots):
"""
- Get the best match for each sample by using the smallest span.
- :param samples:
- :return:
+ Train a scraper that finds the given matches from the given roots.
+ :param matches: the matches to scrape
+ :param roots: the root elements containing the matches, e.g. pages or elements on pages
"""
- best_match_per_sample = [
- sorted(s.get_matches(), key=lambda m: m.get_span())[0] for s in samples
- ]
- return best_match_per_sample
+ found_types = set(map(type, matches))
+ assert (
+ len(found_types) == 1
+ ), f"different match types passed {found_types=}, {matches=}"
+ found_type = first(found_types)
+
+ # make sure we have lists
+ matches = list(matches)
+ roots = list(roots)
+
+ assert len(matches) == len(roots), f"got uneven inputs ({matches=}, {roots=})"
+ if found_type == ValueMatch:
+ logging.info("training ValueScraper")
+ matches: typing.List[ValueMatch]
+
+ # if matches have different extractors, we can't find a common scraper
+ extractors = set(map(lambda m: m.extractor, matches))
+ if len(extractors) != 1:
+ raise NoScraperFoundException(
+ "different extractors found for matches, aborting"
+ )
+ extractor = first(extractors)
+
+ # early return: nodes are matched already, e.g. for List of Values
+ if all(m.node == r for m, r in zip(matches, roots)):
+ # nodes are matched already, done
+ return ValueScraper(PassThroughSelector(), extractor=extractor)
+
+ selector = first(
+ generate_selector_for_nodes([m.node for m in matches], roots), None
+ )
+ if not selector:
+ raise NoScraperFoundException(f"no selector found {matches=}")
+ return ValueScraper(selector, extractor)
+ elif found_type == DictMatch:
+ logging.info("training DictScraper")
+ matches: typing.List[DictMatch]
+
+ # what if some matches have missing keys? idk
+ # by using union of all keys, we'll get errors two lines below to be sure
+ keys = set(flatten(m.match_by_key.keys() for m in matches))
+
+ # train scraper for each key of dict
+ # matches are the matches for the keys
+ # roots are the original roots(?)
+ scraper_per_key = {
+ k: train_scraper_for_matches([m.match_by_key[k] for m in matches], roots)
+ for k in keys
+ }
+ return DictScraper(scraper_per_key)
+ elif found_type == ListMatch:
+ logging.info("training ListScraper")
+ matches: typing.List[ListMatch]
+
+ # so we have a list of ListMatch objects
+ # we have to find a selector that uniquely matches the list elements
+ # todo can be one of the parents
+ match_roots = [m.root for m in matches]
+ logging.info(f"{match_roots=}")
+ selector = first(generate_selector_for_nodes(match_roots, roots))
+ if selector:
+ # for all the item_matches, create a tuple
+ # that contains the item_match and the new root
+ matches_and_roots = [
+ (im, selector.select_one(r))
+ for m, r in zip(matches, roots)
+ for im in m.matches
+ ]
+ item_matches, list_roots = unzip(matches_and_roots)
+ item_scraper = train_scraper_for_matches(
+ list(item_matches), list(list_roots)
+ )
+ return ListScraper(selector, item_scraper)
+ else:
+ raise RuntimeError(f"type not matched: {found_type}")
diff --git a/tests/test_scrapers.py b/tests/test_scrapers.py
index dfb4781..200dde5 100644
--- a/tests/test_scrapers.py
+++ b/tests/test_scrapers.py
@@ -5,6 +5,7 @@ from mlscraper.scrapers import DictScraper
from mlscraper.scrapers import ListScraper
from mlscraper.scrapers import ValueScraper
from mlscraper.selectors import CssRuleSelector
+from mlscraper.selectors import PassThroughSelector
class TestListOfDictScraper:
@@ -62,7 +63,12 @@ class TestValueScraper:
assert vs.get(page1) == "test"
assert vs.get(page2) == "hallo"
+
class TestListOfValuesScraper:
def test_list_of_values_scraper(self):
page = Page(b"<html><body><p>a</p><i>noise</i><p>b</p><p>c</p></body></html>")
- ListScraper('p', )
+ scraper = ListScraper(
+ CssRuleSelector("p"),
+ ValueScraper(PassThroughSelector(), TextValueExtractor()),
+ )
+ assert scraper.get(page) == ["a", "b", "c"]
diff --git a/tests/test_training.py b/tests/test_training.py
index 947fbb0..49288a3 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -13,7 +13,7 @@ def test_train_scraper_simple_list():
["a", "b", "c"],
)
training_set.add_sample(sample)
- train_scraper(training_set.item)
+ train_scraper(training_set)
@pytest.mark.skip("fucking fails")
@@ -22,7 +22,7 @@ def test_train_scraper(stackoverflow_samples):
for s in stackoverflow_samples:
training_set.add_sample(s)
- scraper = train_scraper(training_set.item)
+ scraper = train_scraper(training_set)
print(f"result scraper: {scraper}")
print(f"selector for list items: {scraper.selector}")