Source code for vangja.datasets.stocks

"""Private helper functions for downloading historical stock data.

This module provides functions for downloading OHLCV data from Yahoo Finance,
computing typical prices, and determining S&P 500 index composition history.

The only public function is :func:`get_sp500_tickers_for_range`, which returns
tickers that were consistently part of the S&P 500 during a given time range.

All other functions are private helpers prefixed with ``_``.
"""

from __future__ import annotations

import logging
from datetime import datetime
from io import StringIO
from pathlib import Path

import pandas as pd
import requests

logger = logging.getLogger(__name__)

_SP500_WIKI_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
_HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}


def _compute_typical_price(df: pd.DataFrame) -> pd.Series:
    """Compute typical price as (Open + High + Low + Close) / 4.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame with columns ``Open``, ``High``, ``Low``, ``Close``.

    Returns
    -------
    pd.Series
        Typical price series.
    """
    return (df["Open"] + df["High"] + df["Low"] + df["Close"]) / 4


def _safe_ticker_filename(ticker: str) -> str:
    """Convert a ticker symbol to a safe filename component.

    Parameters
    ----------
    ticker : str
        Ticker symbol (e.g., ``"^GSPC"``, ``"BRK.B"``).

    Returns
    -------
    str
        Filename-safe string.
    """
    return ticker.replace("^", "_").replace("/", "_").replace(".", "_")


def _download_stock_data(
    tickers: list[str], cache_path: Path | None = None
) -> pd.DataFrame:
    """Download historical daily OHLCV data for one or more tickers from
    1940-01-01 to 2026-01-01.

    Uses ``yfinance`` to batch-download data for efficiency when multiple
    tickers are requested.

    Parameters
    ----------
    tickers : list[str]
        List of ticker symbols to download (e.g., ``["AAPL", "^GSPC"]``).
    cache_path : Path or None, default None
        Directory path for caching downloaded data. If None, data is
        downloaded without caching. If provided, each ticker's data is
        stored as a CSV file in this directory. Parent directories are
        created if they do not exist. On subsequent calls, cached data
        is loaded instead of re-downloading.

    Returns
    -------
    pd.DataFrame
        DataFrame with columns: ``ds`` (datetime), ``ticker`` (str),
        ``Open``, ``High``, ``Low``, ``Close``, ``Volume`` (float),
        and ``typical_price`` (float).
    """
    try:
        import yfinance as yf
    except ImportError as e:
        raise ImportError(
            "yfinance is required to download stock data. "
            "Install with: pip install vangja[datasets]"
        ) from e

    START_DATE = "1940-01-01"
    END_DATE = "2026-01-01"
    results: list[pd.DataFrame] = []
    tickers_to_download: list[str] = []

    # Check cache for each ticker and load if available
    if cache_path is not None:
        cache_path.mkdir(parents=True, exist_ok=True)
        for ticker in tickers:
            fname = cache_path / f"{_safe_ticker_filename(ticker)}.csv"
            if fname.exists():
                logger.info("Loading cached data for %s", ticker)
                results.append(pd.read_csv(fname, parse_dates=["ds"]))
            else:
                tickers_to_download.append(ticker)
    else:
        tickers_to_download = list(tickers)

    if not tickers_to_download:
        if not results:
            return pd.DataFrame(
                columns=[
                    "ds",
                    "ticker",
                    "Open",
                    "High",
                    "Low",
                    "Close",
                    "Volume",
                    "typical_price",
                ]
            )
        return pd.concat(results, ignore_index=True)

    # Download all missing tickers at once for speed
    if len(tickers_to_download) == 1:
        raw = yf.download(
            tickers_to_download[0],
            start=START_DATE,
            end=END_DATE,
            auto_adjust=True,
            progress=False,
        )
    else:
        raw = yf.download(
            tickers_to_download,
            start=START_DATE,
            end=END_DATE,
            auto_adjust=True,
            progress=False,
            group_by="ticker",
        )

    if raw.empty:
        logger.warning("No data returned from yfinance")
    else:
        for ticker in tickers_to_download:
            try:
                # Extract per-ticker data
                if len(tickers_to_download) == 1:
                    ticker_df = raw.copy()
                else:
                    ticker_df = raw[ticker].copy()

                # Flatten MultiIndex columns if present (newer yfinance)
                if isinstance(ticker_df.columns, pd.MultiIndex):
                    for idx, level in enumerate(ticker_df.columns.levels):
                        if "Open" in level:
                            ticker_df.columns = ticker_df.columns.get_level_values(idx)
                            break

                ticker_df = ticker_df.dropna(how="all")
                if ticker_df.empty:
                    logger.warning("No data for %s", ticker)
                    continue

                ticker_df = ticker_df.reset_index()

                # Normalize date column name
                for col in ("Date", "Datetime", "date"):
                    if col in ticker_df.columns:
                        ticker_df = ticker_df.rename(columns={col: "ds"})
                        break

                ticker_df["ds"] = pd.to_datetime(ticker_df["ds"]).dt.tz_localize(None)
                ticker_df["ticker"] = ticker
                ticker_df["typical_price"] = _compute_typical_price(ticker_df)

                if cache_path is not None:
                    fname = cache_path / f"{_safe_ticker_filename(ticker)}.csv"
                    ticker_df.to_csv(fname, index=False)

                results.append(ticker_df)
            except (KeyError, TypeError):
                logger.warning("No data for %s", ticker)

    if not results:
        return pd.DataFrame(
            columns=[
                "ds",
                "ticker",
                "Open",
                "High",
                "Low",
                "Close",
                "Volume",
                "typical_price",
            ]
        )

    return pd.concat(results, ignore_index=True)


