diff options
author | Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com> | 2024-05-27 15:07:43 +0200 |
---|---|---|
committer | Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com> | 2024-05-27 15:07:43 +0200 |
commit | 58e8ada6f4fd28f53a62ba461f790e6012cdc8be (patch) | |
tree | ef48cd36ec99fc40bb4f474021053e75b3585807 | |
parent | 38f4316749b9976bb83a727cef96d5e0d572a1d7 (diff) |
Ready POC
5 files changed, 56 insertions, 104 deletions
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index cbfbfc125b1..82c6c234d8b 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -1479,7 +1479,6 @@ class ReferenceGenerator: expanded_types[field], is_required, "website" ) field_type = f"Union[{field_type}, {expanded_type}]" - cleaned_description = ( str(field_info.description) .strip().replace("\n", " ").replace(" ", " ").replace('"', "'") @@ -1506,7 +1505,6 @@ class ReferenceGenerator: # Manually setting to List[<field_type>] for multiple items # Should be removed if TYPE_EXPANSION is updated to include this field_type = f"Union[{field_type}, List[{field_type}]]" - default_value = "" if field_info.default is PydanticUndefined else field_info.default # fmt: skip provider_field_params.append( diff --git a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py index 8f16e8b7bfa..386ea0420f8 100644 --- a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py +++ b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py @@ -5,6 +5,7 @@ from typing import ( Any, + AsyncIterator, Dict, Generic, Optional, @@ -54,13 +55,18 @@ class Fetcher(Generic[Q, R]): """Asynchronously extract the data from the provider.""" @staticmethod + async def atransform_data( + query: Q, data: Any, **kwargs + ) -> Union[R, AnnotatedResult[R]]: + """Asynchronously transform the provider-specific data.""" + + @staticmethod def extract_data(query: Q, credentials: Optional[Dict[str, str]]) -> Any: """Extract the data from the provider.""" @staticmethod def transform_data(query: Q, data: Any, **kwargs) -> Union[R, AnnotatedResult[R]]: """Transform the provider-specific data.""" - raise NotImplementedError def __init_subclass__(cls, *args, **kwargs): """Initialize the subclass.""" @@ -75,6 +81,15 @@ class Fetcher(Generic[Q, R]): " default." ) + if cls.atransform_data != Fetcher.atransform_data: + cls.transform_data = cls.atransform_data + elif cls.transform_data == Fetcher.transform_data: + raise NotImplementedError( + "Fetcher subclass must implement either transform_data or atransform_data" + " method. If both are implemented, atransform_data will be used as the" + " default." + ) + @classmethod async def fetch_data( cls, @@ -95,13 +110,14 @@ class Fetcher(Generic[Q, R]): params: Dict[str, Any], credentials: Optional[Dict[str, str]] = None, **kwargs, - ) -> Union[R, AnnotatedResult[R]]: + ) -> Union[AsyncIterator[R], AsyncIterator[AnnotatedResult[R]]]: """Fetch data from a provider.""" query = cls.transform_query(params=params) data = await maybe_coroutine( cls.aextract_data, query=query, credentials=credentials, **kwargs ) - async for d in data: + transformed_data = cls.atransform_data(query=query, data=data, **kwargs) + async for d in transformed_data: yield d @classproperty @@ -242,64 +258,3 @@ class Fetcher(Generic[Q, R]): assert issubclass( type(transformed_data), cls.return_type ), f"Transformed data must be of the correct type. Expected: {cls.return_type} Got: {type(transformed_data)}" - - -# class StreamFetcher(Generic[Q, R]): -# """Class to fetch live streaming data using WebSocket connections.""" - -# @classmethod -# async def connect( -# cls, -# uri: str, -# ): -# """Connect to a WebSocket server.""" -# cls.websocket = await websockets.connect(uri) -# print("Connected to WebSocket server.") -# asyncio.create_task(cls.receive_data()) - -# @staticmethod -# def transform_data(data: Any, **kwargs) -> Union[R, AnnotatedResult[R]]: -# """Transform the provider-specific data.""" -# raise NotImplementedError - -# @classmethod -# async def receive_data(cls, **kwargs): -# """Receive data from the WebSocket server.""" -# try: -# while True: -# message = await cls.websocket.recv() -# processed_data = await cls.process_message(message, **kwargs) -# if processed_data: -# print(processed_data) - -# except websockets.exceptions.ConnectionClosed: -# print("WebSocket connection closed.") - -# @classmethod -# async def process_message(cls, message: str, **kwargs) -> Optional[R]: -# """Process incoming WebSocket messages.""" -# try: -# json_data = json.loads(message) -# transformed_data = cls.transform_data(json_data, **kwargs) -# return transformed_data -# except Exception as e: -# print(f"Error processing message: {e}") -# return None - -# @classmethod -# async def disconnect(cls): -# """Disconnect the WebSocket.""" -# await cls.websocket.close() - -# @classmethod -# async def fetch_data( -# cls, # pylint: disable=unused-argument -# params: Dict[str, Any], -# credentials: Optional[Dict[str, str]] = None, # pylint: disable=unused-argument -# **kwargs, -# ) -> Union[R, AnnotatedResult[R]]: -# """Fetch live data from a provider.""" -# # In a streaming context, this method may just ensure the connection is open. -# if not hasattr(cls, "websocket"): -# await cls.connect(params.get("uri")) -# # Data handling is asynchronous and managed by `receive_data`. diff --git a/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py b/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py index b27eaf0c5d9..6a4d477b235 100644 --- a/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py +++ b/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py @@ -12,12 +12,8 @@ from openbb_core.app.provider_interface import ( ) from openbb_core.app.query import Query from openbb_core.app.router import Router -from providers.binance.openbb_binance.models.crypto_historical import ( - BinanceCryptoHistoricalFetcher, -) from openbb_crypto.price.price_router import router as price_router -from fastapi.responses import StreamingResponse router = Router(prefix="", description="Cryptocurrency market data.") router.include_router(price_router) @@ -39,15 +35,3 @@ async def search( ) -> OBBject: """Search available cryptocurrency pairs within a provider.""" return await OBBject.from_query(Query(**locals())) - - -@router.command( - methods=["GET"], -) -async def stream_price(symbol: str = "ethbtc", lifetime: int = 10) -> OBBject: - """Define the POC.""" - generator = BinanceCryptoHistoricalFetcher().stream_data( - params={"symbol": symbol, "lifetime": lifetime}, - credentials=None, - ) - return StreamingResponse(generator, media_type="application/x-ndjson") diff --git a/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py b/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py index efce0336149..7bf50bcef99 100644 --- a/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py +++ b/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py @@ -1,6 +1,7 @@ # pylint: disable=W0613:unused-argument """Crypto Price Router.""" +from fastapi.responses import StreamingResponse from openbb_core.app.model.command_context import CommandContext from openbb_core.app.model.example import APIEx from openbb_core.app.model.obbject import OBBject @@ -11,6 +12,10 @@ from openbb_core.app.provider_interface import ( ) from openbb_core.app.query import Query from openbb_core.app.router import Router +from providers.binance.openbb_binance.models.crypto_historical import ( + BinanceCryptoHistoricalData, + BinanceCryptoHistoricalFetcher, +) router = Router(prefix="/price") @@ -56,3 +61,15 @@ async def historical( ) -> OBBject: """Get historical price data for cryptocurrency pair(s) within a provider.""" return await OBBject.from_query(Query(**locals())) + + +@router.command(methods=["GET"]) +async def live( + symbol: str = "ethbtc", lifetime: int = 10 +) -> BinanceCryptoHistoricalData: + """Connect to Binance WebSocket Crypto Price data feed.""" + generator = BinanceCryptoHistoricalFetcher().stream_data( + params={"symbol": symbol, "lifetime": lifetime}, + credentials=None, + ) + return StreamingResponse(generator, media_type="application/x-ndjson") diff --git a/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py b/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py index 3c8d440c166..4aca9311fb8 100644 --- a/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py +++ b/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py @@ -1,6 +1,7 @@ """Binance Crypto Historical WS Data.""" import json +import logging from datetime import datetime, timedelta from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional @@ -56,39 +57,36 @@ class BinanceCryptoHistoricalFetcher(Fetcher): return BinanceCryptoHistoricalQueryParams(**params) @staticmethod - def transform_data( - query: BinanceCryptoHistoricalQueryParams, - data: Dict[str, Any], - ) -> BinanceCryptoHistoricalData: - """Return the transformed data.""" - data["date"] = ( - datetime.now().isoformat() if "date" not in data else data["date"] - ) - - return BinanceCryptoHistoricalData(**data) - - @staticmethod async def aextract_data( query: BinanceCryptoHistoricalQueryParams, credentials: Optional[Dict[str, str]] = None, **kwargs: Any, - ) -> AsyncIterator[str]: + ) -> AsyncGenerator[dict, None]: """Return the raw data from the Binance endpoint.""" async with websockets.connect( f"wss://stream.binance.com:9443/ws/{query.symbol.lower()}@miniTicker" ) as websocket: - print("Connected to WebSocket server.") + logging.info("Connected to WebSocket server.") end_time = datetime.now() + timedelta(seconds=query.lifetime) try: while datetime.now() < end_time: - message = await websocket.recv() - data = json.loads(message) - transformed_data = BinanceCryptoHistoricalFetcher.transform_data( - query, data - ) - yield transformed_data.model_dump_json() + "\n" + chunk = await websocket.recv() + yield json.loads(chunk) except websockets.exceptions.ConnectionClosed as e: - print("WebSocket connection closed.") + logging.error("WebSocket connection closed.") raise e finally: - print("WebSocket connection closed.") + logging.info("WebSocket connection closed.") + + @staticmethod + async def atransform_data( + query: BinanceCryptoHistoricalQueryParams, + data: Dict[str, Any], + ) -> AsyncIterator[str]: + """Return the transformed data.""" + async for chunk in data: + chunk["date"] = ( + datetime.now().isoformat() if "date" not in chunk else chunk["date"] + ) + result = BinanceCryptoHistoricalData(**chunk) + yield result.model_dump_json() + "\n" |