diff options
author | Karl Lorey <git@karllorey.com> | 2022-06-13 19:59:36 +0200 |
---|---|---|
committer | Karl Lorey <git@karllorey.com> | 2022-06-13 19:59:36 +0200 |
commit | 88d7101a2fe3ae008181d0f53c011bb990f29a8c (patch) | |
tree | ed55ee80d7725a6b4902aa23e5cceabc7e0444ca | |
parent | 15e6a963a2270a0d523b6f4500007d16faa1b243 (diff) |
Import mlscraper-experiments
-rw-r--r-- | examples/quotes_to_scrape.py | 38 | ||||
-rw-r--r-- | examples/stackoverflow.py | 40 | ||||
-rw-r--r-- | mlscraper/__init__.py | 391 | ||||
-rw-r--r-- | mlscraper/ml.py | 66 | ||||
-rw-r--r-- | mlscraper/parser.py | 78 | ||||
-rw-r--r-- | mlscraper/samples.py | 147 | ||||
-rw-r--r-- | mlscraper/scrapers.py | 53 | ||||
-rw-r--r-- | mlscraper/selectors.py | 108 | ||||
-rw-r--r-- | mlscraper/training.py | 82 | ||||
-rw-r--r-- | mlscraper/util.py | 486 | ||||
-rw-r--r-- | requirements_dev.txt | 12 | ||||
-rw-r--r-- | requirements_fixed.txt | 28 | ||||
-rw-r--r-- | setup.cfg | 4 | ||||
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | tests/static/so.html | 2219 | ||||
-rw-r--r-- | tests/test_basic.py | 76 | ||||
-rw-r--r-- | tests/test_new.py | 10 | ||||
-rw-r--r-- | tests/test_samples.py | 63 | ||||
-rw-r--r-- | tests/test_scrapers.py | 88 | ||||
-rw-r--r-- | tests/test_selectors.py | 48 | ||||
-rw-r--r-- | tests/test_training.py | 34 | ||||
-rw-r--r-- | tests/test_util.py | 104 |
22 files changed, 3207 insertions, 970 deletions
diff --git a/examples/quotes_to_scrape.py b/examples/quotes_to_scrape.py deleted file mode 100644 index d0b28e2..0000000 --- a/examples/quotes_to_scrape.py +++ /dev/null @@ -1,38 +0,0 @@ -import requests - -from mlscraper import SingleItemScraper, SingleItemPageSample - - -def main(): - items = { - "http://quotes.toscrape.com/author/Eleanor-Roosevelt/": { - "name": "Eleanor Roosevelt", - "birth": "October 11, 1884", - }, - "http://quotes.toscrape.com/author/Andre-Gide/": { - "name": "André Gide", - "birth": "November 22, 1869", - }, - "http://quotes.toscrape.com/author/Thomas-A-Edison/": { - "name": "Thomas A. Edison", - "birth": "February 11, 1847", - }, - } - results = {url: requests.get(url) for url in items.keys()} - - # train scraper - samples = [ - SingleItemPageSample(results[url].content, items[url]) for url in items.keys() - ] - scraper = SingleItemScraper.build(samples) - - print("Scraping Albert Einstein") - html = requests.get("http://quotes.toscrape.com/author/Albert-Einstein/").content - result = scraper.scrape(html) - - print("Result: %s" % result) - # > Result: {'birth': 'March 14, 1879', 'name': 'Albert Einstein'} - - -if __name__ == "__main__": - main() diff --git a/examples/stackoverflow.py b/examples/stackoverflow.py deleted file mode 100644 index 45bb78e..0000000 --- a/examples/stackoverflow.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -import requests - -from mlscraper import SingleItemPageSample, RuleBasedSingleItemScraper - - -def main(): - items = { - "https://stackoverflow.com/questions/11227809/why-is-processing-a-sorted-array-faster-than-processing-an-unsorted-array": { - "title": "Why is processing a sorted array faster than processing an unsorted array?" - }, - "https://stackoverflow.com/questions/927358/how-do-i-undo-the-most-recent-local-commits-in-git": { - "title": "How do I undo the most recent local commits in Git?" - }, - "https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do": { - "title": "What does the “yield” keyword do?" - }, - } - - results = {url: requests.get(url) for url in items.keys()} - - # train scraper - samples = [ - SingleItemPageSample(results[url].content, items[url]) for url in items.keys() - ] - scraper = RuleBasedSingleItemScraper.build(samples) - - print("Scraping new question") - html = requests.get( - "https://stackoverflow.com/questions/2003505/how-do-i-delete-a-git-branch-locally-and-remotely" - ).content - result = scraper.scrape(html) - - print("Result: %s" % result) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main() diff --git a/mlscraper/__init__.py b/mlscraper/__init__.py index 9aac9d5..f45ed94 100644 --- a/mlscraper/__init__.py +++ b/mlscraper/__init__.py @@ -1,388 +1,3 @@ -__author__ = """Karl Lorey""" -__email__ = "git@karllorey.com" -__version__ = "0.1.2" - -import logging -import random -import re -from collections import Counter -from itertools import chain -from typing import List - -import pandas as pd -from bs4 import BeautifulSoup -from more_itertools import flatten - -from mlscraper.ml import NodePreprocessing, train_pipeline -from mlscraper.parser import make_soup_page, ExtractionResult -from mlscraper.training import SingleItemPageSample, MultiItemPageSample -from mlscraper.util import ( - get_common_ancestor_for_paths, - get_common_ancestor_for_nodes, - get_tree_path, - generate_css_selectors_for_node, - get_selectors, - derive_css_selector, - generate_path_selectors, - generate_unique_path_selectors, -) - - -def create_single_item_samples(url_to_item): - """ - Creates single page training data for you. - - :param url_to_item: dict with url as key and expected dict as value - :return: samples - """ - import requests - - results = {url: requests.get(url) for url in url_to_item.keys()} - assert all(resp.status_code == 200 for resp in results.values()) - - pages = {url: make_soup_page(results[url].content) for url in url_to_item.keys()} - - # train scraper - samples = [] - for url in url_to_item: - page = pages[url] - item = url_to_item[url] - - # use random sample if found several times to try to get all possible selectors - item_extraction = { - k: ExtractionResult(random.choice(page.find(v))) for k, v in item.items() - } - sample = SingleItemPageSample(page, item_extraction) - samples.append(sample) - return samples - - -class RuleBasedSingleItemScraper: - """A simple scraper that will simply try to infer the best css selectors.""" - - def __init__(self, classes_per_attr): - self.classes_per_attr = classes_per_attr - - @staticmethod - def build(samples: List[SingleItemPageSample]): - attributes = set(flatten(s.item.keys() for s in samples)) - - rules = {} # attr -> selector - for attr in attributes: - logging.info("Training attribute %s" % attr) - - # get all potential matches - matching_nodes = flatten([s.page.find(s.item[attr]) for s in samples]) - # since uniqueness requires selection over and over, we don't use generate_unique_path... here - path_selector_generator = ( - generate_path_selectors(node._soup_node) for node in matching_nodes - ) - selectors = set(chain(*path_selector_generator)) - - # check if they are unique on every page - # -> for all potential selectors: compute score - selector_scoring = {} # selector -> score - for i, selector in enumerate(selectors): - if selector not in selector_scoring: - logging.info("testing %s (%d/%d)", selector, i, len(selectors)) - matches_per_page = (s.page.select(selector) for s in samples) - matches_per_page_right = [ - len(m) == 1 and m[0].get_text() == s.item[attr] - for m, s in zip(matches_per_page, samples) - ] - score = sum(matches_per_page_right) / len(samples) - selector_scoring[selector] = score - - # find the selector with the best coverage, i.e. the highest accuracy - logging.info("Scoring for %s: %s", attr, selector_scoring) - # sort by score (desc) and selector length (asc) - selectors_sorted = sorted( - selector_scoring.items(), key=lambda x: (x[1], -len(x[0])), reverse=True - ) - logging.info("Best scores for %s: %s", attr, selectors_sorted[:3]) - try: - selector_best = selectors_sorted[0][0] - if selector_scoring[selector_best] < 1: - logging.warning( - "Best selector for %s does not work for all samples (score is %f)" - % (attr, selector_scoring[selector_best]) - ) - - rules[attr] = selector_best - except IndexError: - logging.warning("No selector found for %s", attr) - print(rules) - return RuleBasedSingleItemScraper(rules) - - def scrape(self, html): - page = make_soup_page(html) - item = { - k: page.select(self.classes_per_attr[k])[0].get_text() - for k in self.classes_per_attr.keys() - } - return item - - -class SingleItemScraper: - def __init__(self, classifiers, min_match_proba=0.7): - self.classifiers = classifiers - self.min_match_proba = min_match_proba - - @staticmethod - def build(samples: List[SingleItemPageSample]): - """ - Build a scraper by inferring rules. - - :param samples: Samples to train - :return: the scraper - """ - - # parse html - soups = [sample.page._soup for sample in samples] - - # find samples on the pages - matches = [] - for sample, soup in zip(samples, soups): - matches_per_item = {} - for key in sample.item.keys(): - needle = sample.item[key] - - # currently, we can only find strings with .text extraction - assert isinstance(needle, str), "Only strings supported" - - # search for text, check if parent returns this text - text_matches = soup.find_all(text=re.compile(needle)) - logging.debug("Matches for %s: %s", needle, text_matches) - text_parents = (ns.parent for ns in text_matches) - tag_matches = [p for p in text_parents if extract_text(p) == needle] - matches_per_item[key] = tag_matches - matches.append(matches_per_item) - # print(matches) - - matches_unique = [] - for matches_item, sample in zip(matches, samples): - if all(len(matches_item[attr]) == 1 for attr in sample.item.keys()): - matches_item_unique = { - attr: matches_item[attr][0] for attr in sample.item.keys() - } - matches_unique.append(matches_item_unique) - else: - logging.warning( - "Sample values not unique on page, discarding: %s -> %s" - % (sample, matches_item) - ) - matches_unique.append(None) - # print(matches_unique) - - # for each attribute: - attributes = set(flatten(sample.item.keys() for sample in samples)) - print(attributes) - classifiers = {} - for attr in attributes: - print(attr) - # 1. take all items with unique samples - # 2. mark nodes that match sample as true, others as false - training_data = [] - for matches_item, soup in zip(matches_unique, soups): - if matches_item: - node_to_find = matches_item[attr] - training_data.extend( - [(node, node == node_to_find) for node in soup.descendants] - ) - else: - logging.warning("Skipping one sample for %s" % attr) - - # 3. train classifier - df = pd.DataFrame(training_data, columns=["node", "target"]) - if len(df[df["target"] == False]) > 100: - df_train = pd.concat( - [ - df[df["target"] == True], - df[df["target"] == False].sample(frac=0.01), - ] - ) - else: - df_train = df - - pipeline = train_pipeline(df_train["node"], df_train["target"]) - # pipeline = train_pipeline(df["node"], df["target"]) - classifiers[attr] = pipeline - - return SingleItemScraper(classifiers) - - def scrape(self, html): - soup = BeautifulSoup(html, "lxml") - - # data is a dict, because page is one item - data = {} - - nodes = list(soup.descendants) - for attr in self.classifiers.keys(): - # predict proba of all nodes - node_predictions = self.classifiers[attr].predict_proba(nodes) - - # turn it into a data frame - df = pd.DataFrame(node_predictions, columns=["is_noise", "is_target"]) - - # re-add nodes to extract them later - df["node"] = pd.Series(nodes) - - # get best match - df_nodes_by_proba = df.sort_values("is_target", ascending=False) - best_match = df_nodes_by_proba.iloc[0] - - # use if probability > threshold - if best_match["is_target"] > self.min_match_proba: - # todo apply extractor based on attribute (.text, attrs[href], etc.) - data[attr] = extract_text(best_match["node"]) - else: - logging.warning( - "%s not found in html, probability %f < %f", - attr, - best_match["is_target"], - self.min_match_proba, - ) - # return the data dictionary - return data - - -def extract_text(node): - return node.text.strip() - - -class MultiItemScraper: - """ - Extracts several items from a single page, e.g. all results from a page of search results. - """ - - def __init__(self, parent_selector, value_selectors): - self.parent_selector = parent_selector - self.value_selectors = value_selectors - - @staticmethod - def build(samples: List[MultiItemPageSample]): - """ - Build the scraper by inferring rules. - """ - assert len(samples) == 1, "can only train with one sample" - - items = samples[0].items - - # observation: - # - if multiple common distinctive ancestors exist, - # one can choose the ones easier to match, as this will not affect later value selection - - # assumptions: - # - all samples given must be all samples that can be found on a page - # -> much easier evaluation because we can detect false positives - # - for each sample, there's at least one distinct ancestor containing only this sample - # -> will fail for inline results but allow for much easier selector generation - # - there won't be too many duplicate values so we can ignore samples that contain them - # -> makes ancestor computation much easier by increasing false positives - - # glossary - # - ancestors: the path of elements in the DOM from root to a specific element - - soup = samples[0].page._soup - - # 1. find all examples on the site - matches = [] - for item in items: - matches_item = {} - for key, value in item.items(): - elements_with_value = soup.find_all(text=value) - print("{}: {}".format(key, elements_with_value)) - matches_item[key] = elements_with_value - matches.append(matches_item) - print(matches) - - # 1.1 exclude duplicate samples for now to avoid errors - # (e.g. 2018 existing 4x on a site) - matches_unique = [] - for item, matches_item in zip(items, matches): - multiple_occurence_keys = [ - k for k in matches_item if len(matches_item[k]) != 1 - ] - if not multiple_occurence_keys: - matches_unique.append({k: v[0] for k, v in matches_item.items()}) - else: - matches_unique.append(None) - error = "Item %r dropped because attribute(s) %r were found several times on page" - logging.warning(error, item, multiple_occurence_keys) - print(matches_unique) - - # 2. extract the distinct ancestors for each sample (one sample, one ancestor) - # 2.1 find deepest common ancestor of each item - deepest_common_ancestor_per_item = [ - get_common_ancestor_for_nodes([node for node in match.values()]) - for match in matches_unique - ] - print(deepest_common_ancestor_per_item) - assert len(set(deepest_common_ancestor_per_item)) == len( - deepest_common_ancestor_per_item - ) - - # 2.2 get a list of distinctive ancestors for each item - deepest_common_ancestor_of_items = get_common_ancestor_for_nodes( - deepest_common_ancestor_per_item - ) - print(deepest_common_ancestor_of_items) - - tree_paths = [get_tree_path(node) for node in deepest_common_ancestor_per_item] - unique_ancestors_per_item = [] - for tree_path_of_item in tree_paths: - uniques = [ - node - for node in tree_path_of_item - if not any(node in tp for tp in tree_paths if tp != tree_path_of_item) - ] - unique_ancestors_per_item.append(uniques) - - # 3. find a selector to match exactly one distinctive ancestor for each sample - ancestor_selector = derive_css_selector(unique_ancestors_per_item, soup) - if not ancestor_selector: - raise RuntimeError("Found no selector") - print(ancestor_selector) - - # select all ancestors on the page - ancestors = soup.select(ancestor_selector) - - # 4. find simplest selector from these distinctive ancestors to the sample values - value_selectors = {} - for attr in set([attr for item in items for attr in item.keys()]): - # try to infer a rule that matches attribute for most items given the selector - - # compute value nodes and resp. ancestor nodes - value_nodes = [match[attr].parent for match in matches_unique] - value_parents = [ - next(iter(set(value_node.parents).intersection(ancestors))) - for value_node in value_nodes - ] - print(value_nodes) - print(value_parents) - - # get all potential candidates for value selectors - value_selector_cand = [ - list(get_selectors(node, parent)) - for node, parent in zip(value_nodes, value_parents) - ] - - # merge all selector candidates and find the most common one - value_selector = Counter(flatten(value_selector_cand)).most_common(1)[0][0] - value_selectors[attr] = value_selector - print(value_selectors) - - return MultiItemScraper(ancestor_selector, value_selectors) - - def scrape(self, html): - data = [] - - soup = BeautifulSoup(html, "lxml") - ancestors = soup.select(self.parent_selector) - for ancestor in ancestors: - data_single = { - attr: ancestor.select(selector)[0].text.strip() - for attr, selector in self.value_selectors.items() - } - data.append(data_single) - return data +# create hierarchy for training +# - for all items: find selectable object +# - for all selectable objects: find path to other items diff --git a/mlscraper/ml.py b/mlscraper/ml.py deleted file mode 100644 index 7ba35f6..0000000 --- a/mlscraper/ml.py +++ /dev/null @@ -1,66 +0,0 @@ -import logging - -import pandas as pd -from sklearn.base import TransformerMixin -from sklearn.pipeline import Pipeline -from sklearn.tree import DecisionTreeClassifier - -from mlscraper.util import generate_unique_path_selectors, get_tree_path - - -class NodePreprocessing(TransformerMixin): - """Preprocesses a list of nodes.""" - - def __init__(self): - self.css_selectors = None - - def fit(self, X, y): - # get all css selectors - css_selectors = set() - for node, is_target in zip(X, y): - if is_target: - # todo doesn't work for multi-result pages - # because a unique selector cannot select multiple elements on one page - - css_selectors |= set(generate_unique_path_selectors(node)) - self.css_selectors = css_selectors - logging.info("found %d css selectors" % len(css_selectors)) - - return self - - def transform(self, X, y=None, **fit_params): - logging.info("starting transformation (%d nodes)" % len(X)) - - # create basic df - df = pd.DataFrame(X, columns=["node"]) - - def get_root(node): - return get_tree_path(node)[-1] - - # so we basically want to know which nodes match the selectors - # the problem is that hashing takes very long in bs4 - - for i, css_selector in enumerate(self.css_selectors): - logging.info("Selector %d/%d" % (i, len(self.css_selectors))) - col = "select: %s" % css_selector - df[col] = df["node"].apply(lambda n: n in get_root(n).select(css_selector)) - - return df[[c for c in df.columns if c not in ["node"]]] - - -def train_pipeline(nodes, targets): - assert len(nodes) == len(targets), "len(nodes) != len(targets)" - - pipeline_steps = [ - ("pre", NodePreprocessing()), - ("classifier", DecisionTreeClassifier(class_weight="balanced")), - ] - pipeline = Pipeline(steps=pipeline_steps) - - pipeline.fit(nodes, targets) - - return pipeline - - -def extract_classes(n): - return n.attrs.get("class", []) if hasattr(n, "attrs") else [] diff --git a/mlscraper/parser.py b/mlscraper/parser.py deleted file mode 100644 index 4bbce1a..0000000 --- a/mlscraper/parser.py +++ /dev/null @@ -1,78 +0,0 @@ -# everything related to parsing html -import logging -import re -from abc import ABC - -from bs4 import BeautifulSoup, Tag - - -class Page(ABC): - def select(self, css_selector): - raise NotImplementedError() - - def find(self, needle): - raise NotImplementedError() - - -class Node(ABC): - pass - - -class SoupPage(Page): - - _soup = None - - def __init__(self, soup: BeautifulSoup): - self._soup = soup - - def select(self, css_selector): - try: - return [SoupNode(res) for res in self._soup.select(css_selector)] - except NotImplementedError: - logging.warning( - "ignoring selector %s: not implemented by BS4" % css_selector - ) - return [] - - def find(self, needle): - assert type(needle) == str, "can only find strings, %s given" % type(needle) - text_matches = self._soup.find_all(text=re.compile(needle)) - logging.debug("Matches for %s: %s", needle, text_matches) - text_parents = (ns.parent for ns in text_matches) - tag_matches = [p for p in text_parents if extract_soup_text(p) == needle] - return [SoupNode(m) for m in tag_matches] - - -def extract_soup_text(tag: Tag): - return tag.text.strip() - - -class SoupNode(Node): - _soup_node = None - - def __init__(self, node): - self._soup_node = node - - def __eq__(self, other): - return self._soup_node.__eq__(other._soup_node) - - def __hash__(self): - return self._soup_node.__hash__() - - def get_text(self): - return extract_soup_text(self._soup_node) - - -def make_soup_page(html): - soup = BeautifulSoup(html, "lxml") - return SoupPage(soup) - - -class ExtractionResult: - """Specific result found on a page""" - - node = None - # extraction_method = None - - def __init__(self, node: Node): - self.node = node diff --git a/mlscraper/samples.py b/mlscraper/samples.py new file mode 100644 index 0000000..876ea44 --- /dev/null +++ b/mlscraper/samples.py @@ -0,0 +1,147 @@ +import typing +from itertools import product + +from mlscraper.util import DictMatch, ListMatch, Page + + +class ItemStructureException(Exception): + pass + + +class Sample: + def __init__(self, page: Page, value: typing.Union[str, list, dict]): + self.page = page + self.value = value + + def __repr__(self): + return f"<{self.__class__.__name__} {self.page=}, {self.value=}>" + + def get_matches(self): + # todo: fix creating new sample objects, maybe by using Item class? + + if isinstance(self.value, str): + return self.page.find_all(self.value) + + if isinstance(self.value, list): + matches_by_value = [Sample(self.page, v).get_matches() for v in self.value] + + # generate list of combinations + # todo filter combinations that use the same matches twice + match_combis = product(*matches_by_value) + + return [ListMatch(tuple(match_combi)) for match_combi in match_combis] + + if isinstance(self.value, dict): + matches_by_key = { + k: Sample(self.page, self.value[k]).get_matches() for k in self.value + } + + return [ + DictMatch(dict(zip(matches_by_key.keys(), mc))) + for mc in product(*matches_by_key.values()) + ] + + raise RuntimeError(f"unsupported value: {self.value}") + + +class TrainingSet: + """ + Class containing samples for all pages. + """ + + item = None + + def add_sample(self, sample: Sample): + if not self.item: + self.item = Item.create_from(sample.value) + + self.item.add_sample(sample) + + def __repr__(self): + return f"<{self.__class__.__name__} {self.item=}>" + + +class Item: + """ + The structure to scrape. + """ + + samples = None + + @classmethod + def create_from(cls, item): + if isinstance(item, str): + return ValueItem() + elif isinstance(item, list): + return ListItem() + elif isinstance(item, dict): + return DictItem() + else: + raise ItemStructureException(f"unsupported item: {type(item)}") + + def __init__(self): + self.samples = [] + + def add_sample(self, sample: Sample): + self.samples.append(sample) + + def __repr__(self): + return f"<{self.__class__.__name__} {self.samples=}>" + + +class DictItem(Item): + item_per_key = None + + def __init__(self): + super().__init__() + self.item_per_key = {} + + def add_sample(self, sample: Sample): + if not isinstance(sample.value, dict): + raise ItemStructureException(f"dict expected, {sample.value} given") + + super().add_sample(sample) + + for key, value in sample.value.items(): + if key not in self.item_per_key: + self.item_per_key[key] = Item.create_from(value) + + value_sample = Sample(sample.page, value) + self.item_per_key[key].add_sample(value_sample) + + +class ListItem(Item): + item = None + + def __init__(self): + super().__init__() + self.item = None + + def add_sample(self, sample: Sample): + if not isinstance(sample.value, list): + raise ItemStructureException(f"list expected, {sample.value} given") + + super().add_sample(sample) + + if not self.item and len(sample.value): + self.item = Item.create_from(sample.value[0]) + + for v in sample.value: + self.item.add_sample(Sample(sample.page, v)) + + +class ValueItem(Item): + def add_sample(self, sample: Sample): + if not isinstance(sample.value, str): + raise ItemStructureException(f"str expected, {sample.value} given") + super().add_sample(sample) + + +def make_training_set(pages, items): + assert len(pages) == len(items) + + ts = TrainingSet() + for p, i in zip(pages, items): + ts.add_sample(Sample(p, i)) + + return ts diff --git a/mlscraper/scrapers.py b/mlscraper/scrapers.py new file mode 100644 index 0000000..06f534d --- /dev/null +++ b/mlscraper/scrapers.py @@ -0,0 +1,53 @@ +import typing + +from mlscraper.util import Extractor, Node, Selector + + +class Scraper: + def get(self, node: Node): + raise NotImplementedError() + + +class DictScraper(Scraper): + scraper_per_key = None + + def __init__(self, scraper_per_key: typing.Dict[str, Scraper]): + self.scraper_per_key = scraper_per_key + + def get(self, node: Node): + return {key: scraper.get(node) for key, scraper in self.scraper_per_key.items()} + + def __repr__(self): + return f"<DictScraper {self.scraper_per_key=}>" + + +class ListScraper(Scraper): + selector = None + scraper = None + + def __init__(self, selector: Selector, scraper: Scraper): + self.selector = selector + self.scraper = scraper + + def get(self, node: Node): + return [ + self.scraper.get(item_node) for item_node in self.selector.select_all(node) + ] + + def __repr__(self): + return f"<ListScraper {self.scraper=}>" + + +class ValueScraper(Scraper): + selector = None + extractor = None + + def __init__(self, selector: Selector, extractor: Extractor): + self.selector = selector + self.extractor = extractor + + def get(self, node: Node): + return self.extractor.extract(self.selector.select_one(node)) + + def __repr__(self): + return f"<ValueScraper {self.selector=}, {self.extractor=}>" diff --git a/mlscraper/selectors.py b/mlscraper/selectors.py new file mode 100644 index 0000000..3e3ae82 --- /dev/null +++ b/mlscraper/selectors.py @@ -0,0 +1,108 @@ +import logging +import typing +from itertools import product + +from more_itertools import flatten + +from mlscraper.samples import Sample +from mlscraper.util import Matcher, Node, Page, Selector + + +class CssRuleSelector(Selector): + def __init__(self, css_rule): + self.css_rule = css_rule + + def select_one(self, page: Page): + return page.select(self.css_rule)[0] + + def select_all(self, page): + return page.select(self.css_rule) + + def __repr__(self): + return f"<{self.__class__.__name__} {self.css_rule=}>" + + +def generate_selector_for_nodes(nodes, roots): + if roots is None: + logging.info("roots is None, setting roots manually") + roots = [n.get_root() for n in nodes] + + nodes_per_root = {} + for root in set(roots): + nodes_per_root[root] = {n for n, r in zip(nodes, roots) if r == root} + + selectors_seen = set() + + for node in nodes: + for sel in node.generate_path_selectors(): + if sel not in selectors_seen: + print([set(root.select(sel)) for root in nodes_per_root.keys()]) + print([nodes_per_root[root] for root in nodes_per_root.keys()]) + |