def _parse_constituents_table(table: pd.DataFrame) -> pd.DataFrame:
    """Parse the S&P 500 current constituents table from Wikipedia.

    Parameters
    ----------
    table : pd.DataFrame
        Raw DataFrame from ``pd.read_html`` for the first table on
        the Wikipedia S&P 500 companies page.

    Returns
    -------
    pd.DataFrame
        DataFrame with at least columns ``ticker`` (str) and
        ``date_added`` (datetime).
    """
    df = table.copy()

    # Flatten MultiIndex if present
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = [
            str(c[-1]) if c[-1] and "Unnamed" not in str(c[-1]) else str(c[0])
            for c in df.columns
        ]

    # Rename to standard names based on content
    col_map: dict[str, str] = {}
    for col in df.columns:
        col_lower = str(col).lower()
        if "symbol" in col_lower or col_lower == "ticker":
            col_map[col] = "ticker"
        elif "date" in col_lower and "added" in col_lower:
            col_map[col] = "date_added"
        elif col_lower == "security":
            col_map[col] = "security"

    df = df.rename(columns=col_map)

    if "date_added" in df.columns:
        df["date_added"] = pd.to_datetime(df["date_added"], errors="coerce")

    return df


def _parse_changes_table(table: pd.DataFrame) -> pd.DataFrame:
    """Parse the S&P 500 historical changes table from Wikipedia.

    Parameters
    ----------
    table : pd.DataFrame
        Raw DataFrame from ``pd.read_html`` for the second table on
        the Wikipedia S&P 500 companies page.

    Returns
    -------
    pd.DataFrame
        DataFrame with columns ``date``, ``added_ticker``,
        ``added_name``, ``removed_ticker``, ``removed_name``.
    """
    df = table.copy()

    # Handle MultiIndex columns from merged header cells
    if isinstance(df.columns, pd.MultiIndex):
        new_cols: list[str] = []
        for i, col in enumerate(df.columns):
            level0 = str(col[0]).strip().lower()
            level1 = str(col[1]).strip().lower() if len(col) > 1 else ""

            if "date" in level0:
                new_cols.append("date")
            elif "added" in level0 and "ticker" in level1:
                new_cols.append("added_ticker")
            elif "added" in level0 and ("security" in level1 or "name" in level1):
                new_cols.append("added_name")
            elif "removed" in level0 and "ticker" in level1:
                new_cols.append("removed_ticker")
            elif "removed" in level0 and ("security" in level1 or "name" in level1):
                new_cols.append("removed_name")
            elif "reason" in level0:
                new_cols.append("reason")
            else:
                new_cols.append(f"col_{i}")
        df.columns = new_cols
    else:
        # Flat columns — rename by position
        cols = df.columns.tolist()
        if len(cols) >= 6:
            rename = {
                cols[0]: "date",
                cols[1]: "added_ticker",
                cols[2]: "added_name",
                cols[3]: "removed_ticker",
                cols[4]: "removed_name",
                cols[5]: "reason",
            }
            df = df.rename(columns=rename)

    if "date" in df.columns:
        df["date"] = pd.to_datetime(df["date"], errors="coerce")
        df = df.dropna(subset=["date"])

    keep = [
        c
        for c in [
            "date",
            "added_ticker",
            "added_name",
            "removed_ticker",
            "removed_name",
        ]
        if c in df.columns
    ]
    return df[keep]


