summaryrefslogtreecommitdiffstats
path: root/mlscraper/selectors.py
blob: 904e984e41d9050e7807621399710e5996e4ed09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import logging
from itertools import product

from mlscraper.html import Node
from mlscraper.util import no_duplicates_generator_decorator
from more_itertools import first
from more_itertools import powerset


class Selector:
    """
    Class to select nodes from another node.
    """

    def select_one(self, node: Node) -> Node:
        raise NotImplementedError()

    def select_all(self, node: Node) -> list[Node]:
        raise NotImplementedError()


class PassThroughSelector(Selector):
    def select_one(self, node: Node) -> Node:
        return node

    def select_all(self, node: Node) -> list[Node]:
        # this does not make sense as we have only one node to pass through
        raise RuntimeError("cannot apply select_all to PassThroughSelector")


class CssRuleSelector(Selector):
    def __init__(self, css_rule):
        self.css_rule = css_rule

    def select_one(self, node: Node):
        selection = node.select(self.css_rule)
        if not selection:
            raise AssertionError(
                f"css rule does not match any node ({self.css_rule=}, {node=})"
            )
        return selection[0]

    def select_all(self, node: Node):
        return node.select(self.css_rule)

    def __repr__(self):
        return f"<{self.__class__.__name__} {self.css_rule=}>"


def generate_unique_selectors_for_nodes(nodes: list[Node], roots, complexity: int):
    """
    generate a unique selector which only matches the given nodes.
    """
    if roots is None:
        logging.info("roots is None, setting roots manually")
        roots = [n.root for n in nodes]

    nodes_per_root = {r: [n for n in nodes if n.has_parent(r)] for r in set(roots)}
    for selector in generate_selectors_for_nodes(nodes, roots, complexity):
        if all(
            selector.select_all(r) == nodes_of_root
            for r, nodes_of_root in nodes_per_root.items()
        ):
            yield selector


@no_duplicates_generator_decorator
def generate_selectors_for_nodes(nodes: list[Node], roots, complexity: int):
    """
    Generate a selector which matches the given nodes.
    """

    logging.info(
        f"trying to find selector for nodes ({nodes=}, {roots=}, {complexity=})"
    )
    assert nodes, "no nodes given"
    assert roots, "no roots given"
    assert len(nodes) == len(roots)

    direct_css_selectors = list(_generate_direct_css_selectors_for_nodes(nodes))
    for direct_css_selector in direct_css_selectors:
        yield CssRuleSelector(direct_css_selector)

    ancestors_below_roots = [
        [p for p in n.parents if p.has_parent(r) and p.tag_name != "html"]
        for n, r in zip(nodes, roots)
    ]
    for ancestors in product(*ancestors_below_roots):
        for ancestor_selector_raw in _generate_direct_css_selectors_for_nodes(
            ancestors
        ):
            # generate refinement selectors for parents
            # e.g. if selectivity of child selector is not enough
            for css_selector_raw in direct_css_selectors:
                css_selector_combined = ancestor_selector_raw + " " + css_selector_raw
                yield CssRuleSelector(css_selector_combined)

                # make parent selector
                if all(node.parent == parent for node, parent in zip(nodes, ancestors)):
                    yield CssRuleSelector(
                        f"{ancestor_selector_raw} > {css_selector_raw}"
                    )


def _generate_direct_css_selectors_for_nodes(nodes: list[Node]):
    # pseudo classes apply to already generated selectors
    # and can thus be applied in retrospect

    # see: https://developer.mozilla.org/en-US/docs/Web/CSS/:nth-child
    for css_selector in _generate_direct_css_selectors_for_nodes_without_pseudo(nodes):
        yield css_selector

    # pull to the end as far as possible
    for css_selector in _generate_direct_css_selectors_for_nodes_without_pseudo(nodes):
        if all(n.tag_name not in ["html", "body"] for n in nodes):
            child_indexes = [n.parent.select(css_selector).index(n) for n in nodes]
            if len(set(child_indexes)) == 1:
                # nth is indexed with 1
                nth = 1 + child_indexes[0]
                yield f"{css_selector}:nth-child({nth})"


def _generate_direct_css_selectors_for_nodes_without_pseudo(nodes: list[Node]):
    # check for same tag name
    is_same_tag = len({n.tag_name for n in nodes}) == 1
    common_tag_name = nodes[0].tag_name if is_same_tag else None
    if common_tag_name:
        yield common_tag_name

    # check for common id
    common_ids = {n.id for n in nodes}
    is_same_id = len(common_ids) == 1
    if is_same_id and None not in common_ids:
        yield "#" + first(common_ids)

    # check for common classes
    common_classes = set.intersection(*[set(n.classes) for n in nodes])
    for class_combination in powerset(common_classes):
        if class_combination:
            css_selector = "".join(map(lambda cl: "." + cl, class_combination))
            yield css_selector

            # if same tag name, also yield tag_name + selector
            if common_tag_name:
                yield common_tag_name + css_selector
        else:
            # empty combination -> ignore
            pass

    # check for common attributes
    # see: https://developer.mozilla.org/en-US/docs/Web/CSS/Attribute_selectors
    if common_tag_name:
        common_attributes = set.intersection(
            *[set(n.html_attributes.keys()) for n in nodes]
        )
        common_attributes_filtered = [
            ca for ca in common_attributes if ca not in ["id", "class", "rel"]
        ]
        for common_attribute in common_attributes_filtered:
            yield f"{common_tag_name}[{common_attribute}]"

            # check for common attribute values
            attribute_values = {n.html_attributes[common_attribute] for n in nodes}
            if len(attribute_values) == 1:
                common_attribute_value = first(attribute_values)
                yield f'{common_tag_name}[{common_attribute}="{common_attribute_value}"]'