summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKarl Lorey <git@karllorey.com>2020-09-26 15:15:31 +0200
committerKarl Lorey <git@karllorey.com>2020-09-26 15:15:31 +0200
commit1d977917dc6b362e4490750ee2e0cee56e4df7e4 (patch)
tree3b1d969a68771e1c6b4c2213be013e05a502db6a
parentf023ca303be41804b954cf79c5de1ca5e810d5c6 (diff)
Make rule-based selection faster and more robust
-rw-r--r--mlscraper/__init__.py64
-rw-r--r--mlscraper/parser.py2
-rw-r--r--mlscraper/training.py4
-rw-r--r--test/test_new.py5
4 files changed, 46 insertions, 29 deletions
diff --git a/mlscraper/__init__.py b/mlscraper/__init__.py
index 551687a..d6c799b 100644
--- a/mlscraper/__init__.py
+++ b/mlscraper/__init__.py
@@ -2,6 +2,7 @@ import logging
import random
import re
from collections import Counter, namedtuple
+from itertools import chain
from typing import List
import pandas as pd
@@ -19,6 +20,7 @@ from mlscraper.util import (
get_selectors,
derive_css_selector,
generate_path_selectors,
+ generate_unique_path_selectors,
)
SingleItemSample = namedtuple("SingleItemSample", ["data", "html"])
@@ -66,38 +68,50 @@ class RuleBasedSingleItemScraper:
rules = {} # attr -> selector
for attr in attributes:
- logging.info("Training attribute %s")
- selector_scoring = {} # selector -> score
-
- # for all potential selectors
- for sample in samples:
- soup_node = sample.item[attr].node._soup_node
- for css_selector in generate_path_selectors(soup_node):
- # check if selector matches the desired sample on each page
- # todo avoid doing it over and over for each sample
- sample_nodes = set(s.item[attr].node for s in samples)
- matches = set(
- flatten([s.page.select(css_selector) for s in samples])
- )
- score = len(sample_nodes.intersection(matches)) / len(
- sample_nodes.union(matches)
+ logging.info("Training attribute %s" % attr)
+
+ # get all potential matches
+ matching_nodes = flatten([s.page.find(s.item[attr]) for s in samples])
+ selectors = set(
+ chain(
+ *(
+ generate_unique_path_selectors(node._soup_node)
+ for node in matching_nodes
)
- selector_scoring[css_selector] = score
+ )
+ )
+
+ # check if they are unique on every page
+ # -> for all potential selectors: compute score
+ selector_scoring = {} # selector -> score
+ for selector in selectors:
+ if selector not in selector_scoring:
+ logging.info("testing %s" % selector)
+ matches_per_page = [s.page.select(selector) for s in samples]
+ matches_per_page_right = [
+ len(m) == 1 and m[0].get_text() == s.item[attr]
+ for m, s in zip(matches_per_page, samples)
+ ]
+ score = sum(matches_per_page_right) / len(samples)
+ selector_scoring[selector] = score
# find the selector with the best coverage, i.e. the highest accuracy
- print("Scoring: %s" % selector_scoring)
+ logging.info("Scoring for %s: %s", attr, selector_scoring)
selectors_sorted = sorted(
selector_scoring, key=selector_scoring.get, reverse=True
)
- print(selectors_sorted)
- selector_best = selectors_sorted[0]
- if selector_scoring[selector_best] < 1:
- logging.warning(
- "Best selector for %s does not work for all samples (score is %f)"
- % (attr, selector_scoring[selector_best])
- )
+ logging.info("Best scores for %s: %s", attr, selectors_sorted[:3])
+ try:
+ selector_best = selectors_sorted[0]
+ if selector_scoring[selector_best] < 1:
+ logging.warning(
+ "Best selector for %s does not work for all samples (score is %f)"
+ % (attr, selector_scoring[selector_best])
+ )
- rules[attr] = selector_best
+ rules[attr] = selector_best
+ except IndexError:
+ logging.warning("No selector found for %s", attr)
print(rules)
return RuleBasedSingleItemScraper(rules)
diff --git a/mlscraper/parser.py b/mlscraper/parser.py
index f8dc5f7..4bbce1a 100644
--- a/mlscraper/parser.py
+++ b/mlscraper/parser.py
@@ -35,7 +35,7 @@ class SoupPage(Page):
return []
def find(self, needle):
- assert type(needle) == str, "can only find strings ATM"
+ assert type(needle) == str, "can only find strings, %s given" % type(needle)
text_matches = self._soup.find_all(text=re.compile(needle))
logging.debug("Matches for %s: %s", needle, text_matches)
text_parents = (ns.parent for ns in text_matches)
diff --git a/mlscraper/training.py b/mlscraper/training.py
index 875ba22..28517a4 100644
--- a/mlscraper/training.py
+++ b/mlscraper/training.py
@@ -22,3 +22,7 @@ class SingleItemPageSample:
def __init__(self, page: Page, item: dict):
self.page = page
self.item = item
+
+ def find_nodes(self, attr):
+ needle = self.item[attr]
+ return self.page.find(needle)
diff --git a/test/test_new.py b/test/test_new.py
index 9ba5f72..c6a8874 100644
--- a/test/test_new.py
+++ b/test/test_new.py
@@ -5,9 +5,8 @@ from mlscraper.parser import make_soup_page, ExtractionResult
def test_basic():
html = '<html><body><div class="parent"><p class="item">result</p></div><p class="item">not a result</p></body></html>'
page = make_soup_page(html)
- node = page.select(".item")[0]
- item = {"res": ExtractionResult(node)}
+ item = {"res": "result"}
samples = [SingleItemPageSample(page, item)]
scraper = RuleBasedSingleItemScraper.build(samples)
- assert scraper.scrape(html)["res"] == "result"
+ assert scraper.scrape(html) == item