def _fetch_sp500_wiki_tables(
    cache_path: Path | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Fetch S&P 500 constituent and changes tables from Wikipedia.

    Downloads the "List of S&P 500 companies" Wikipedia page and parses
    both the current constituents table and the historical changes table.

    Parameters
    ----------
    cache_path : Path or None, default None
        Directory for caching parsed tables as CSV files. If both
        ``sp500_constituents.csv`` and ``sp500_changes.csv`` exist in
        this directory, they are loaded instead of downloading.

    Returns
    -------
    tuple[pd.DataFrame, pd.DataFrame]
        ``(constituents_df, changes_df)``
    """
    if cache_path is not None:
        cache_path.mkdir(parents=True, exist_ok=True)
        const_file = cache_path / "sp500_constituents.csv"
        changes_file = cache_path / "sp500_changes.csv"

        if const_file.exists() and changes_file.exists():
            const_df = pd.read_csv(const_file)
            if "date_added" in const_df.columns:
                const_df["date_added"] = pd.to_datetime(const_df["date_added"])
            changes_df = pd.read_csv(changes_file)
            if "date" in changes_df.columns:
                changes_df["date"] = pd.to_datetime(changes_df["date"])
            return const_df, changes_df

    resp = requests.get(_SP500_WIKI_URL, headers=_HEADERS)
    resp.raise_for_status()
    tables = pd.read_html(StringIO(resp.text))
    const_df = _parse_constituents_table(tables[0])
    changes_df = _parse_changes_table(tables[1])

    if cache_path is not None:
        const_df.to_csv(const_file, index=False)
        changes_df.to_csv(changes_file, index=False)

    return const_df, changes_df


def _get_sp500_constituents_at_date(
    target_date: str | datetime | pd.Timestamp,
    const_df: pd.DataFrame | None = None,
    changes_df: pd.DataFrame | None = None,
    cache_path: Path | None = None,
) -> set[str]:
    """Determine which tickers were in the S&P 500 on a given date.

    Reconstructs the index composition by starting from the current
    constituents and reversing all historical changes that occurred
    after ``target_date``.

    Parameters
    ----------
    target_date : str, datetime, or pd.Timestamp
        The date to determine S&P 500 composition for.
    const_df : pd.DataFrame or None
        Pre-fetched constituents table. If None, fetched from Wikipedia.
    changes_df : pd.DataFrame or None
        Pre-fetched changes table. If None, fetched from Wikipedia.
    cache_path : Path or None
        Directory for caching Wikipedia data.

    Returns
    -------
    set[str]
        Set of ticker symbols in the S&P 500 on the target date.

    Notes
    -----
    Accuracy depends on Wikipedia's historical changes table, which has
    comprehensive data from approximately 1997 onwards. Earlier dates
    may be less accurate.
    """
    target = pd.Timestamp(target_date)

    if const_df is None or changes_df is None:
        const_df, changes_df = _fetch_sp500_wiki_tables(cache_path)

    # Start with current tickers
    current_tickers = set(const_df["ticker"].dropna().astype(str).str.strip())
    result = current_tickers.copy()

    # Sort changes newest first and undo those after target_date
    sorted_changes = changes_df.sort_values("date", ascending=False)

    for _, row in sorted_changes.iterrows():
        change_date = row["date"]
        if pd.isna(change_date) or change_date <= target:
            continue

        added = row.get("added_ticker")
        removed = row.get("removed_ticker")

        # Undo addition (wasn't there before this date)
        if pd.notna(added) and str(added).strip():
            result.discard(str(added).strip())

        # Undo removal (was still there before this date)
        if pd.notna(removed) and str(removed).strip():
            result.add(str(removed).strip())

    return result


[docs] def get_sp500_tickers_for_range( start_date: str | datetime | pd.Timestamp, end_date: str | datetime | pd.Timestamp, cache_path: Path | None = None, ) -> list[str]: """Get tickers consistently in the S&P 500 during a date range. Returns tickers that were part of the S&P 500 for the entire duration between ``start_date`` and ``end_date``. A ticker is excluded if it was removed at any point during the range, even if it was later re-added. Parameters ---------- start_date : str, datetime, or pd.Timestamp Start of the date range (inclusive). end_date : str, datetime, or pd.Timestamp End of the date range (inclusive). cache_path : Path or None, default None Directory for caching Wikipedia data as CSV files. If None, data is fetched without caching. If provided, parent directories are created if they do not exist. Returns ------- list[str] Sorted list of ticker symbols that were consistently in the S&P 500 during the entire date range. Raises ------ ValueError If ``start_date`` is after ``end_date``. Notes ----- Accuracy depends on Wikipedia's "List of S&P 500 companies" historical changes table, which has comprehensive data from approximately 1997 onwards. Results for earlier periods may be less accurate. Examples -------- >>> from vangja.datasets.stocks import get_sp500_tickers_for_range >>> tickers = get_sp500_tickers_for_range( ... "2020-01-01", "2020-12-31" ... ) # doctest: +SKIP >>> "AAPL" in tickers # doctest: +SKIP True """ start = pd.Timestamp(start_date) end = pd.Timestamp(end_date) if start > end: raise ValueError(f"start_date ({start}) must be before end_date ({end})") # Fetch tables once, share across calls const_df, changes_df = _fetch_sp500_wiki_tables(cache_path) # Get constituents at start of range constituents_at_start = _get_sp500_constituents_at_date( start, const_df=const_df, changes_df=changes_df, ) # Find tickers removed during the range range_changes = changes_df[ (changes_df["date"] > start) & (changes_df["date"] <= end) ] removed_during: set[str] = set() for _, row in range_changes.iterrows(): removed = row.get("removed_ticker") if pd.notna(removed) and str(removed).strip(): removed_during.add(str(removed).strip()) consistent = constituents_at_start - removed_during return sorted([ticker.replace(".", "-") for ticker in consistent])