summaryrefslogtreecommitdiffstats
path: root/tests/test_training.py
blob: f0f1703b758604e2f28623dafd0aa90add6d1412 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from mlscraper.html import Page
from mlscraper.matches import TextValueExtractor
from mlscraper.samples import Sample
from mlscraper.samples import TrainingSet
from mlscraper.scrapers import ListScraper
from mlscraper.scrapers import ValueScraper
from mlscraper.selectors import CssRuleSelector
from mlscraper.selectors import PassThroughSelector
from mlscraper.training import train_scraper


def test_train_scraper_simple_list():
    training_set = TrainingSet()
    page = Page(b"<html><body><p>a</p><i>noise</i><p>b</p><p>c</p></body></html>")
    sample = Sample(
        page,
        ["a", "b", "c"],
    )
    training_set.add_sample(sample)
    scraper = train_scraper(training_set)

    # check list scraper
    assert isinstance(scraper, ListScraper)
    assert isinstance(scraper.selector, CssRuleSelector)
    assert scraper.selector.css_rule == "p"

    # check item scraper
    item_scraper = scraper.scraper
    assert isinstance(item_scraper, ValueScraper)
    assert isinstance(item_scraper.selector, PassThroughSelector)
    assert isinstance(item_scraper.extractor, TextValueExtractor)


def test_train_scraper_list_of_dicts():
    html = b"""
    <html>
    <body>
    <div><p>a</p><p>b</p></div>
    <div><p>c</p><p>d</p></div>
    </body>
    </html
    """
    page = Page(html)
    sample = Sample(page, [["a", "b"], ["c", "d"]])
    training_set = TrainingSet()
    training_set.add_sample(sample)
    scraper = train_scraper(training_set)
    assert isinstance(scraper, ListScraper)
    assert isinstance(scraper.selector, CssRuleSelector)
    assert scraper.selector.css_rule == "div"

    inner_scraper = scraper.scraper
    assert isinstance(inner_scraper, ListScraper)
    assert isinstance(inner_scraper.selector, CssRuleSelector)
    assert inner_scraper.selector.css_rule == "p"

    value_scraper = inner_scraper.scraper
    assert isinstance(value_scraper, ValueScraper)
    assert isinstance(value_scraper.selector, PassThroughSelector)
    assert isinstance(value_scraper.extractor, TextValueExtractor)


def test_train_scraper_multipage():
    training_set = TrainingSet()
    for items in ["ab", "cd"]:
        html = b"""
        <html><body>
        <div class="target">
        <ul><li>%s</li><li>%s</li></ul>
        </div>
        </body></html>
        """ % (
            items[0].encode(),
            items[1].encode(),
        )
        training_set.add_sample(Sample(Page(html), [items[0], items[1]]))
    scraper = train_scraper(training_set)
    assert scraper.selector.css_rule == "li"
    assert scraper.get(
        Page(b"""<html><body><ul><li>first</li><li>second</li></body></html>""")
    ) == ["first", "second"]


def test_train_scraper_stackoverflow(stackoverflow_samples):
    training_set = TrainingSet()
    for s in stackoverflow_samples:
        training_set.add_sample(s)

    scraper = train_scraper(training_set)
    assert isinstance(scraper, ListScraper)
    assert isinstance(scraper.selector, CssRuleSelector)

    scraping_result = scraper.get(stackoverflow_samples[0].page)
    scraping_sample = stackoverflow_samples[0].value
    assert scraping_result == scraping_sample