summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIgor Radovanovic <74266147+IgorWounds@users.noreply.github.com>2023-09-18 20:31:02 +0200
committerIgor Radovanovic <74266147+IgorWounds@users.noreply.github.com>2023-09-18 20:31:02 +0200
commitceff6b9745fef5860ee6bc56b14222f6d8d07b26 (patch)
tree99ef4beb08db19ad28b6941b66fe225934e5d06e
parent4ee4646f5bd3d70bdf90f8a1ffba212120b74d01 (diff)
Final fixes
-rw-r--r--openbb_sdk/providers/tests/test_provider_fetcher.py42
-rw-r--r--openbb_sdk/providers/tests/test_provider_field_dupes.py14
-rw-r--r--openbb_sdk/providers/tests/utils/__init__.py0
-rw-r--r--openbb_sdk/providers/tests/utils/unit_test_generator.py3
4 files changed, 22 insertions, 37 deletions
diff --git a/openbb_sdk/providers/tests/test_provider_fetcher.py b/openbb_sdk/providers/tests/test_provider_fetcher.py
index 07e1fdf17e5..d464a65e661 100644
--- a/openbb_sdk/providers/tests/test_provider_fetcher.py
+++ b/openbb_sdk/providers/tests/test_provider_fetcher.py
@@ -1,46 +1,28 @@
+"""Test if providers and fetchers are covered by tests."""
import os
-
import unittest
from importlib import import_module
-
-from openbb_provider.abstract.fetcher import Fetcher
-from openbb_provider.registry import RegistryLoader
from typing import Dict
+from openbb_provider.abstract.provider import Provider
+from openbb_provider.registry import RegistryLoader
-# TODO : this should be imported from utils
-def get_provider_fetchers() -> Dict[str, Dict[str, Fetcher]]:
- """Get the fetchers from the provider registry."""
- registry = RegistryLoader.from_extensions()
- provider_fetcher_map: Dict[str, Dict[str, Fetcher]] = {}
- for provider_name, provider_cls in registry.providers.items():
- provider_fetcher_map[provider_name] = {}
- for fetcher_name, fetcher_cls in provider_cls.fetcher_dict.items():
- provider_fetcher_map[provider_name][fetcher_name] = fetcher_cls
- return provider_fetcher_map
-
-
-# TODO : this should be imported from utils
-def check_pattern_in_file(file_path: str, pattern: str) -> bool:
- """Check if a pattern is in a file."""
- with open(file_path) as f:
- lines = f.readlines()
- for line in lines:
- if pattern in line:
- return True
- return False
+from providers.tests.utils.unit_test_generator import (
+ check_pattern_in_file,
+ get_provider_fetchers,
+)
-def get_providers():
+def get_providers() -> Dict[str, Provider]:
"""Get the providers from the provider registry."""
- providers = {}
+ providers: Dict[str, Provider] = {}
registry = RegistryLoader.from_extensions()
for provider_name, provider_cls in registry.providers.items():
providers[provider_name] = provider_cls
return providers
-def get_provider_test_files(provider):
+def get_provider_test_files(provider: Provider):
"""Given a provider, return the path to the test file."""
fetchers_dict = provider.fetcher_dict
fetcher_module_name = fetchers_dict[list(fetchers_dict.keys())[0]].__module__
@@ -54,10 +36,10 @@ def get_provider_test_files(provider):
class ProviderFetcherTest(unittest.TestCase):
- """Tests for providers and fetchers"""
+ """Tests for providers and fetchers."""
def test_provider_w_tests(self):
- """Test the provider fetchers - ensure all providers have tests"""
+ """Test the provider fetchers and ensure all providers have tests."""
providers = get_providers()
for provider_name, provider_cls in providers.items():
diff --git a/openbb_sdk/providers/tests/test_provider_field_dupes.py b/openbb_sdk/providers/tests/test_provider_field_dupes.py
index 6bf4efff95e..b3021dc1c1a 100644
--- a/openbb_sdk/providers/tests/test_provider_field_dupes.py
+++ b/openbb_sdk/providers/tests/test_provider_field_dupes.py
@@ -1,3 +1,4 @@
+"""Test for common fields in the provider models that should be standard."""
import glob
import importlib
import inspect
@@ -55,7 +56,8 @@ def get_subclasses_w_keys(module: object, cls: Type) -> Dict[Type, List[str]]:
def get_subclasses(
python_files: List[str], package_name: str, cls: Type
) -> Dict[Type, List[str]]:
- """
+ """Get the subclasses of a class defined in a list of python files.
+
Given a list of python files, and a class, return a dictionary of
subclasses of that class that are defined in those files.
@@ -112,8 +114,7 @@ def child_parent_map(map_: Dict, parents: Dict, module: object) -> None:
def get_path_components(path: str):
- """Given a path, return a list of path components"""
-
+ """Given a path, return a list of path components."""
path_components = []
head, tail = os.path.split(path)
@@ -127,7 +128,8 @@ def get_path_components(path: str):
def match_provider_and_fields(
providers_w_fields: List[Dict[str, List[str]]], duplicated_fields: List[str]
) -> List[str]:
- """
+ """Get the provider and fields that match the duplicated fields.
+
Given a list of providers with fields and duplicated fields,
return a list of matching "Provider:'dup_field'".
"""
@@ -156,11 +158,11 @@ class ProviderFieldDupesTest(unittest.TestCase):
"""Test for common fields in the provider models that should be standard."""
def test_provider_field_dupes(self):
- """
+ """Check for duplicate fields in the provider models.
+
This function checks for duplicate fields in the provider models
and identifies the fields that should be standardized.
"""
-
standard_models_directory = os.path.dirname(standard_models.__file__)
standard_models_files = glob.glob(
os.path.join(standard_models_directory, "*.py")
diff --git a/openbb_sdk/providers/tests/utils/__init__.py b/openbb_sdk/providers/tests/utils/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/openbb_sdk/providers/tests/utils/__init__.py
diff --git a/openbb_sdk/providers/tests/utils/unit_test_generator.py b/openbb_sdk/providers/tests/utils/unit_test_generator.py
index 774e15b803e..8e77ff11dd6 100644
--- a/openbb_sdk/providers/tests/utils/unit_test_generator.py
+++ b/openbb_sdk/providers/tests/utils/unit_test_generator.py
@@ -2,13 +2,14 @@
import os
from typing import Any, Dict, Tuple
-from credentials_schema import test_credentials
from openbb_provider.abstract.fetcher import Fetcher
from openbb_provider.registry import RegistryLoader
from openbb_provider.utils.helpers import to_snake_case
from pydantic.fields import ModelField
from sdk.core.openbb_core.app.provider_interface import ProviderInterface
+from providers.tests.utils.credentials_schema import test_credentials
+
def get_provider_fetchers() -> Dict[str, Dict[str, Fetcher]]:
"""Get the fetchers from the provider registry."""