diff options
author | Karl Lorey <git@karllorey.com> | 2022-06-20 11:12:00 +0200 |
---|---|---|
committer | Karl Lorey <git@karllorey.com> | 2022-06-20 11:12:00 +0200 |
commit | c6c371223d56f23ad4a588231b5e6f51bee4259c (patch) | |
tree | 45d68736dfce2c8bee9d061e19224161f3916332 /tests | |
parent | 3a4c3234653984768a992747ba45da5e34d3af9c (diff) |
Re-implement selector generation with a speedup >10xdevelop
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_selectors.py | 56 | ||||
-rw-r--r-- | tests/test_training.py | 22 | ||||
-rw-r--r-- | tests/test_util.py | 9 |
3 files changed, 72 insertions, 15 deletions
diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 3e4570a..9e7acc0 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -1,21 +1,49 @@ from mlscraper.html import Page -from mlscraper.selectors import generate_selector_for_nodes +from mlscraper.selectors import generate_unique_selectors_for_nodes -def test_generate_selector_for_nodes(): - page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>' - page1 = Page(page1_html) +def get_css_selectors_for_node(node): + """ + helper to extract plain css rules + """ + return [ + selector.css_rule + for selector in generate_unique_selectors_for_nodes([node], None, 100) + ] - page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>' - page2 = Page(page2_html) - nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2])) - gen = generate_selector_for_nodes(nodes, None, 1) - selectors_found = [sel.css_rule for sel in gen] - assert {".test", "p.test"} == set(selectors_found) +class TestGenerateUniqueSelectorsForNodes: + def test_basic(self): + page1_html = '<html><body><p class="test">test</p><p>bla</p></body></html>' + page1 = Page(page1_html) + page2_html = '<html><body><div></div><p class="test">hallo</p></body></html>' + page2 = Page(page2_html) -class TestGenerateSelectorForNodes: - def test_generate_selector_for_nodes(self): - # generate_selector_for_nodes() - pass + nodes = list(map(lambda p: p.select("p.test")[0], [page1, page2])) + gen = generate_unique_selectors_for_nodes(nodes, None, 1) + selectors_found = [sel.css_rule for sel in gen] + + assert "p" not in selectors_found + assert "div" not in selectors_found + + assert ".test" in selectors_found + assert "p.test" in selectors_found + + def test_ids(self): + page = Page( + b""" + <html><body> + <div id="target">test</div> + <div>irrelevant</div> + </body></html>""" + ) + node = page.select("#target")[0] + selectors = get_css_selectors_for_node(node) + assert selectors == ["#target"] + + def test_multi_parents(self): + page = Page(b'<html><body><div id="target"><p>test</p></div><div><p></p></div>') + node = page.select("#target")[0].select("p")[0] + selectors = get_css_selectors_for_node(node) + assert "#target p" in selectors diff --git a/tests/test_training.py b/tests/test_training.py index ccd0441..f0f1703 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,4 +1,3 @@ -import pytest from mlscraper.html import Page from mlscraper.matches import TextValueExtractor from mlscraper.samples import Sample @@ -61,6 +60,27 @@ def test_train_scraper_list_of_dicts(): assert isinstance(value_scraper.extractor, TextValueExtractor) +def test_train_scraper_multipage(): + training_set = TrainingSet() + for items in ["ab", "cd"]: + html = b""" + <html><body> + <div class="target"> + <ul><li>%s</li><li>%s</li></ul> + </div> + </body></html> + """ % ( + items[0].encode(), + items[1].encode(), + ) + training_set.add_sample(Sample(Page(html), [items[0], items[1]])) + scraper = train_scraper(training_set) + assert scraper.selector.css_rule == "li" + assert scraper.get( + Page(b"""<html><body><ul><li>first</li><li>second</li></body></html>""") + ) == ["first", "second"] + + def test_train_scraper_stackoverflow(stackoverflow_samples): training_set = TrainingSet() for s in stackoverflow_samples: diff --git a/tests/test_util.py b/tests/test_util.py index e69de29..c2428e8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -0,0 +1,9 @@ +from mlscraper.util import no_duplicates_generator_decorator + + +def test_no_duplicates_generator_decorator(): + @no_duplicates_generator_decorator + def decorated_generator(): + yield from [1, 1, 2, 3, 3, 3] + + assert list(decorated_generator()) == [1, 2, 3] |