summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authortehcoderer <me@tehcoderer.com>2024-05-30 10:46:48 -0400
committertehcoderer <me@tehcoderer.com>2024-05-30 10:46:48 -0400
commitcb9b4457c89a9db85ff49d5056abc98c2f463f7b (patch)
treeea0d80d1bd8aa1411e1ca2d2b09dbb922cde0489
parent795e91f09915936c78f88fe67282ec8dfec166b9 (diff)
testy?
-rw-r--r--openbb_platform/core/openbb_core/app/router.py25
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py3
-rw-r--r--openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py9
-rw-r--r--openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py10
4 files changed, 37 insertions, 10 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index 6185eea74b6..80232e2bbe5 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -19,6 +19,7 @@ from typing import (
)
from fastapi import APIRouter, Depends
+from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias
@@ -231,6 +232,26 @@ class Router:
self._routers: Dict[str, Router] = {}
@overload
+ def stream(self, func: Callable[P, OBBject]) -> Callable[P, StreamingResponse]:
+ pass
+
+ @overload
+ def stream(self, **kwargs) -> Callable:
+ pass
+
+ def stream(
+ self,
+ func: Optional[Callable[P, OBBject]] = None,
+ **kwargs,
+ ) -> Optional[Callable]:
+ """Stream decorator for routes."""
+ if func is None:
+ return lambda f: self.stream(f, **kwargs)
+
+ kwargs["is_stream"] = True
+ return self.command(func, **kwargs)
+
+ @overload
def command(self, func: Optional[Callable[P, OBBject]]) -> Callable[P, OBBject]:
pass
@@ -260,6 +281,8 @@ class Router:
examples=kwargs.pop("examples", []),
providers=ProviderInterface().available_providers,
)
+ kwargs["openapi_extra"]["is_stream"] = kwargs.pop("is_stream", False)
+
kwargs["operation_id"] = kwargs.get(
"operation_id", SignatureInspector.get_operation_id(func)
)
@@ -349,7 +372,7 @@ class SignatureInspector:
@classmethod
def complete(
- cls, func: Callable[P, OBBject], model: str
+ cls, func: Callable[P, OBBject], model: str, is_stream: bool = False
) -> Optional[Callable[P, OBBject]]:
"""Complete function signature."""
if isclass(return_type := func.__annotations__["return"]) and not issubclass(
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index 00fea7f5867..c0d3dbadce9 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -342,7 +342,8 @@ class ImportDefinition:
if route:
if route.deprecated:
hint_type_list.append(type(route.summary.metadata))
- function_hint_type_list = cls.get_function_hint_type_list(func=route.endpoint) # type: ignore
+ if not route.openapi_extra.get("is_stream", False):
+ function_hint_type_list = cls.get_function_hint_type_list(func=route.endpoint) # type: ignore
hint_type_list.extend(function_hint_type_list)
hint_type_list = list(set(hint_type_list))
diff --git a/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py b/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py
index 7bf50bcef99..0eccc948fa6 100644
--- a/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py
+++ b/openbb_platform/extensions/crypto/openbb_crypto/price/price_router.py
@@ -13,7 +13,6 @@ from openbb_core.app.provider_interface import (
from openbb_core.app.query import Query
from openbb_core.app.router import Router
from providers.binance.openbb_binance.models.crypto_historical import (
- BinanceCryptoHistoricalData,
BinanceCryptoHistoricalFetcher,
)
@@ -63,13 +62,11 @@ async def historical(
return await OBBject.from_query(Query(**locals()))
-@router.command(methods=["GET"])
-async def live(
- symbol: str = "ethbtc", lifetime: int = 10
-) -> BinanceCryptoHistoricalData:
+@router.stream(methods=["GET"])
+async def live(symbol: str = "ethbtc", lifetime: int = 10, tld: str = "us") -> OBBject:
"""Connect to Binance WebSocket Crypto Price data feed."""
generator = BinanceCryptoHistoricalFetcher().stream_data(
- params={"symbol": symbol, "lifetime": lifetime},
+ params={"symbol": symbol, "lifetime": lifetime, "tld": tld},
credentials=None,
)
return StreamingResponse(generator, media_type="application/x-ndjson")
diff --git a/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py b/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py
index 13dc85358f7..796910a8ac8 100644
--- a/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py
+++ b/openbb_platform/providers/binance/openbb_binance/models/crypto_historical.py
@@ -3,7 +3,7 @@
import json
import logging
from datetime import datetime, timedelta
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, Literal, Optional
import websockets
from openbb_core.provider.abstract.fetcher import Fetcher
@@ -21,6 +21,9 @@ logger = logging.getLogger(__name__)
class BinanceCryptoHistoricalQueryParams(CryptoHistoricalQueryParams):
"""Binance Crypto Historical Query Params."""
+ tld: Optional[Literal["us", "com"]] = Field(
+ default="us", description="Top-level domain of the Binance endpoint."
+ )
lifetime: Optional[int] = Field(
default=60, description="Lifetime of WebSocket in seconds."
)
@@ -65,11 +68,14 @@ class BinanceCryptoHistoricalFetcher(Fetcher):
) -> AsyncGenerator[dict, None]:
"""Return the raw data from the Binance endpoint."""
async with websockets.connect(
- f"wss://stream.binance.com:9443/ws/{query.symbol.lower()}@miniTicker"
+ f"wss://stream.binance.{query.tld}:9443/ws/{query.symbol.lower()}@miniTicker"
) as websocket:
logger.info("Connected to WebSocket server.")
end_time = datetime.now() + timedelta(seconds=query.lifetime)
+ print("Connected to WebSocket server.")
try:
+ chunk = await websocket.recv()
+ print(f"Chunk me baby: {chunk}")
while datetime.now() < end_time:
chunk = await websocket.recv()
yield json.loads(chunk)