From 80d3ac613101886e210ee1d6f9d5506cf3a6ae07 Mon Sep 17 00:00:00 2001 From: Henrique Joaquim Date: Fri, 22 Mar 2024 14:22:29 +0000 Subject: [BugFix] No event loop when exporting images (#6249) * duplicated warning * if the event loop was closed - create a new one to process the image --- .../obbject_extensions/charting/openbb_charting/core/backend.py | 8 +++++++- .../charting/openbb_charting/core/openbb_figure.py | 6 +----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py b/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py index ce8b706f5cc..457df7d5981 100644 --- a/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py +++ b/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py @@ -152,7 +152,6 @@ class Backend(PyWry): theme: Optional[str] = None, ) -> dict: """Get the json update for the backend.""" - posthog: Dict[str, Any] = dict(collect_logs=self.charting_settings.log_collect) if ( self.charting_settings.log_collect @@ -222,6 +221,11 @@ class Backend(PyWry): self.send_outgoing(outgoing) if export_image and isinstance(export_image, Path): + if self.loop.is_closed(): + # Create a new event loop + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_until_complete(self.process_image(export_image)) async def process_image(self, export_image: Path): @@ -507,6 +511,7 @@ if not PLOTLYJS_PATH.exists() and not JUPYTER_NOTEBOOK: def create_backend(charting_settings: Optional["ChartingSettings"] = None): + """Create the backend.""" # # pylint: disable=import-outside-toplevel from openbb_core.app.model.charts.charting_settings import ChartingSettings @@ -517,6 +522,7 @@ def create_backend(charting_settings: Optional["ChartingSettings"] = None): def get_backend() -> Backend: + """Get the backend instance.""" if BACKEND is None: raise ValueError("Backend not created") return BACKEND diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/core/openbb_figure.py b/openbb_platform/obbject_extensions/charting/openbb_charting/core/openbb_figure.py index 27963e1cd5b..32e5876817f 100644 --- a/openbb_platform/obbject_extensions/charting/openbb_charting/core/openbb_figure.py +++ b/openbb_platform/obbject_extensions/charting/openbb_charting/core/openbb_figure.py @@ -923,11 +923,7 @@ class OpenBBFigure(go.Figure): # If the backend fails, we just show the figure normally # This is a very rare case, but it's better to have a fallback - if getattr(self._charting_settings, "debug_mode", False): - warn(f"Failed to show figure with backend: {e}") - warn( - f"Failed to show figure with backend: {e}" - ) # remove this line when the above lines are figured out + warn(f"Failed to show figure with backend. {e}") # We check if any figures were initialized before the backend failed # If so, we show them with the default plotly backend -- cgit v1.2.3 From 23135abf1b3e42bf5438dc9003b20966ce01e5cf Mon Sep 17 00:00:00 2001 From: montezdesousa <79287829+montezdesousa@users.noreply.github.com> Date: Fri, 22 Mar 2024 17:53:40 +0000 Subject: [BugFix] - Bring back mypy (#6242) * first batch * another * another * another batch unyped defs * final untyped defs * update yml * pylint * remove unecessary __init__.py files * Revert "remove unecessary __init__.py files" This reverts commit 5977eccfee7e2f536d7f2d039ba7a753d4c66695. * fix company_news typing * rename func * fix tests * tests --- .github/workflows/linting.yml | 2 +- .../core/openbb_core/api/router/commands.py | 15 ++- .../core/openbb_core/app/command_runner.py | 8 +- .../core/openbb_core/app/deprecation.py | 4 +- .../openbb_core/app/model/abstract/singleton.py | 2 +- .../core/openbb_core/app/model/example.py | 4 +- .../core/openbb_core/app/model/metadata.py | 24 +++-- .../core/openbb_core/app/model/obbject.py | 73 +++++++------ .../core/openbb_core/app/provider_interface.py | 118 +++++++++++---------- openbb_platform/core/openbb_core/app/router.py | 10 +- .../core/openbb_core/app/static/account.py | 10 +- .../openbb_core/app/static/utils/decorators.py | 10 +- openbb_platform/core/openbb_core/app/utils.py | 8 +- .../core/openbb_core/provider/abstract/fetcher.py | 2 +- .../core/openbb_core/provider/registry_map.py | 12 ++- .../openbb_core/provider/standard_models/cpi.py | 2 +- .../openbb_core/provider/standard_models/spot.py | 2 +- .../core/openbb_core/provider/utils/helpers.py | 48 ++++++--- .../core/tests/api/test_auth/test_user_auth.py | 2 +- .../tests/api/test_dependency/test_coverage.py | 2 +- .../core/tests/api/test_dependency/test_system.py | 2 +- .../formatters/test_formatter_with_exceptions.py | 2 +- .../core/tests/app/logs/test_handlers_manager.py | 2 +- .../core/tests/app/logs/test_logging_service.py | 46 +++++--- .../tests/app/logs/utils/test_expired_files.py | 3 +- .../core/tests/app/model/abstract/test_tagged.py | 2 +- .../core/tests/app/model/abstract/test_warning.py | 2 +- .../core/tests/app/model/charts/test_chart.py | 2 +- .../core/tests/app/model/hub/test_hub_session.py | 5 +- .../core/tests/app/model/test_command_context.py | 2 +- .../core/tests/app/model/test_defaults.py | 4 +- .../core/tests/app/model/test_metadata.py | 2 +- .../core/tests/app/model/test_obbject.py | 18 ++-- .../core/tests/app/model/test_system_settings.py | 36 +++---- .../core/tests/app/service/test_user_service.py | 11 +- .../core/tests/app/static/test_filters.py | 4 +- .../core/tests/app/test_command_runner.py | 17 ++- openbb_platform/core/tests/app/test_deprecation.py | 10 +- .../core/tests/app/test_extension_loader.py | 25 +++-- .../core/tests/app/test_platform_router.py | 2 +- openbb_platform/core/tests/app/test_utils.py | 12 +-- .../core/tests/provider/abstract/test_data.py | 10 +- .../core/tests/provider/abstract/test_provider.py | 5 +- .../tests/provider/abstract/test_query_params.py | 6 +- .../standard_models/test_standard_models.py | 4 +- .../core/tests/provider/utils/test_client.py | 18 ++-- .../core/tests/provider/utils/test_helpers.py | 2 +- .../fmp/openbb_fmp/models/company_news.py | 2 +- .../openbb_intrinio/models/company_news.py | 8 +- 49 files changed, 343 insertions(+), 279 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index edf6025d597..ccabe3a349d 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -68,8 +68,8 @@ jobs: - run: | # Run linters for openbb_platform if [ -n "${{ env.platform_files }}" ]; then - # TODO: Add mypy to this part of the linting workflow once we're ready pylint ${{ env.platform_files }} + mypy ${{ env.platform_files }} --ignore-missing-imports --check-untyped-defs else echo "No Python files changed in openbb_platform" fi diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py index 96334f549fa..7b644cca372 100644 --- a/openbb_platform/core/openbb_core/api/router/commands.py +++ b/openbb_platform/core/openbb_core/api/router/commands.py @@ -134,19 +134,26 @@ def validate_output(c_out: OBBject) -> OBBject: json_schema_extra = field.json_schema_extra if field else None # case where 1st layer field needs to be excluded - if json_schema_extra and json_schema_extra.get("exclude_from_api", None): + if ( + json_schema_extra + and isinstance(json_schema_extra, dict) + and json_schema_extra.get("exclude_from_api", None) + ): delattr(c_out, key) # if it's a model with nested fields elif is_model(type_): for field_name, field in type_.__fields__.items(): - if field.json_schema_extra and field.json_schema_extra.get( - "exclude_from_api", None + extra = getattr(field, "json_schema_extra", None) + if ( + extra + and isinstance(extra, dict) + and extra.get("exclude_from_api", None) ): delattr(value, field_name) # if it's a yet a nested model we need to go deeper in the recursion - elif is_model(field.annotation): + elif is_model(getattr(field, "annotation", None)): exclude_fields_from_api(field_name, getattr(value, field_name)) for k, v in c_out.model_copy(): diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index c8d7f43e377..33fb4def99f 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -6,7 +6,7 @@ from datetime import datetime from inspect import Parameter, signature from sys import exc_info from time import perf_counter_ns -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from warnings import catch_warnings, showwarning, warn from fastapi.params import Query @@ -181,7 +181,7 @@ class ParametersBuilder: def _warn_kwargs( provider_choices: Dict[str, Any], extra_params: Dict[str, Any], - model: BaseModel, + model: Type[BaseModel], ) -> None: """Warn if kwargs received and ignored by the validation model.""" # We only check the extra_params annotation because ignored fields @@ -247,7 +247,7 @@ class ParametersBuilder: @classmethod def build( cls, - args: Tuple[Any], + args: Tuple[Any, ...], execution_context: ExecutionContext, func: Callable, route: str, @@ -317,7 +317,7 @@ class StaticCommandRunner: async def _execute_func( cls, route: str, - args: Tuple[Any], + args: Tuple[Any, ...], execution_context: ExecutionContext, func: Callable, kwargs: Dict[str, Any], diff --git a/openbb_platform/core/openbb_core/app/deprecation.py b/openbb_platform/core/openbb_core/app/deprecation.py index 482338bf32d..5dd38c1ae54 100644 --- a/openbb_platform/core/openbb_core/app/deprecation.py +++ b/openbb_platform/core/openbb_core/app/deprecation.py @@ -12,10 +12,10 @@ from openbb_core.app.version import VERSION, get_major_minor class DeprecationSummary(str): """A string subclass that can be used to store deprecation metadata.""" - def __new__(cls, value, metadata): + def __new__(cls, value: str, metadata: DeprecationWarning): """Create a new instance of the class.""" obj = str.__new__(cls, value) - obj.metadata = metadata + setattr(obj, "metadata", metadata) return obj diff --git a/openbb_platform/core/openbb_core/app/model/abstract/singleton.py b/openbb_platform/core/openbb_core/app/model/abstract/singleton.py index 7d6bad71a1b..3575186ee22 100644 --- a/openbb_platform/core/openbb_core/app/model/abstract/singleton.py +++ b/openbb_platform/core/openbb_core/app/model/abstract/singleton.py @@ -7,7 +7,7 @@ class SingletonMeta(type, Generic[T]): # TODO : check if we want to update this to be thread safe _instances: Dict[T, T] = {} - def __call__(cls, *args, **kwargs): + def __call__(cls: "SingletonMeta", *args, **kwargs): if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance diff --git a/openbb_platform/core/openbb_core/app/model/example.py b/openbb_platform/core/openbb_core/app/model/example.py index 27f0e77084a..25772107824 100644 --- a/openbb_platform/core/openbb_core/app/model/example.py +++ b/openbb_platform/core/openbb_core/app/model/example.py @@ -2,7 +2,7 @@ from abc import abstractmethod from datetime import date, datetime, timedelta -from typing import Any, Dict, List, Literal, Optional, Union, _GenericAlias +from typing import Any, Dict, List, Literal, Optional, Union, _GenericAlias # type: ignore from pydantic import ( BaseModel, @@ -58,7 +58,7 @@ class APIEx(Example): @staticmethod def _unpack_type(type_: type) -> set: - """Unpack types from types, example Union[List[str], int] -> {str, int}.""" + """Unpack types from types, example Union[List[str], int] -> {typing._GenericAlias, int}.""" if ( hasattr(type_, "__args__") and type(type_) # pylint: disable=unidiomatic-typecheck diff --git a/openbb_platform/core/openbb_core/app/model/metadata.py b/openbb_platform/core/openbb_core/app/model/metadata.py index e5ea6a0f5bf..cf421fca7e0 100644 --- a/openbb_platform/core/openbb_core/app/model/metadata.py +++ b/openbb_platform/core/openbb_core/app/model/metadata.py @@ -1,6 +1,6 @@ from datetime import datetime from inspect import isclass -from typing import Any, Dict +from typing import Any, Dict, Optional, Sequence, Union import numpy as np import pandas as pd @@ -36,7 +36,7 @@ class Metadata(BaseModel): value is kept or trimmed to 80 characters. """ for arg, arg_val in v.items(): - new_arg_val = None + new_arg_val: Optional[Union[str, dict[str, Sequence[Any]]]] = None # Data if isclass(type(arg_val)) and issubclass(type(arg_val), Data): @@ -47,30 +47,32 @@ class Metadata(BaseModel): # List[Data] if isinstance(arg_val, list) and issubclass(type(arg_val[0]), Data): - columns = [list(d.model_dump().keys()) for d in arg_val] - columns = (item for sublist in columns for item in sublist) # flatten + _columns = [list(d.model_dump().keys()) for d in arg_val] + ld_columns = ( + item for sublist in _columns for item in sublist + ) # flatten new_arg_val = { "type": f"List[{type(arg_val[0]).__name__}]", - "columns": list(set(columns)), + "columns": list(set(ld_columns)), } # DataFrame elif isinstance(arg_val, pd.DataFrame): - columns = ( + df_columns = ( list(arg_val.index.names) + arg_val.columns.tolist() if any(index is not None for index in list(arg_val.index.names)) else arg_val.columns.tolist() ) new_arg_val = { "type": f"{type(arg_val).__name__}", - "columns": columns, + "columns": df_columns, } # List[DataFrame] elif isinstance(arg_val, list) and issubclass( type(arg_val[0]), pd.DataFrame ): - columns = [ + ldf_columns = [ ( list(df.index.names) + df.columns.tolist() if any(index is not None for index in list(df.index.names)) @@ -80,7 +82,7 @@ class Metadata(BaseModel): ] new_arg_val = { "type": f"List[{type(arg_val[0]).__name__}]", - "columns": columns, + "columns": ldf_columns, } # Series @@ -92,7 +94,7 @@ class Metadata(BaseModel): # List[Series] elif isinstance(arg_val, list) and isinstance(arg_val[0], pd.Series): - columns = [ + ls_columns = [ ( list(series.index.names) + [series.name] if any(index is not None for index in list(series.index.names)) @@ -102,7 +104,7 @@ class Metadata(BaseModel): ] new_arg_val = { "type": f"List[{type(arg_val[0]).__name__}]", - "columns": columns, + "columns": ls_columns, } # ndarray diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py index 9f2a2da8746..89232702405 100644 --- a/openbb_platform/core/openbb_core/app/model/obbject.py +++ b/openbb_platform/core/openbb_core/app/model/obbject.py @@ -8,6 +8,7 @@ from typing import ( ClassVar, Dict, Generic, + Hashable, List, Literal, Optional, @@ -82,30 +83,32 @@ class OBBject(Tagged, Generic[T]): @classmethod def results_type_repr(cls, params: Optional[Any] = None) -> str: - """Return the results type name.""" + """Return the results type representation.""" results_field = cls.model_fields.get("results") - type_ = params[0] if params else results_field.annotation - name = type_.__name__ if hasattr(type_, "__name__") else str(type_) + type_repr = "Any" + if results_field: + type_ = params[0] if params else results_field.annotation + type_repr = getattr(type_, "__name__", str(type_)) - if (json_schema_extra := results_field.json_schema_extra) is not None: - model = json_schema_extra.get("model") + if json_schema_extra := getattr(results_field, "json_schema_extra", {}): + model = json_schema_extra.get("model", "Any") - if json_schema_extra.get("is_union"): - return f"Union[List[{model}], {model}]" - if json_schema_extra.get("has_list"): - return f"List[{model}]" + if json_schema_extra.get("is_union"): + return f"Union[List[{model}], {model}]" + if json_schema_extra.get("has_list"): + return f"List[{model}]" - return model + return model - if "typing." in str(type_): - unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_)) - name = sub( - r"(\w+\.)*(\w+)?(\, NoneType)?", - r"\2", - unpack_optional, - ) + if "typing." in str(type_): + unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_)) + type_repr = sub( + r"(\w+\.)*(\w+)?(\, NoneType)?", + r"\2", + unpack_optional, + ) - return name + return type_repr @classmethod def model_parametrized_name(cls, params: Any) -> str: @@ -189,9 +192,9 @@ class OBBject(Tagged, Generic[T]): # List[List | str | int | float] | Dict[str, Dict | List | BaseModel] else: try: - df = pd.DataFrame(res) + df = pd.DataFrame(res) # type: ignore[call-overload] # Set index, if any - if index is not None and index in df.columns: + if df is not None and index is not None and index in df.columns: df.set_index(index, inplace=True) except ValueError: @@ -245,7 +248,7 @@ class OBBject(Tagged, Generic[T]): orient: Literal[ "dict", "list", "series", "split", "tight", "records", "index" ] = "list", - ) -> Dict[str, List]: + ) -> Union[Dict[Hashable, Any], List[Dict[Hashable, Any]]]: """Convert results field to a dictionary using any of pandas to_dict options. Parameters @@ -256,25 +259,21 @@ class OBBject(Tagged, Generic[T]): Returns ------- - Dict[str, List] - Dictionary of lists. + Union[Dict[Hashable, Any], List[Dict[Hashable, Any]]] + Dictionary of lists or list of dictionaries if orient is "records". """ - df = self.to_dataframe(index=None) # type: ignore - transpose = False - if orient == "list": - transpose = True - if not isinstance(self.results, dict): - transpose = False - else: # Only enter the loop if self.results is a dictionary - self.results: Dict[str, Any] = self.results # type: ignore - for _, value in self.results.items(): - if not isinstance(value, dict): - transpose = False - break - if transpose: + df = self.to_dataframe(index=None) + if ( + orient == "list" + and isinstance(self.results, dict) + and all( + isinstance(value, dict) + for value in self.results.values() # pylint: disable=no-member + ) + ): df = df.T results = df.to_dict(orient=orient) - if orient == "list" and "index" in results: + if isinstance(results, dict) and orient == "list" and "index" in results: del results["index"] return results diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py index eab7ea1e807..0b5267e58b8 100644 --- a/openbb_platform/core/openbb_core/app/provider_interface.py +++ b/openbb_platform/core/openbb_core/app/provider_interface.py @@ -18,7 +18,7 @@ from openbb_core.provider.query_executor import QueryExecutor from openbb_core.provider.registry_map import MapType, RegistryMap from openbb_core.provider.utils.helpers import to_snake_case -TupleFieldType = Tuple[str, type, Any] +TupleFieldType = Tuple[str, Optional[Type], Optional[Any]] @dataclass @@ -26,8 +26,8 @@ class DataclassField: """Dataclass field.""" name: str - type_: type - default: Any + annotation: Optional[Type] + default: Optional[Any] @dataclass @@ -154,20 +154,19 @@ class ProviderInterface(metaclass=SingletonMeta): def create_executor(self) -> QueryExecutor: """Get query executor.""" - return self._query_executor(self._registry_map.registry) # type: ignore + return self._query_executor(self._registry_map.registry) # type: ignore[operator] @staticmethod def _merge_fields( current: DataclassField, incoming: DataclassField, query: bool = False ) -> DataclassField: - current_name = current.name - current_type = current.type_ - current_desc = getattr(current.default, "description", "") + """Merge 2 dataclass fields.""" + curr_name = current.name + curr_type: Optional[Type] = current.annotation + curr_desc = getattr(current.default, "description", "") - incoming_type = incoming.type_ - incoming_desc = getattr(incoming.default, "description", "") - - F: Union[Callable, object] = Query if query else FieldInfo + inc_type: Optional[Type] = incoming.annotation + inc_desc = getattr(incoming.default, "description", "") def split_desc(desc: str) -> str: """Split field description.""" @@ -175,34 +174,33 @@ class ProviderInterface(metaclass=SingletonMeta): detail = item[0] if item else "" return detail - curr_detail = split_desc(current_desc) - inc_detail = split_desc(incoming_desc) + curr_detail = split_desc(curr_desc) + inc_detail = split_desc(inc_desc) - providers = f"{current.default.title},{incoming.default.title}" + curr_title = getattr(current.default, "title", "") + inc_title = getattr(incoming.default, "title", "") + providers = ",".join([curr_title, inc_title]) formatted_prov = providers.replace(",", ", ") if SequenceMatcher(None, curr_detail, inc_detail).ratio() > 0.8: new_desc = f"{curr_detail} (provider: {formatted_prov})" else: - new_desc = f"{current_desc};\n {incoming_desc}" + new_desc = f"{curr_desc};\n {inc_desc}" - merged_default = F( # type: ignore - default=current.default.default, + QF: Callable = Query if query else FieldInfo # type: ignore[assignment] + merged_default = QF( + default=getattr(current.default, "default", None), title=providers, description=new_desc, ) - merged_type = ( - Union[current_type, incoming_type] - if current_type != incoming_type - else current_type + merged_type: Optional[Type] = ( + Union[curr_type, inc_type] # type: ignore[assignment] + if curr_type != inc_type + else curr_type ) - return DataclassField( - name=current_name, - type_=merged_type, # type: ignore - default=merged_default, - ) + return DataclassField(curr_name, merged_type, merged_default) @staticmethod def _create_field( @@ -213,9 +211,7 @@ class ProviderInterface(metaclass=SingletonMeta): force_optional: bool = False, ) -> DataclassField: new_name = name.replace(".", "_") - # field.type_ don't work for nested types - # field.outer_type_ don't work for Optional nested types - type_ = field.annotation + annotation = field.annotation additional_description = "" if (extra := field.json_schema_extra) and ( @@ -239,7 +235,7 @@ class ProviderInterface(metaclass=SingletonMeta): if field.is_required(): if force_optional: - type_ = Optional[type_] # type: ignore + annotation = Optional[annotation] # type: ignore default = None else: default = ... @@ -247,24 +243,32 @@ class ProviderInterface(metaclass=SingletonMeta): default = field.default if query: - # We need to use query if we want the field description to show up in the - # swagger, it's a fastapi limitation - default = Query( - default=default, - title=provider_name, - description=description, - alias=field.alias or None, - json_schema_extra=field.json_schema_extra, + # We need to use query if we want the field description to show + # up in the swagger, it's a fastapi limitation + return DataclassField( + new_name, + annotation, + Query( + default=default, + title=provider_name, + description=description, + alias=field.alias or None, + json_schema_extra=getattr(field, "json_schema_extra", None), + ), ) - elif provider_name: - default: FieldInfo = Field( - default=default or None, - title=provider_name, - description=description, - json_schema_extra=field.json_schema_extra, + if provider_name: + return DataclassField( + new_name, + annotation, + Field( + default=default or None, + title=provider_name, + description=description, + json_schema_extra=field.json_schema_extra, + ), ) - return DataclassField(new_name, type_, default) + return DataclassField(new_name, annotation, default) @classmethod def _extract_params( @@ -282,7 +286,7 @@ class ProviderInterface(metaclass=SingletonMeta): standard[incoming.name] = ( incoming.name, - incoming.type_, + incoming.annotation, incoming.default, ) else: @@ -305,7 +309,7 @@ class ProviderInterface(metaclass=SingletonMeta): extra[updated.name] = ( updated.name, - updated.type_, + updated.annotation, updated.default, ) @@ -331,7 +335,7 @@ class ProviderInterface(metaclass=SingletonMeta): standard[incoming.name] = ( incoming.name, - incoming.type_, + incoming.annotation, incoming.default, ) else: @@ -357,7 +361,7 @@ class ProviderInterface(metaclass=SingletonMeta): extra[updated.name] = ( updated.name, - updated.type_, + updated.annotation, updated.default, ) @@ -393,14 +397,14 @@ class ProviderInterface(metaclass=SingletonMeta): standard, extra = self._extract_params(providers) result[model_name] = { - "standard": make_dataclass( # type: ignore + "standard": make_dataclass( cls_name=model_name, - fields=list(standard.values()), + fields=list(standard.values()), # type: ignore[arg-type] bases=(StandardParams,), ), - "extra": make_dataclass( # type: ignore + "extra": make_dataclass( cls_name=model_name, - fields=list(extra.values()), + fields=list(extra.values()), # type: ignore[arg-type] bases=(ExtraParams,), ), } @@ -464,14 +468,14 @@ class ProviderInterface(metaclass=SingletonMeta): extra: dict standard, extra = self._extract_data(providers) result[model_name] = { - "standard": make_dataclass( # type: ignore + "standard": make_dataclass( cls_name=model_name, - fields=list(standard.values()), + fields=list(standard.values()), # type: ignore[arg-type] bases=(StandardData,), ), - "extra": make_dataclass( # type: ignore + "extra": make_dataclass( cls_name=model_name, - fields=list(extra.values()), + fields=list(extra.values()), # type: ignore[arg-type] bases=(ExtraData,), ), } diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 3968d5892a0..218629be8ba 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -350,7 +350,7 @@ class SignatureInspector: func = cls.inject_return_type( func=func, - return_map=provider_interface.return_map.get(model), + return_map=provider_interface.return_map.get(model, {}), model=model, ) @@ -374,11 +374,7 @@ class SignatureInspector: return_map: Dict[str, dict], model: str, ) -> Callable[P, OBBject]: - """ - Inject full return model into the function. - Also updates __name__ and __doc__ for API schemas. - """ - + """Inject full return model into the function. Also updates __name__ and __doc__ for API schemas.""" results: Dict[str, Any] = {"list_type": [], "dict_type": []} for provider, return_data in return_map.items(): @@ -397,7 +393,7 @@ class SignatureInspector: if not v: continue - inner_type = SerializeAsAny[ + inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type] Annotated[ Union[tuple(v)], # type: ignore Field(discriminator="provider"), diff --git a/openbb_platform/core/openbb_core/app/static/account.py b/openbb_platform/core/openbb_core/app/static/account.py index a60e49f34bc..d589c265663 100644 --- a/openbb_platform/core/openbb_core/app/static/account.py +++ b/openbb_platform/core/openbb_core/app/static/account.py @@ -42,10 +42,11 @@ class Account: def _log_account_command(func): # pylint: disable=E0213 """Log account command.""" - @wraps(func) + @wraps(func) # type: ignore[arg-type] def wrapped(self, *args, **kwargs): try: - result = func(self, *args, **kwargs) # pylint: disable=E1102 + # pylint: disable=E1102 + result = func(self, *args, **kwargs) # type: ignore[operator] except Exception as e: raise OpenBBError(e) from e finally: @@ -57,8 +58,9 @@ class Account: ls.log( user_settings=user_settings, system_settings=system_settings, - route=f"/account/{func.__name__}", # pylint: disable=E1101 - func=func, + # pylint: disable=E1101 + route=f"/account/{func.__name__}", # type: ignore[attr-defined] + func=func, # type: ignore[arg-type] kwargs={}, # don't want any credentials being logged by accident exec_info=exc_info(), ) diff --git a/openbb_platform/core/openbb_core/app/static/utils/decorators.py b/openbb_platform/core/openbb_core/app/static/utils/decorators.py index 8e9c9bd00f6..61d9b46578b 100644 --- a/openbb_platform/core/openbb_core/app/static/utils/decorators.py +++ b/openbb_platform/core/openbb_core/app/static/utils/decorators.py @@ -54,8 +54,9 @@ def exception_handler(func: Callable[P, R]) -> Callable[P, R]: # Get the last traceback object from the exception tb = e.__traceback__ - while tb.tb_next is not None: - tb = tb.tb_next + if tb: + while tb.tb_next is not None: + tb = tb.tb_next if isinstance(e, ValidationError): error_list = [] @@ -70,7 +71,10 @@ def exception_handler(func: Callable[P, R]) -> Callable[P, R]: f"input_type={type(error['input']).__name__}, " f"input_value={error['input']}]\n" ) - error_info = f" For further information visit {error['url']}\n" + url = error.get("url") + error_info = ( + f" For further information visit {url}\n" if url else "" + ) error_list.append(arg_error + error_details + error_info) error_list.insert(0, validation_error) diff --git a/openbb_platform/core/openbb_core/app/utils.py b/openbb_platform/core/openbb_core/app/utils.py index 4004e3060a9..f7c6c1ca7dc 100644 --- a/openbb_platform/core/openbb_core/app/utils.py +++ b/openbb_platform/core/openbb_core/app/utils.py @@ -3,7 +3,7 @@ import ast import json from datetime import time -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import pandas as pd @@ -17,7 +17,7 @@ from openbb_core.provider.abstract.data import Data def basemodel_to_df( data: Union[List[Data], Data], - index: Optional[Union[None, str, Iterable]] = None, + index: Optional[str] = None, ) -> pd.DataFrame: """Convert list of BaseModel to a Pandas DataFrame.""" if isinstance(data, list): @@ -44,9 +44,7 @@ def basemodel_to_df( df.set_index("date", inplace=True) df.sort_index(axis=0, inplace=True) else: - df = ( - df.set_index(index) if index is not None and index in df.columns else df - ) + df = df.set_index(index) if index and index in df.columns else df return df diff --git a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py index 4ff3482118a..011f03f197b 100644 --- a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py +++ b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py @@ -65,7 +65,7 @@ class Fetcher(Generic[Q, R]): super().__init_subclass__(*args, **kwargs) if cls.aextract_data != Fetcher.aextract_data: - cls.extract_data = cls.aextract_data + cls.extract_data = cls.aextract_data # type: ignore[method-assign] elif cls.extract_data == Fetcher.extract_data: raise NotImplementedError( "Fetcher subclass must implement either extract_data or aextract_data" diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index 38c49aadef2..eb4573a0efb 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -146,7 +146,7 @@ class RegistryMap: def extract_data_model(fetcher: Fetcher, provider_str: str) -> BaseModel: """Extract info (fields and docstring) from fetcher query params or data.""" model: BaseModel = RegistryMap._get_model(fetcher, "data") - + model_name = getattr(model, "__name__", "") fields = {} for field_name, field in model.model_fields.items(): field.serialization_alias = field_name @@ -161,8 +161,8 @@ class RegistryMap: ), ) - provider_model = create_model( - model.__name__.replace("Data", ""), + provider_model = create_model( # type: ignore[call-overload] + model_name.replace("Data", ""), __base__=model, __doc__=model.__doc__, __module__=model.__module__, @@ -171,7 +171,11 @@ class RegistryMap: # Replace the provider models in the modules with the new models we created # To make sure provider field is defined to be the provider string - setattr(sys.modules[model.__module__], model.__name__, provider_model) + # This is hacky, but we need to have `provider: Literal['provider_name']` + # in the model to serve as union discriminator for the API validation + # the alternative would be to specify it manually in all the models + if model_name: + setattr(sys.modules[model.__module__], model_name, provider_model) return provider_model diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py index a1e37f81ba6..472a21df233 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py @@ -76,7 +76,7 @@ class ConsumerPriceIndexQueryParams(QueryParams): country: str = Field( description=QUERY_DESCRIPTIONS.get("country"), - choices=CPI_COUNTRIES, # type: ignore + json_schema_extra={"choices": CPI_COUNTRIES}, # type: ignore[dict-item] ) units: CPI_UNITS = Field( default="growth_same", diff --git a/openbb_platform/core/openbb_core/provider/standard_models/spot.py b/openbb_platform/core/openbb_core/provider/standard_models/spot.py index 24b37ad4179..ca871e1f0c8 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/spot.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/spot.py @@ -32,7 +32,7 @@ class SpotRateQueryParams(QueryParams): category: str = Field( default="spot_rate", description="Rate category. Options: spot_rate, par_yield.", - choices=["par_yield", "spot_rate"], + json_schema_extra={"choices": ["par_yield", "spot_rate"]}, ) @field_validator("category", mode="before", check_fields=False) diff --git a/openbb_platform/core/openbb_core/provider/utils/helpers.py b/openbb_platform/core/openbb_core/provider/utils/helpers.py index 99064cd8396..3a5c557341f 100644 --- a/openbb_platform/core/openbb_core/provider/utils/helpers.py +++ b/openbb_platform/core/openbb_core/provider/utils/helpers.py @@ -2,11 +2,20 @@ import asyncio import re -from datetime import datetime +from datetime import date, datetime from difflib import SequenceMatcher from functools import partial from inspect import iscoroutinefunction -from typing import Awaitable, Callable, List, Literal, Optional, TypeVar, Union, cast +from typing import ( + Awaitable, + Callable, + List, + Literal, + Optional, + TypeVar, + Union, + cast, +) import requests from anyio import start_blocking_portal @@ -21,6 +30,7 @@ from openbb_core.provider.utils.client import ( T = TypeVar("T") P = ParamSpec("P") +D = TypeVar("D", bound="Data") def check_item(item: str, allowed: List[str], threshold: float = 0.75) -> None: @@ -177,13 +187,13 @@ async def amake_requests( is_exception = isinstance(result, Exception) if is_exception and kwargs.get("raise_for_status", False): - raise result + raise result # type: ignore[misc] if is_exception or not result: continue - results.extend( # type: ignore - result if isinstance(result, list) else [result] + results.extend( + result if isinstance(result, list) else [result] # type: ignore[list-item] ) return results @@ -283,17 +293,23 @@ def run_async( def filter_by_dates( - data: List[Data], - start_date: datetime, - end_date: datetime, -) -> List[Data]: + data: List[D], start_date: Optional[date] = None, end_date: Optional[date] = None +) -> List[D]: """Filter data by dates.""" - if not any([start_date, end_date]): + if start_date is None and end_date is None: return data - return list( - filter( - lambda d: start_date <= d.date.date() <= end_date, - data, - ) - ) + def _filter(d: Data) -> bool: + _date = getattr(d, "date", None) + dt = _date.date() if _date and isinstance(_date, datetime) else _date + if dt: + if start_date and end_date: + return start_date <= dt <= end_date + if start_date: + return dt >= start_date + if end_date: + return dt <= end_date + return True + return False + + return list(filter(_filter, data)) diff --git a/openbb_platform/core/tests/api/test_auth/test_user_auth.py b/openbb_platform/core/tests/api/test_auth/test_user_auth.py index ce2c9e55818..22a8ccd9cd0 100644 --- a/openbb_platform/core/tests/api/test_auth/test_user_auth.py +++ b/openbb_platform/core/tests/api/test_auth/test_user_auth.py @@ -59,6 +59,6 @@ def test_get_user_settings_(mock_user_service): mock_user_settings = MagicMock(spec=UserSettings, profile=MagicMock(active=True)) mock_user_service.default_user_settings = mock_user_settings mock_user_service.return_value = mock_user_service - result = asyncio.run(get_user_settings(MagicMock(), mock_user_service)) + result = asyncio.run(get_user_settings(MagicMock(), mock_user_service)) # type: ignore[arg-type] assert result == mock_user_settings diff --git a/openbb_platform/core/tests/api/test_dependency/test_coverage.py b/openbb_platform/core/tests/api/test_dependency/test_coverage.py index ee5d388847f..38bef9cced3 100644 --- a/openbb_platform/core/tests/api/test_dependency/test_coverage.py +++ b/openbb_platform/core/tests/api/test_dependency/test_coverage.py @@ -9,6 +9,6 @@ from openbb_core.api.dependency.coverage import get_command_map def test_get_system_settings(): """Test get_system_settings.""" - response = asyncio.run(get_command_map(MagicMock())) + response = asyncio.run(get_command_map(MagicMock())) # type: ignore[arg-type] assert response diff --git a/openbb_platform/core/tests/api/test_dependency/test_system.py b/openbb_platform/core/tests/api/test_dependency/test_system.py index 5aa014a23aa..953f9e2b384 100644 --- a/openbb_platform/core/tests/api/test_dependency/test_system.py +++ b/openbb_platform/core/tests/api/test_dependency/test_system.py @@ -14,6 +14,6 @@ def test_get_system_settings(mock_system_service): """Test get_system_settings.""" mock_system_service.return_value.system_settings = SystemSettings() - response = asyncio.run(get_system_settings(MagicMock(), mock_system_service)) + response = asyncio.run(get_system_settings(MagicMock(), mock_system_service)) # type: ignore[arg-type] assert response diff --git a/openbb_platform/core/tests/app/logs/formatters/test_formatter_with_exceptions.py b/openbb_platform/core/tests/app/logs/formatters/test_formatter_with_exceptions.py index 662187bfdb4..90a2cae1b37 100644 --- a/openbb_platform/core/tests/app/logs/formatters/test_formatter_with_exceptions.py +++ b/openbb_platform/core/tests/app/logs/formatters/test_formatter_with_exceptions.py @@ -272,7 +272,7 @@ def test_filter_log_line(input_text, expected_output, formatter): def test_formatException_invalid(): with pytest.raises(Exception): - formatter.formatException(Exception("Big bad error")) + formatter.formatException(Exception("Big bad error")) # type: ignore[attr-defined] def test_format(formatter): diff --git a/openbb_platform/core/tests/app/logs/test_handlers_manager.py b/openbb_platform/core/tests/app/logs/test_handlers_manager.py index 12cb4565d9d..43c23214e88 100644 --- a/openbb_platform/core/tests/app/logs/test_handlers_manager.py +++ b/openbb_platform/core/tests/app/logs/test_handlers_manager.py @@ -85,4 +85,4 @@ def test_update_handlers(): for hdlr in handlers: if isinstance(hdlr, (MockPosthogHandler, MockPathTrackingFileHandler)): assert hdlr.settings == changed_settings - assert hdlr.formatter.settings == changed_settings + assert hdlr.formatter.settings == changed_settings # type: ignore[union-attr] diff --git a/openbb_platform/core/tests/app/logs/test_logging_service.py b/openbb_platform/core/tests/app/logs/test_logging_service.py index bd7460bc3c2..316b603588c 100644 --- a/openbb_platform/core/tests/app/logs/test_logging_service.py +++ b/openbb_platform/core/tests/app/logs/test_logging_service.py @@ -8,6 +8,7 @@ from openbb_core.app.model.abstract.error import OpenBBError from pydantic import BaseModel # ruff: noqa: S106 +# pylint: disable=redefined-outer-name, protected-access class MockLoggingSettings: @@ -24,9 +25,7 @@ class MockOBBject(BaseModel): @pytest.fixture(scope="function") def logging_service(): mock_system_settings = Mock() - mock_system_settings = "mock_system_settings" mock_user_settings = Mock() - mock_user_settings = "mock_user_settings" mock_setup_handlers = Mock() mock_log_startup = Mock() @@ -40,19 +39,37 @@ def logging_service(): "openbb_core.app.logs.logging_service.LoggingService._log_startup", mock_log_startup, ): - logging_service = LoggingService( + _logging_service = LoggingService( system_settings=mock_system_settings, user_settings=mock_user_settings, ) - assert mock_setup_handlers.assert_called_once - assert mock_log_startup.assert_called_once + return _logging_service - return logging_service +def test_correctly_initialized(): + mock_system_settings = Mock() + mock_user_settings = Mock() + mock_setup_handlers = Mock() + mock_log_startup = Mock() + + with patch( + "openbb_core.app.logs.logging_service.LoggingSettings", + MockLoggingSettings, + ), patch( + "openbb_core.app.logs.logging_service.LoggingService._setup_handlers", + mock_setup_handlers, + ), patch( + "openbb_core.app.logs.logging_service.LoggingService._log_startup", + mock_log_startup, + ): + LoggingService( + system_settings=mock_system_settings, + user_settings=mock_user_settings, + ) -def test_correctly_initialized(logging_service): - assert logging_service + mock_setup_handlers.assert_called_once() + mock_log_startup.assert_called_once() def test_logging_settings_setter(logging_service): @@ -68,8 +85,8 @@ def test_logging_settings_setter(logging_service): custom_user_settings, ) - assert logging_service.logging_settings.system_settings == "custom_system_settings" - assert logging_service.logging_settings.user_settings == "custom_user_settings" + assert logging_service.logging_settings.system_settings == "custom_system_settings" # type: ignore[attr-defined] + assert logging_service.logging_settings.user_settings == "custom_user_settings" # type: ignore[attr-defined] def test_log_startup(logging_service): @@ -93,7 +110,10 @@ def test_log_startup(logging_service): expected_log_data = { "route": "test_route", "PREFERENCES": "your_preferences", - "KEYS": {"username": "defined", "password": "defined"}, + "KEYS": { + "username": "defined", + "password": "defined", # pragma: allowlist secret + }, "SYSTEM": "your_system_settings", "custom_headers": {"X-OpenBB-Test": "test"}, } @@ -101,7 +121,7 @@ def test_log_startup(logging_service): "STARTUP: %s ", json.dumps(expected_log_data), ) - mock_get_logger.assert_called_once + mock_get_logger.assert_called_once() @pytest.mark.parametrize( @@ -163,7 +183,7 @@ def test_log( exec_info=exec_info, custom_headers=custom_headers, ) - assert mock_log_startup.assert_called_once + mock_log_startup.assert_called_once() else: mock_info = mock_get_logger.return_value.info diff --git a/openbb_platform/core/tests/app/logs/utils/test_expired_files.py b/openbb_platform/core/tests/app/logs/utils/test_expired_files.py index 8d785ef80d2..7e2721b1058 100644 --- a/openbb_platform/core/tests/app/logs/utils/test_expired_files.py +++ b/openbb_platform/core/tests/app/logs/utils/test_expired_files.py @@ -2,6 +2,7 @@ import os import tempfile from pathlib import Path from time import time +from typing import List from unittest.mock import MagicMock, patch import pytest @@ -92,7 +93,7 @@ def mock_path(): def test_remove_file_list_no_files(mock_path): # Arrange # Let's assume the file list is empty, meaning there are no files to remove - file_list = [] + file_list: List[Path] = [] # Act remove_file_list(file_list) diff --git a/openbb_platform/core/tests/app/model/abstract/test_tagged.py b/openbb_platform/core/tests/app/model/abstract/test_tagged.py index 9a8ecd6c12e..33df298ad58 100644 --- a/openbb_platform/core/tests/app/model/abstract/test_tagged.py +++ b/openbb_platform/core/tests/app/model/abstract/test_tagged.py @@ -8,7 +8,7 @@ def test_tagged_model(): def test_fields(): - fields = Tagged.__fields__ + fields = Tagged.model_fields fields_keys = fields.keys() assert "id" in fields_keys diff --git a/openbb_platform/core/tests/app/model/abstract/test_warning.py b/openbb_platform/core/tests/app/model/abstract/test_warning.py index 5c466442f12..46692077cdd 100644 --- a/openbb_platform/core/tests/app/model/abstract/test_warning.py +++ b/openbb_platform/core/tests/app/model/abstract/test_warning.py @@ -19,7 +19,7 @@ def test_warn_model(category, message): def test_fields(): - fields = Warning_.__fields__ + fields = Warning_.model_fields fields_keys = fields.keys() assert "category" in fields_keys diff --git a/openbb_platform/core/tests/app/model/charts/test_chart.py b/openbb_platform/core/tests/app/model/charts/test_chart.py index 0278a3625b9..3d3ff9960bd 100644 --- a/openbb_platform/core/tests/app/model/charts/test_chart.py +++ b/openbb_platform/core/tests/app/model/charts/test_chart.py @@ -41,7 +41,7 @@ def test_charting_config_validation(): chart = Chart(content=content, format=chart_format) with pytest.raises(ValueError): - chart.content = "Invalid Content" + chart.content = "Invalid Content" # type: ignore[assignment] assert chart.content == content assert chart.format == chart_format diff --git a/openbb_platform/core/tests/app/model/hub/test_hub_session.py b/openbb_platform/core/tests/app/model/hub/test_hub_session.py index 96247341fef..645a097ac1f 100644 --- a/openbb_platform/core/tests/app/model/hub/test_hub_session.py +++ b/openbb_platform/core/tests/app/model/hub/test_hub_session.py @@ -1,11 +1,12 @@ from openbb_core.app.model.hub.hub_session import HubSession +from pydantic import SecretStr # ruff: noqa: S105 S106 def test_hub_session(): session = HubSession( - access_token="mock_access_token", + access_token=SecretStr("mock_access_token"), token_type="mock_token_type", email="mock_email", user_uuid="mock_user_uuid", @@ -21,7 +22,7 @@ def test_hub_session(): def test_fields(): - fields = HubSession.__fields__ + fields = HubSession.model_fields fields_keys = fields.keys() assert "access_token" in fields_keys diff --git a/openbb_platform/core/tests/app/model/test_command_context.py b/openbb_platform/core/tests/app/model/test_command_context.py index 38a346b2387..76423c8b9c5 100644 --- a/openbb_platform/core/tests/app/model/test_command_context.py +++ b/openbb_platform/core/tests/app/model/test_command_context.py @@ -13,7 +13,7 @@ def test_command_context(): def test_fields(): - fields = CommandContext.__fields__ + fields = CommandContext.model_fields fields_keys = fields.keys() assert "user_settings" in fields_keys diff --git a/openbb_platform/core/tests/app/model/test_defaults.py b/openbb_platform/core/tests/app/model/test_defaults.py index e134eb36214..af87fa52e24 100644 --- a/openbb_platform/core/tests/app/model/test_defaults.py +++ b/openbb_platform/core/tests/app/model/test_defaults.py @@ -1,14 +1,14 @@ from openbb_core.app.model.defaults import Defaults -def test_defaultst(): +def test_defaults(): cc = Defaults(routes={"test": {"test": "test"}}) assert isinstance(cc, Defaults) assert cc.routes == {"test": {"test": "test"}} def test_fields(): - fields = Defaults.__fields__ + fields = Defaults.model_fields fields_keys = fields.keys() assert "routes" in fields_keys diff --git a/openbb_platform/core/tests/app/model/test_metadata.py b/openbb_platform/core/tests/app/model/test_metadata.py index 9c5cd1e2f9d..60c4f95c583 100644 --- a/openbb_platform/core/tests/app/model/test_metadata.py +++ b/openbb_platform/core/tests/app/model/test_metadata.py @@ -21,7 +21,7 @@ def test_Metadata(): def test_fields(): "Smoke test" - fields = Metadata.__fields__.keys() + fields = Metadata.model_fields.keys() assert "arguments" in fields assert "duration" in fields assert "route" in fields diff --git a/openbb_platform/core/tests/app/model/test_obbject.py b/openbb_platform/core/tests/app/model/test_obbject.py index ffe5b8368b2..431b2b56280 100644 --- a/openbb_platform/core/tests/app/model/test_obbject.py +++ b/openbb_platform/core/tests/app/model/test_obbject.py @@ -12,7 +12,7 @@ from pandas.testing import assert_frame_equal def test_OBBject(): """Smoke test.""" - co = OBBject() + co: OBBject = OBBject() assert isinstance(co, OBBject) @@ -29,7 +29,7 @@ def test_fields(): def test_to_dataframe_no_results(): """Test helper.""" - co = OBBject() + co: OBBject = OBBject() with pytest.raises(Exception): co.to_dataframe() @@ -214,7 +214,7 @@ class MockDataFrame(Data): def test_to_dataframe(results, expected_df): """Test helper.""" # Arrange - co = OBBject(results=results) + co: OBBject = OBBject(results=results) # Act and Assert if isinstance(expected_df, pd.DataFrame): @@ -253,7 +253,7 @@ def test_to_dataframe(results, expected_df): def test_to_dataframe_w_args(results, index, sort_by): """Test helper.""" # Arrange - co = OBBject(results=results) + co: OBBject = OBBject(results=results) # Act and Assert result = co.to_dataframe(index=index, sort_by=sort_by) @@ -281,7 +281,7 @@ def test_to_dataframe_w_args(results, index, sort_by): def test_to_df_daylight_savings(results): """Test helper.""" # Arrange - co = OBBject(results=results) + co: OBBject = OBBject(results=results) # Act and Assert expected_df = basemodel_to_df(results, index="date") @@ -342,7 +342,7 @@ def test_to_df_daylight_savings(results): def test_to_dict(results, expected_dict): """Test helper.""" # Arrange - co = OBBject(results=results) + co: OBBject = OBBject(results=results) # Act and Assert if isinstance(expected_dict, (list, dict)): @@ -357,7 +357,7 @@ def test_to_dict(results, expected_dict): def test_show_chart_exists(): """Test helper.""" - mock_instance = OBBject() + mock_instance: OBBject = OBBject() # Arrange mock_instance.chart = MagicMock(spec=Chart) mock_instance.chart.fig = MagicMock() @@ -372,7 +372,7 @@ def test_show_chart_exists(): def test_show_chart_no_chart(): """Test helper.""" - mock_instance = OBBject() + mock_instance: OBBject = OBBject() # Act and Assert with pytest.raises(OpenBBError, match="Chart not found."): @@ -381,7 +381,7 @@ def test_show_chart_no_chart(): def test_show_chart_no_fig(): """Test helper.""" - mock_instance = OBBject() + mock_instance: OBBject = OBBject() # Arrange mock_instance.chart = Chart() diff --git a/openbb_platform/core/tests/app/model/test_system_settings.py b/openbb_platform/core/tests/app/model/test_system_settings.py index c44fb072318..90a1c63778f 100644 --- a/openbb_platform/core/tests/app/model/test_system_settings.py +++ b/openbb_platform/core/tests/app/model/test_system_settings.py @@ -28,12 +28,12 @@ def test_create_openbb_directory_directory_and_files_not_exist(tmpdir): ) # Act - SystemSettings.create_openbb_directory(values) + SystemSettings.create_openbb_directory(values) # type: ignore[operator] # Assert - assert os.path.exists(values.openbb_directory) - assert os.path.exists(values.user_settings_path) - assert os.path.exists(values.system_settings_path) + assert os.path.exists(values.openbb_directory) # type: ignore[attr-defined] + assert os.path.exists(values.user_settings_path) # type: ignore[attr-defined] + assert os.path.exists(values.system_settings_path) # type: ignore[attr-defined] def test_create_openbb_directory_directory_exists_user_settings_missing(tmpdir): @@ -47,15 +47,15 @@ def test_create_openbb_directory_directory_exists_user_settings_missing(tmpdir): ) # Create the openbb directory - Path(values.openbb_directory).mkdir(parents=True, exist_ok=True) + Path(values.openbb_directory).mkdir(parents=True, exist_ok=True) # type: ignore[attr-defined] # Act - SystemSettings.create_openbb_directory(values) + SystemSettings.create_openbb_directory(values) # type: ignore[operator] # Assert - assert os.path.exists(values.openbb_directory) - assert os.path.exists(values.user_settings_path) - assert os.path.exists(values.system_settings_path) + assert os.path.exists(values.openbb_directory) # type: ignore[attr-defined] + assert os.path.exists(values.user_settings_path) # type: ignore[attr-defined] + assert os.path.exists(values.system_settings_path) # type: ignore[attr-defined] def test_create_openbb_directory_directory_exists_system_settings_missing(tmpdir): @@ -69,19 +69,19 @@ def test_create_openbb_directory_directory_exists_system_settings_missing(tmpdir ) # Create the openbb directory - Path(values.openbb_directory).mkdir(parents=True, exist_ok=True) + Path(values.openbb_directory).mkdir(parents=True, exist_ok=True) # type: ignore[attr-defined] # Create the user_settings.json file - with open(values.user_settings_path, "w") as f: + with open(values.user_settings_path, "w") as f: # type: ignore[attr-defined] f.write("{}") # Act - SystemSettings.create_openbb_directory(values) + SystemSettings.create_openbb_directory(values) # type: ignore[operator] # Assert - assert os.path.exists(values.openbb_directory) - assert os.path.exists(values.user_settings_path) - assert os.path.exists(values.system_settings_path) + assert os.path.exists(values.openbb_directory) # type: ignore[attr-defined] + assert os.path.exists(values.user_settings_path) # type: ignore[attr-defined] + assert os.path.exists(values.system_settings_path) # type: ignore[attr-defined] @pytest.mark.parametrize( @@ -138,7 +138,7 @@ def test_create_openbb_directory_directory_exists_system_settings_missing(tmpdir def test_validate_posthog_handler(values, expected_handlers): values = MockSystemSettings(**values) # Act - result = SystemSettings.validate_posthog_handler(values) + result = SystemSettings.validate_posthog_handler(values) # type: ignore[operator] # Assert assert result.logging_handlers == expected_handlers @@ -160,7 +160,7 @@ def test_validate_posthog_handler(values, expected_handlers): def test_validate_logging_handlers(handlers, valid): # Act and Assert if valid: - assert SystemSettings.validate_logging_handlers(handlers) == handlers + assert SystemSettings.validate_logging_handlers(handlers) == handlers # type: ignore[call-arg] else: with pytest.raises(ValueError, match="Invalid logging handler"): - SystemSettings.validate_logging_handlers(handlers) + SystemSettings.validate_logging_handlers(handlers) # type: ignore[call-arg] diff --git a/openbb_platform/core/tests/app/service/test_user_service.py b/openbb_platform/core/tests/app/service/test_user_service.py index 070d91935d8..aefe199382f 100644 --- a/openbb_platform/core/tests/app/service/test_user_service.py +++ b/openbb_platform/core/tests/app/service/test_user_service.py @@ -4,6 +4,7 @@ import json import tempfile from pathlib import Path +from openbb_core.app.model.defaults import Defaults from openbb_core.app.service.user_service import ( UserService, UserSettings, @@ -26,9 +27,9 @@ def test_write_default_user_settings(): # Create a UserSettings object with some test data user_settings = UserSettings() - user_settings.credentials = {"username": "test"} - user_settings.preferences = {"theme": "dark"} - user_settings.defaults = {"language": "en"} + user_settings.credentials = {"username": "test"} # type: ignore[assignment] + user_settings.preferences = {"theme": "dark"} # type: ignore[assignment] + user_settings.defaults = {"language": "en"} # type: ignore[assignment] # Write the user settings to the temporary file UserService.write_default_user_settings(user_settings, temp_path) @@ -50,13 +51,13 @@ def test_update_default(): """Test update default user settings.""" # Some settings - defaults_test = {"routes": {"test": {"test": "test"}}} + defaults_test = Defaults(routes={"test": {"test": "test"}}) other_settings = UserSettings(defaults=defaults_test) # Update the default settings updated_settings = UserService.update_default(other_settings) - assert updated_settings.defaults.model_dump() == defaults_test + assert updated_settings.defaults.model_dump() == defaults_test.model_dump() def test_merge_dicts(): diff --git a/openbb_platform/core/tests/app/static/test_filters.py b/openbb_platform/core/tests/app/static/test_filters.py index 9