from __future__ import annotations import calendar import os import time from unittest.mock import patch import pytest import server from server import ( _apply_filters, _build_condition, _get_market_factory, _resolve_cookies, clear_session, find_most_active, find_top_gainers, find_top_losers, fundamental_scan, get_historical_candles, get_stock_quotes, list_fields, list_markets, screen_market, set_session, technical_scan, ) from tradingview_screener.query import And, Or # --------------------------------------------------------------------------- # _resolve_cookies # --------------------------------------------------------------------------- class TestResolveCookies: def test_per_call_takes_priority(self): assert _resolve_cookies("per_call_id") == {"sessionid": "per_call_id"} def test_in_memory_over_env(self): server._session_cookies = {"sessionid": "mem_id"} with patch.dict(os.environ, {"TV_SESSION_ID": "env_id"}, clear=True): assert _resolve_cookies() == {"sessionid": "mem_id"} server._session_cookies = None def test_env_var_fallback(self): server._session_cookies = None with patch.dict(os.environ, {"TV_SESSION_ID": "env_id"}, clear=True): assert _resolve_cookies() == {"sessionid": "env_id"} def test_no_cookies(self): server._session_cookies = None with patch.dict(os.environ, {}, clear=True): assert _resolve_cookies() is None # --------------------------------------------------------------------------- # _build_condition # --------------------------------------------------------------------------- def _assert_condition(cond: dict, left: str, operation: str, right: object = None): """Helper: check the inner expression dict, handling nested 'expression' keys.""" inner = cond if "expression" not in cond else cond.get("expression", cond) assert inner.get("left") == left, f"expected left={left}, got {inner.get('left')}" assert inner.get("operation") == operation, f"expected operation={operation}, got {inner.get('operation')}" if right is not None or "right" in inner: assert inner.get("right") == right, f"expected right={right}, got {inner.get('right')}" class TestBuildCondition: def test_greater(self): c = _build_condition({"field": "RSI", "operator": ">", "value": 70}) _assert_condition(c, "RSI", "greater", 70) def test_greater_alt(self): c = _build_condition({"field": "RSI", "operator": "greater", "value": 70}) _assert_condition(c, "RSI", "greater", 70) def test_greater_equal(self): c = _build_condition({"field": "RSI", "operator": ">=", "value": 70}) _assert_condition(c, "RSI", "egreater", 70) def test_less(self): c = _build_condition({"field": "RSI", "operator": "<", "value": 30}) _assert_condition(c, "RSI", "less", 30) def test_equal(self): c = _build_condition({"field": "close", "operator": "==", "value": 100}) _assert_condition(c, "close", "equal", 100) def test_not_equal(self): c = _build_condition({"field": "sector", "operator": "!=", "value": "Finance"}) _assert_condition(c, "sector", "nequal", "Finance") def test_between(self): c = _build_condition({"field": "close", "operator": "between", "value": [100, 200]}) _assert_condition(c, "close", "in_range", [100, 200]) def test_not_between(self): c = _build_condition({"field": "close", "operator": "not_between", "value": [50, 100]}) _assert_condition(c, "close", "not_in_range", [50, 100]) def test_isin(self): c = _build_condition({"field": "exchange", "operator": "isin", "value": ["NASDAQ", "NYSE"]}) _assert_condition(c, "exchange", "in_range", ["NASDAQ", "NYSE"]) def test_not_in(self): c = _build_condition({"field": "exchange", "operator": "not_in", "value": ["OTC"]}) _assert_condition(c, "exchange", "not_in_range", ["OTC"]) def test_has(self): c = _build_condition({"field": "description", "operator": "has", "value": "tech"}) _assert_condition(c, "description", "has", "tech") def test_has_none_of(self): c = _build_condition({"field": "description", "operator": "has_none_of", "value": ["test"]}) _assert_condition(c, "description", "has_none_of", ["test"]) def test_like(self): c = _build_condition({"field": "name", "operator": "like", "value": "AAPL"}) _assert_condition(c, "name", "match", "AAPL") def test_not_like(self): c = _build_condition({"field": "name", "operator": "not_like", "value": "TEST"}) _assert_condition(c, "name", "nmatch", "TEST") def test_empty(self): c = _build_condition({"field": "description", "operator": "empty"}) _assert_condition(c, "description", "empty", None) def test_not_empty(self): c = _build_condition({"field": "description", "operator": "not_empty"}) _assert_condition(c, "description", "nempty", None) def test_crosses(self): c = _build_condition({"field": "MACD.macd", "operator": "crosses", "value": "MACD.signal"}) _assert_condition(c, "MACD.macd", "crosses", "MACD.signal") def test_crosses_above(self): c = _build_condition({"field": "MACD.macd", "operator": "crosses_above", "value": "MACD.signal"}) _assert_condition(c, "MACD.macd", "crosses_above", "MACD.signal") def test_crosses_below(self): c = _build_condition({"field": "MACD.macd", "operator": "crosses_below", "value": "MACD.signal"}) _assert_condition(c, "MACD.macd", "crosses_below", "MACD.signal") def test_above_pct(self): c = _build_condition({"field": "close", "operator": "above_pct", "value": ["SMA50", 1.02]}) _assert_condition(c, "close", "above%", ["SMA50", 1.02]) def test_below_pct(self): c = _build_condition({"field": "close", "operator": "below_pct", "value": ["SMA50", 0.98]}) _assert_condition(c, "close", "below%", ["SMA50", 0.98]) def test_between_pct(self): c = _build_condition({"field": "close", "operator": "between_pct", "value": ["SMA50", 0.98, 1.02]}) _assert_condition(c, "close", "in_range%", ["SMA50", 0.98, 1.02]) def test_in_day_range(self): c = _build_condition({"field": "RSI", "operator": "in_day_range", "value": [30, 70]}) _assert_condition(c, "RSI", "in_day_range", [30, 70]) def test_in_week_range(self): c = _build_condition({"field": "RSI", "operator": "in_week_range", "value": [30, 70]}) _assert_condition(c, "RSI", "in_week_range", [30, 70]) def test_in_month_range(self): c = _build_condition({"field": "RSI", "operator": "in_month_range", "value": [30, 70]}) _assert_condition(c, "RSI", "in_month_range", [30, 70]) def test_or_group(self): c = _build_condition({ "operator": "or", "filters": [ {"field": "RSI", "operator": ">", "value": 70}, {"field": "RSI", "operator": "<", "value": 30}, ], }) assert callable(Or) assert "operation" in c assert c["operation"]["operator"] == "or" assert len(c["operation"]["operands"]) == 2 def test_and_group(self): c = _build_condition({ "operator": "and", "filters": [ {"field": "close", "operator": ">", "value": 100}, {"field": "volume", "operator": ">", "value": 1000000}, ], }) assert callable(And) assert c["operation"]["operator"] == "and" assert len(c["operation"]["operands"]) == 2 def test_nested_groups(self): c = _build_condition({ "operator": "or", "filters": [ { "operator": "and", "filters": [ {"field": "RSI", "operator": ">", "value": 70}, {"field": "volume", "operator": ">", "value": 1000000}, ], }, {"field": "close", "operator": "<", "value": 10}, ], }) op = c["operation"] assert op["operator"] == "or" assert len(op["operands"]) == 2 def test_invalid_filter_raises(self): with pytest.raises(ValueError, match="Invalid filter"): _build_condition({"invalid": "data"}) # --------------------------------------------------------------------------- # _apply_filters # --------------------------------------------------------------------------- class TestApplyFilters: def test_none_filters(self): q = _get_market_factory("stocks", "america") original_query = dict(q.query) result = _apply_filters(q, None) assert result.query == original_query def test_empty_filters(self): q = _get_market_factory("stocks", "america") original_query = dict(q.query) result = _apply_filters(q, []) assert result.query == original_query def test_simple_filters_only(self): q = _get_market_factory("stocks", "america") q = _apply_filters(q, [{"field": "close", "operator": ">", "value": 100}]) flt = q.query.get("filter", []) assert isinstance(flt, list) assert any(f.get("left") == "close" for f in flt) def test_nested_filters_only(self): q = _get_market_factory("stocks", "america") q = _apply_filters(q, [ {"operator": "or", "filters": [ {"field": "RSI", "operator": ">", "value": 70}, {"field": "RSI", "operator": "<", "value": 30}, ]}, ]) f2 = q.query.get("filter2", {}) assert isinstance(f2, dict) def test_mixed_filters(self): q = _get_market_factory("stocks", "america") q = _apply_filters(q, [ {"field": "close", "operator": ">", "value": 10}, {"operator": "or", "filters": [ {"field": "RSI", "operator": ">", "value": 70}, {"field": "RSI", "operator": "<", "value": 30}, ]}, ]) flt = q.query.get("filter", []) assert any(f.get("left") == "close" for f in flt) f2 = q.query.get("filter2", {}) assert "or" in str(f2) def test_preserves_default_filter2(self): q = _get_market_factory("stocks", "america") q = _apply_filters(q, [ {"operator": "or", "filters": [ {"field": "RSI", "operator": ">", "value": 70}, {"field": "RSI", "operator": "<", "value": 30}, ]}, ]) f2 = q.query.get("filter2") assert f2 is not None assert f2.get("operator") == "and" # --------------------------------------------------------------------------- # _get_market_factory # --------------------------------------------------------------------------- class TestGetMarketFactory: def test_stocks_default(self): q = _get_market_factory("stocks", None) assert "america" in str(q.query.get("markets", "")) def test_stocks_with_country(self): q = _get_market_factory("stocks", "india") assert "india" in str(q.query.get("markets", "")) def test_unknown_falls_back_to_stocks(self): q = _get_market_factory("bogus", None) assert "america" in str(q.query.get("markets", "")) def test_crypto(self): q = _get_market_factory("crypto", None) m = str(q.query.get("markets", "")).lower() assert "crypto" in m def test_forex(self): q = _get_market_factory("forex", None) m = str(q.query.get("markets", "")).lower() assert "forex" in m def test_options(self): q = _get_market_factory("options", None) index_filters = q.query.get("index_filters", []) assert any("ESM2026" in str(f) for f in index_filters) # --------------------------------------------------------------------------- # MCP Tools (with mocked _exec_query) # --------------------------------------------------------------------------- MOCK_RESULT = (1, [{"name": "TEST", "close": 150.0, "change": 2.5, "volume": 1000000}]) @pytest.fixture(autouse=True) def reset_session(): server._session_cookies = None yield @patch("server._exec_query", return_value=MOCK_RESULT) class TestMCPTools: def test_get_stock_quotes(self, mock_exec): result = get_stock_quotes(tickers=["NASDAQ:NVDA"]) assert "Found 1 result(s)" in result assert "TEST" in result mock_exec.assert_called_once() def test_get_stock_quotes_with_columns(self, mock_exec): result = get_stock_quotes( tickers=["NASDAQ:AAPL"], columns=["name", "close", "RSI"], ) assert "Found 1 result(s)" in result def test_get_stock_quotes_crypto(self, mock_exec): result = get_stock_quotes( tickers=["BINANCE:BTCUSDT"], market_type="crypto", ) assert "Found 1 result(s)" in result def test_screen_market_no_filters(self, mock_exec): result = screen_market(limit=10) assert "Total: 1" in result def test_screen_market_with_filters(self, mock_exec): result = screen_market( columns=["name", "close", "RSI"], filters=[{"field": "RSI", "operator": ">", "value": 70}], order_by="RSI", limit=20, ) assert "Total: 1" in result def test_screen_market_nested_filters(self, mock_exec): result = screen_market( filters=[{"operator": "or", "filters": [ {"field": "RSI", "operator": ">", "value": 70}, {"field": "RSI", "operator": "<", "value": 30}, ]}], ) assert "Total: 1" in result def test_screen_market_with_offset(self, mock_exec): result = screen_market(limit=5, offset=10) assert "Total: 1" in result def test_screen_market_crypto(self, mock_exec): result = screen_market(market_type="crypto", limit=5) assert "Total: 1" in result def test_find_top_gainers(self, mock_exec): result = find_top_gainers(limit=10) assert "Total: 1" in result def test_find_top_gainers_with_min_price(self, mock_exec): result = find_top_gainers(limit=10, min_price=10.0) assert "Total: 1" in result def test_find_top_losers(self, mock_exec): result = find_top_losers(limit=10) assert "Total: 1" in result def test_find_most_active(self, mock_exec): result = find_most_active(limit=10) assert "Total: 1" in result def test_technical_scan(self, mock_exec): result = technical_scan( filters=[{"field": "RSI", "operator": "<", "value": 30}], limit=10, ) assert "Total: 1" in result def test_technical_scan_with_custom_columns(self, mock_exec): result = technical_scan( filters=[{"field": "RSI", "operator": ">", "value": 70}], columns=["name", "close", "RSI", "MACD.macd"], limit=5, ) assert "Total: 1" in result def test_fundamental_scan(self, mock_exec): result = fundamental_scan( filters=[{"field": "market_cap_basic", "operator": ">", "value": 1e9}], limit=10, ) assert "Total: 1" in result def test_fundamental_scan_with_custom_columns(self, mock_exec): result = fundamental_scan( filters=[{"field": "price_earnings_ttm", "operator": "<", "value": 20}], columns=["name", "close", "price_earnings_ttm"], ) assert "Total: 1" in result class TestListTools: def test_list_markets(self): result = list_markets() assert "asset_types" in result assert "stock_countries" in result assert "america" in result assert "crypto" in result def test_list_fields_all(self): result = list_fields() assert "price" in result assert "technical" in result assert "fundamental" in result assert "general" in result def test_list_fields_price(self): result = list_fields(category="price") assert "close" in result assert "volume" in result def test_list_fields_technical(self): result = list_fields(category="technical") assert "RSI" in result assert "MACD.macd" in result def test_list_fields_fundamental(self): result = list_fields(category="fundamental") assert "market_cap_basic" in result assert "price_earnings_ttm" in result def test_list_fields_invalid_category(self): result = list_fields(category="bogus") assert "not found" in result class TestSessionTools: def test_set_and_clear_session(self): result = set_session("test_session_123") assert "stored" in result assert _resolve_cookies() == {"sessionid": "test_session_123"} result2 = clear_session() assert "cleared" in result2 with patch.dict(os.environ, {}, clear=True): assert _resolve_cookies() is None def test_clear_session_when_none(self): server._session_cookies = None result = clear_session() assert "cleared" in result # --------------------------------------------------------------------------- # get_historical_candles # --------------------------------------------------------------------------- _MOCK_CANDLES = [ {"time": 1746662400, "open": 100.0, "high": 110.0, "low": 95.0, "close": 105.0, "volume": 1_000_000}, {"time": 1746748800, "open": 105.0, "high": 115.0, "low": 100.0, "close": 108.0, "volume": 1_200_000}, {"time": 1746835200, "open": 108.0, "high": 112.0, "low": 104.0, "close": 106.0, "volume": 900_000}, ] class TestGetHistoricalCandles: @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_basic_count_based(self, mock_fetch): result = get_historical_candles("NASDAQ:AAPL", resolution="D", count=3) mock_fetch.assert_called_once_with("NASDAQ:AAPL", "D", 3, from_ts=None, to_ts=None) assert "NASDAQ:AAPL" in result assert "3 candles" in result @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_from_date_only(self, mock_fetch): result = get_historical_candles("NASDAQ:AAPL", from_date="2026-05-08") kw = mock_fetch.call_args.kwargs expected_from = int(calendar.timegm(time.strptime("2026-05-08", "%Y-%m-%d"))) assert kw["from_ts"] == expected_from assert kw["to_ts"] is not None # defaults to now assert "2026-05-08" in result @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_from_and_to_date(self, mock_fetch): result = get_historical_candles("NASDAQ:AAPL", from_date="2026-05-08", to_date="2026-05-13") kw = mock_fetch.call_args.kwargs expected_from = int(calendar.timegm(time.strptime("2026-05-08", "%Y-%m-%d"))) expected_to = int(calendar.timegm(time.strptime("2026-05-13", "%Y-%m-%d"))) + 86399 assert kw["from_ts"] == expected_from assert kw["to_ts"] == expected_to assert "2026-05-08 → 2026-05-13" in result @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_datetime_strings(self, mock_fetch): get_historical_candles("NASDAQ:AAPL", from_date="2026-05-17 14:00", to_date="2026-05-17 14:05") kw = mock_fetch.call_args.kwargs expected_from = int(calendar.timegm(time.strptime("2026-05-17 14:00", "%Y-%m-%d %H:%M"))) expected_to = int(calendar.timegm(time.strptime("2026-05-17 14:05", "%Y-%m-%d %H:%M"))) assert kw["from_ts"] == expected_from assert kw["to_ts"] == expected_to @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_datetime_with_seconds(self, mock_fetch): get_historical_candles("NASDAQ:AAPL", from_date="2026-05-17 14:00:30", to_date="2026-05-17 14:05:00") kw = mock_fetch.call_args.kwargs expected_from = int(calendar.timegm(time.strptime("2026-05-17 14:00:30", "%Y-%m-%d %H:%M:%S"))) assert kw["from_ts"] == expected_from @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_date_only_to_date_end_of_day(self, mock_fetch): # date-only to_date should be bumped to end of that day (23:59:59) get_historical_candles("NASDAQ:AAPL", from_date="2026-05-08", to_date="2026-05-13") kw = mock_fetch.call_args.kwargs midnight = int(calendar.timegm(time.strptime("2026-05-13", "%Y-%m-%d"))) assert kw["to_ts"] == midnight + 86399 @patch("server._fetch_candles", return_value=[]) def test_no_data_returned(self, mock_fetch): result = get_historical_candles("NASDAQ:AAPL") assert "No historical data" in result @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_invalid_date_format_raises(self, mock_fetch): with pytest.raises(ValueError, match="Unrecognised date"): get_historical_candles("NASDAQ:AAPL", from_date="05/08/2026") @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_count_capped_at_500(self, mock_fetch): get_historical_candles("NASDAQ:AAPL", count=9999) assert mock_fetch.call_args.args[2] == 500 @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_detect_patterns_included_by_default(self, mock_fetch): result = get_historical_candles("NASDAQ:AAPL") assert "Pattern Detection" in result @patch("server._fetch_candles", return_value=_MOCK_CANDLES) def test_detect_patterns_disabled(self, mock_fetch): result = get_historical_candles("NASDAQ:AAPL", detect_patterns=False) assert "Pattern Detection" not in result