diff options
author | tehcoderer <me@tehcoderer.com> | 2024-05-30 10:46:48 -0400 |
---|---|---|
committer | tehcoderer <me@tehcoderer.com> | 2024-05-30 10:46:48 -0400 |
commit | cb9b4457c89a9db85ff49d5056abc98c2f463f7b (patch) | |
tree | ea0d80d1bd8aa1411e1ca2d2b09dbb922cde0489 | |
parent | 795e91f09915936c78f88fe67282ec8dfec166b9 (diff) |
testy?
4 files changed, 37 insertions, 10 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 6185eea74b6..80232e2bbe5 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -19,6 +19,7 @@ from typing import ( ) from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse from pydantic import BaseModel from pydantic.v1.validators import find_validators from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias @@ -231,6 +232,26 @@ class Router: self._routers: Dict[str, Router] = {} @overload + def stream(self, func: Callable[P, OBBject]) -> Callable[P, StreamingResponse]: + pass + + @overload + def stream(self, **kwargs) -> Callable: + pass + + def stream( + self, + func: Optional[Callable[P, OBBject]] = None, + **kwargs, + ) -> Optional[Callable]: + """Stream decorator for routes.""" + if func is None: + return lambda f: self.stream(f, **kwargs) + + kwargs["is_stream"] = True + return self.command(func, **kwargs) + + @overload def command(self, func: Optional[Callable[P, OBBject]]) -> Callable[P, OBBject]: pass @@ -260,6 +281,8 @@ class Router: examples=kwargs.pop("examples", []), providers=ProviderInterface().available_providers, ) + kwargs["openapi_extra"]["is_stream"] = kwargs.pop("is_stream", False) + kwargs["operation_id"] = kwargs.get( "operation_id", SignatureInspector.get_operation_id(func) ) @@ -349,7 +372,7 @@ class SignatureInspector: @classmethod def complete( - cls, func: Callable[P, OBBject], model: str + cls, func: Callable[P, OBBject], model: str, is_stream: bool = False ) -> Optional[Callable[P, OBBject]]: """Complete function signature.""" if isclass(return_type := func.__annotations__["return"]) and not issubclass( diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 00fea7f5867..c0d3dbadce9 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -342,7 +342,8 @@ class ImportDefinition: if route: if route.deprecated: hint_type_list.append(type(route.summary.metadata)) - function_hint_type_list = cls.get_function_hint_type_list(func=route.endpoint) # type: ignore + if not route.openapi_extra.get("is_stream", False): + function_hint_type_list = cls.get_function_hint_type_list(func=route.endpoint) # type: ignore hint_type_list.extend(function_hint_type_list) hint_type_list = list(set(hint_type_list)) diff --git a/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py b/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py index 7bf50bcef99..0eccc948fa6 100644 --- a/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py +++ b/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py @@ -13,7 +13,6 @@ from openbb_core.app.provider_interface import ( from openbb_core.app.query import Query from openbb_core.app.router import Router from providers.binance.openbb_binance.models.crypto_historical import ( - BinanceCryptoHistoricalData, BinanceCryptoHistoricalFetcher, ) @@ -63,13 +62,11 @@ async def historical( return await OBBject.from_query(Query(**locals())) -@router.command(methods=["GET"]) -async def live( - symbol: str = "ethbtc", lifetime: int = 10 -) -> BinanceCryptoHistoricalData: +@router.stream(methods=["GET"]) +async def live(symbol: str = "ethbtc", lifetime: int = 10, tld: str = "us") -> OBBject: """Connect to Binance WebSocket Crypto Price data feed.""" generator = BinanceCryptoHistoricalFetcher().stream_data( - params={"symbol": symbol, "lifetime": lifetime}, + params={"symbol": symbol, "lifetime": lifetime, "tld": tld}, credentials=None, ) return StreamingResponse(generator, media_type="application/x-ndjson") diff --git a/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py b/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py index 13dc85358f7..796910a8ac8 100644 --- a/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py +++ b/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py @@ -3,7 +3,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, Literal, Optional import websockets from openbb_core.provider.abstract.fetcher import Fetcher @@ -21,6 +21,9 @@ logger = logging.getLogger(__name__) class BinanceCryptoHistoricalQueryParams(CryptoHistoricalQueryParams): """Binance Crypto Historical Query Params.""" + tld: Optional[Literal["us", "com"]] = Field( + default="us", description="Top-level domain of the Binance endpoint." + ) lifetime: Optional[int] = Field( default=60, description="Lifetime of WebSocket in seconds." ) @@ -65,11 +68,14 @@ class BinanceCryptoHistoricalFetcher(Fetcher): ) -> AsyncGenerator[dict, None]: """Return the raw data from the Binance endpoint.""" async with websockets.connect( - f"wss://stream.binance.com:9443/ws/{query.symbol.lower()}@miniTicker" + f"wss://stream.binance.{query.tld}:9443/ws/{query.symbol.lower()}@miniTicker" ) as websocket: logger.info("Connected to WebSocket server.") end_time = datetime.now() + timedelta(seconds=query.lifetime) + print("Connected to WebSocket server.") try: + chunk = await websocket.recv() + print(f"Chunk me baby: {chunk}") while datetime.now() < end_time: chunk = await websocket.recv() yield json.loads(chunk) |