summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorKarl Lorey <git@karllorey.com>2022-06-20 11:12:00 +0200
committerKarl Lorey <git@karllorey.com>2022-06-20 11:12:00 +0200
commitc6c371223d56f23ad4a588231b5e6f51bee4259c (patch)
tree45d68736dfce2c8bee9d061e19224161f3916332 /tests
parent3a4c3234653984768a992747ba45da5e34d3af9c (diff)
Re-implement selector generation with a speedup >10xdevelop
Diffstat (limited to 'tests')
-rw-r--r--tests/test_selectors.py56
-rw-r--r--tests/test_training.py22
-rw-r--r--tests/test_util.py9
3 files changed, 72 insertions, 15 deletions
diff --git a/tests/test_selectors.py b/tests/test_selectors.py
index 3e4570a..9e7acc0 100644
--- a/tests/test_selectors.py
+++ b/tests/test_selectors.py
@@ -1,21 +1,49 @@
from mlscraper.html import Page
-from mlscraper.selectors import generate_selector_for_nodes
+from mlscraper.selectors import generate_unique_selectors_for_nodes
-def test_generate_selector_for_nodes():
- page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>'
- page1 = Page(page1_html)
+def get_css_selectors_for_node(node):
+ """
+ helper to extract plain css rules
+ """
+ return [
+ selector.css_rule
+ for selector in generate_unique_selectors_for_nodes([node], None, 100)
+ ]
- page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>'
- page2 = Page(page2_html)
- nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2]))
- gen = generate_selector_for_nodes(nodes, None, 1)
- selectors_found = [sel.css_rule for sel in gen]
- assert {".test", "p.test"} == set(selectors_found)
+class TestGenerateUniqueSelectorsForNodes:
+ def test_basic(self):
+ page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>'
+ page1 = Page(page1_html)
+ page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>'
+ page2 = Page(page2_html)
-class TestGenerateSelectorForNodes:
- def test_generate_selector_for_nodes(self):
- # generate_selector_for_nodes()
- pass
+ nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2]))
+ gen = generate_unique_selectors_for_nodes(nodes, None, 1)
+ selectors_found = [sel.css_rule for sel in gen]
+
+ assert "p" not in selectors_found
+ assert "div" not in selectors_found
+
+ assert ".test" in selectors_found
+ assert "p.test" in selectors_found
+
+ def test_ids(self):
+ page = Page(
+ b"""
+ <html><body>
+ <div id="target">test</div>
+ <div>irrelevant</div>
+ </body></html>"""
+ )
+ node = page.select("#target")[0]
+ selectors = get_css_selectors_for_node(node)
+ assert selectors == ["#target"]
+
+ def test_multi_parents(self):
+ page = Page(b'<html><body><div id="target"><p>test</p></div><div><p></p></div>')
+ node = page.select("#target")[0].select("p")[0]
+ selectors = get_css_selectors_for_node(node)
+ assert "#target p" in selectors
diff --git a/tests/test_training.py b/tests/test_training.py
index ccd0441..f0f1703 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -1,4 +1,3 @@
-import pytest
from mlscraper.html import Page
from mlscraper.matches import TextValueExtractor
from mlscraper.samples import Sample
@@ -61,6 +60,27 @@ def test_train_scraper_list_of_dicts():
assert isinstance(value_scraper.extractor, TextValueExtractor)
+def test_train_scraper_multipage():
+ training_set = TrainingSet()
+ for items in ["ab", "cd"]:
+ html = b"""
+ <html><body>
+ <div class="target">
+ <ul><li>%s</li><li>%s</li></ul>
+ </div>
+ </body></html>
+ """ % (
+ items[0].encode(),
+ items[1].encode(),
+ )
+ training_set.add_sample(Sample(Page(html), [items[0], items[1]]))
+ scraper = train_scraper(training_set)
+ assert scraper.selector.css_rule == "li"
+ assert scraper.get(
+ Page(b"""<html><body><ul><li>first</li><li>second</li></body></html>""")
+ ) == ["first", "second"]
+
+
def test_train_scraper_stackoverflow(stackoverflow_samples):
training_set = TrainingSet()
for s in stackoverflow_samples:
diff --git a/tests/test_util.py b/tests/test_util.py
index e69de29..c2428e8 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -0,0 +1,9 @@
+from mlscraper.util import no_duplicates_generator_decorator
+
+
+def test_no_duplicates_generator_decorator():
+ @no_duplicates_generator_decorator
+ def decorated_generator():
+ yield from [1, 1, 2, 3, 3, 3]
+
+ assert list(decorated_generator()) == [1, 2, 3]