summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIgor Radovanovic <74266147+IgorWounds@users.noreply.github.com>2023-09-18 15:29:33 +0200
committerIgor Radovanovic <74266147+IgorWounds@users.noreply.github.com>2023-09-18 15:29:33 +0200
commit959aef83ccfafc0411c1950031a6da57e650aa2b (patch)
treefbef9e8687f051d3ceeccd78e1ff1a185fab454a
parent9c553a0b5551e4d4e658ba089528ad7b6f3c5a80 (diff)
Improvements
Co-authored-by: @hjoaquim
-rw-r--r--openbb_sdk/providers/utils/unit_test_generator.py85
1 files changed, 56 insertions, 29 deletions
diff --git a/openbb_sdk/providers/utils/unit_test_generator.py b/openbb_sdk/providers/utils/unit_test_generator.py
index 894c4d417a7..262f4b96385 100644
--- a/openbb_sdk/providers/utils/unit_test_generator.py
+++ b/openbb_sdk/providers/utils/unit_test_generator.py
@@ -1,30 +1,42 @@
"""The unit test generator for the fetchers."""
import os
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Tuple
from credentials_schema import test_credentials
from openbb_provider.abstract.fetcher import Fetcher
+from openbb_provider.registry import RegistryLoader
from openbb_provider.utils.helpers import to_snake_case
from pydantic.fields import ModelField
+
from sdk.core.openbb_core.app.provider_interface import ProviderInterface
-def get_provider_fetchers(
- available_providers: List[str],
-) -> Dict[str, Dict[str, Fetcher]]:
- """Return a list of all fetchers in the provider."""
- fetchers: Dict[str, Dict[str, Fetcher]] = {}
+def get_provider_fetchers() -> Dict[str, Dict[str, Fetcher]]:
+ registry = RegistryLoader.from_extensions()
+ provider_fetcher_map: Dict[str, Dict[str, Fetcher]] = {}
+ for provider_name, provider_cls in registry.providers.items():
+ provider_fetcher_map[provider_name] = {}
+ for fetcher_name, fetcher_cls in provider_cls.fetcher_dict.items():
+ provider_fetcher_map[provider_name][fetcher_name] = fetcher_cls
+ return provider_fetcher_map
+
- for provider in available_providers:
- provider_loaded = __import__(f"openbb_{provider}")
- provider_variable = getattr(provider_loaded, f"{provider}_provider")
- fetcher_dict = provider_variable.fetcher_dict
- for fetcher_name, fetcher_class in fetcher_dict.items():
- if provider not in fetchers:
- fetchers[provider] = {}
- fetchers[provider][fetcher_name] = fetcher_class
+# def get_provider_fetchers(
+# available_providers: List[str],
+# ) -> Dict[str, Dict[str, Fetcher]]:
+# """Return a list of all fetchers in the provider."""
+# fetchers: Dict[str, Dict[str, Fetcher]] = {}
- return fetchers
+# for provider in available_providers:
+# provider_loaded = __import__(f"openbb_{provider}")
+# provider_variable = getattr(provider_loaded, f"{provider}_provider")
+# fetcher_dict = provider_variable.fetcher_dict
+# for fetcher_name, fetcher_class in fetcher_dict.items():
+# if provider not in fetchers:
+# fetchers[provider] = {}
+# fetchers[provider][fetcher_name] = fetcher_class
+
+# return fetchers
def generate_fetcher_unit_tests(path: str) -> None:
@@ -70,7 +82,9 @@ def get_test_params(param_fields: Dict[str, ModelField]) -> Dict[str, Any]:
def write_test_credentials(path: str, provider: str) -> None:
"""Write the mocked credentials to the provider test folders."""
- credentials: Dict[str, str] = test_credentials.get(provider, {})
+ credentials: Tuple[str, str] = test_credentials.get(
+ provider, ("token", "MOCK_TOKEN")
+ )
template = """
test_credentials = UserService().default_user_settings.credentials.dict()
@@ -89,16 +103,21 @@ def vcr_config():
f.write(template.format(credentials_str=str(credentials)))
+def check_pattern_in_file(file_path: str, pattern: str) -> bool:
+ with open(file_path) as f:
+ lines = f.readlines()
+ for line in lines:
+ if pattern in line:
+ return True
+ return False
+
+
def write_fetcher_unit_tests() -> None:
"""Write the fetcher unit tests to the provider test folders."""
provider_interface = ProviderInterface()
- available_providers = provider_interface.available_providers
provider_interface_map = provider_interface.map
- # TODO: Check why the available_providers isn't working
- # but it works when you write it manually
-
- fetchers = get_provider_fetchers(available_providers)
+ fetchers = get_provider_fetchers()
provider_fetchers: Dict[str, Dict[str, str]] = {}
for provider, fetcher_dict in fetchers.items():
@@ -117,17 +136,21 @@ def write_fetcher_unit_tests() -> None:
provider_fetchers[model_name][fetcher_name] = path
# Check if the test is already in the file
- with open(path) as f:
- lines = f.readlines()
- for line in lines:
- if fetcher_path in line and fetcher_name in line:
- return
+ # with open(path) as f:
+ # lines = f.readlines()
+ # for line in lines:
+ # if fetcher_path in line and fetcher_name in line:
+ # return
- with open(path, "a") as f:
- f.write(f"from {fetcher_path} import {fetcher_name}\n")
+ pattern = f"from {fetcher_path} import {fetcher_name}"
+ if not check_pattern_in_file(path, pattern):
+ with open(path, "a") as f:
+ f.write(f"{pattern}\n")
# here we add the credentials and vcr config
- write_test_credentials(path, provider)
+ pattern = "vcr_config"
+ if not check_pattern_in_file(path, pattern):
+ write_test_credentials(path, provider)
test_template = """
@pytest.mark.record_http
@@ -141,6 +164,10 @@ def test_{fetcher_name_snake}(credentials=test_credentials):
for model_name, fetcher_dict in provider_fetchers.items():
for fetcher_name, path in fetcher_dict.items():
+ pattern = f"{fetcher_name}()"
+ if check_pattern_in_file(path, pattern):
+ continue
+
# Add logic here to grab the necessary standardized params and credentials
test_params = get_test_params(
param_fields=provider_interface_map[model_name]["openbb"][