diff options
author | Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com> | 2023-09-18 15:29:33 +0200 |
---|---|---|
committer | Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com> | 2023-09-18 15:29:33 +0200 |
commit | 959aef83ccfafc0411c1950031a6da57e650aa2b (patch) | |
tree | fbef9e8687f051d3ceeccd78e1ff1a185fab454a | |
parent | 9c553a0b5551e4d4e658ba089528ad7b6f3c5a80 (diff) |
Improvements
Co-authored-by: @hjoaquim
-rw-r--r-- | openbb_sdk/providers/utils/unit_test_generator.py | 85 |
1 files changed, 56 insertions, 29 deletions
diff --git a/openbb_sdk/providers/utils/unit_test_generator.py b/openbb_sdk/providers/utils/unit_test_generator.py index 894c4d417a7..262f4b96385 100644 --- a/openbb_sdk/providers/utils/unit_test_generator.py +++ b/openbb_sdk/providers/utils/unit_test_generator.py @@ -1,30 +1,42 @@ """The unit test generator for the fetchers.""" import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple from credentials_schema import test_credentials from openbb_provider.abstract.fetcher import Fetcher +from openbb_provider.registry import RegistryLoader from openbb_provider.utils.helpers import to_snake_case from pydantic.fields import ModelField + from sdk.core.openbb_core.app.provider_interface import ProviderInterface -def get_provider_fetchers( - available_providers: List[str], -) -> Dict[str, Dict[str, Fetcher]]: - """Return a list of all fetchers in the provider.""" - fetchers: Dict[str, Dict[str, Fetcher]] = {} +def get_provider_fetchers() -> Dict[str, Dict[str, Fetcher]]: + registry = RegistryLoader.from_extensions() + provider_fetcher_map: Dict[str, Dict[str, Fetcher]] = {} + for provider_name, provider_cls in registry.providers.items(): + provider_fetcher_map[provider_name] = {} + for fetcher_name, fetcher_cls in provider_cls.fetcher_dict.items(): + provider_fetcher_map[provider_name][fetcher_name] = fetcher_cls + return provider_fetcher_map + - for provider in available_providers: - provider_loaded = __import__(f"openbb_{provider}") - provider_variable = getattr(provider_loaded, f"{provider}_provider") - fetcher_dict = provider_variable.fetcher_dict - for fetcher_name, fetcher_class in fetcher_dict.items(): - if provider not in fetchers: - fetchers[provider] = {} - fetchers[provider][fetcher_name] = fetcher_class +# def get_provider_fetchers( +# available_providers: List[str], +# ) -> Dict[str, Dict[str, Fetcher]]: +# """Return a list of all fetchers in the provider.""" +# fetchers: Dict[str, Dict[str, Fetcher]] = {} - return fetchers +# for provider in available_providers: +# provider_loaded = __import__(f"openbb_{provider}") +# provider_variable = getattr(provider_loaded, f"{provider}_provider") +# fetcher_dict = provider_variable.fetcher_dict +# for fetcher_name, fetcher_class in fetcher_dict.items(): +# if provider not in fetchers: +# fetchers[provider] = {} +# fetchers[provider][fetcher_name] = fetcher_class + +# return fetchers def generate_fetcher_unit_tests(path: str) -> None: @@ -70,7 +82,9 @@ def get_test_params(param_fields: Dict[str, ModelField]) -> Dict[str, Any]: def write_test_credentials(path: str, provider: str) -> None: """Write the mocked credentials to the provider test folders.""" - credentials: Dict[str, str] = test_credentials.get(provider, {}) + credentials: Tuple[str, str] = test_credentials.get( + provider, ("token", "MOCK_TOKEN") + ) template = """ test_credentials = UserService().default_user_settings.credentials.dict() @@ -89,16 +103,21 @@ def vcr_config(): f.write(template.format(credentials_str=str(credentials))) +def check_pattern_in_file(file_path: str, pattern: str) -> bool: + with open(file_path) as f: + lines = f.readlines() + for line in lines: + if pattern in line: + return True + return False + + def write_fetcher_unit_tests() -> None: """Write the fetcher unit tests to the provider test folders.""" provider_interface = ProviderInterface() - available_providers = provider_interface.available_providers provider_interface_map = provider_interface.map - # TODO: Check why the available_providers isn't working - # but it works when you write it manually - - fetchers = get_provider_fetchers(available_providers) + fetchers = get_provider_fetchers() provider_fetchers: Dict[str, Dict[str, str]] = {} for provider, fetcher_dict in fetchers.items(): @@ -117,17 +136,21 @@ def write_fetcher_unit_tests() -> None: provider_fetchers[model_name][fetcher_name] = path # Check if the test is already in the file - with open(path) as f: - lines = f.readlines() - for line in lines: - if fetcher_path in line and fetcher_name in line: - return + # with open(path) as f: + # lines = f.readlines() + # for line in lines: + # if fetcher_path in line and fetcher_name in line: + # return - with open(path, "a") as f: - f.write(f"from {fetcher_path} import {fetcher_name}\n") + pattern = f"from {fetcher_path} import {fetcher_name}" + if not check_pattern_in_file(path, pattern): + with open(path, "a") as f: + f.write(f"{pattern}\n") # here we add the credentials and vcr config - write_test_credentials(path, provider) + pattern = "vcr_config" + if not check_pattern_in_file(path, pattern): + write_test_credentials(path, provider) test_template = """ @pytest.mark.record_http @@ -141,6 +164,10 @@ def test_{fetcher_name_snake}(credentials=test_credentials): for model_name, fetcher_dict in provider_fetchers.items(): for fetcher_name, path in fetcher_dict.items(): + pattern = f"{fetcher_name}()" + if check_pattern_in_file(path, pattern): + continue + # Add logic here to grab the necessary standardized params and credentials test_params = get_test_params( param_fields=provider_interface_map[model_name]["openbb"][ |