wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

test_integration.py (4739B)


      1 """Integration test: download a small CRSP MSF sample and verify output.
      2 
      3 Requires WRDS credentials (PGUSER/PGPASSWORD or ~/.config/wrds-dl/credentials).
      4 Skipped automatically when credentials are unavailable.
      5 
      6 If the Go wrds-dl binary is found, downloads the same data with both
      7 implementations and asserts their content hashes match.
      8 """
      9 
     10 from __future__ import annotations
     11 
     12 import hashlib
     13 import os
     14 import subprocess
     15 import tempfile
     16 from pathlib import Path
     17 
     18 import pyarrow.parquet as pq
     19 import pytest
     20 
     21 from wrds_dl.config import load_credentials
     22 
     23 # A narrow, deterministic query: 10 rows from crsp.msf for Jan 2020.
     24 QUERY = (
     25     "SELECT permno, date, prc, ret, shrout "
     26     "FROM crsp.msf "
     27     "WHERE date = '2020-01-31' "
     28     "ORDER BY permno "
     29     "LIMIT 10"
     30 )
     31 
     32 REPO_ROOT = Path(__file__).resolve().parents[2]
     33 GO_BINARY = REPO_ROOT / "wrds-dl"  # pre-built binary at repo root
     34 
     35 
     36 def _has_credentials() -> bool:
     37     if os.environ.get("PGUSER"):
     38         return True
     39     user, pw, _ = load_credentials()
     40     return bool(user and pw)
     41 
     42 
     43 pytestmark = pytest.mark.skipif(
     44     not _has_credentials(),
     45     reason="WRDS credentials not available",
     46 )
     47 
     48 
     49 def _content_hash(parquet_path: str) -> str:
     50     """Read a parquet file, sort deterministically, and return a SHA-256 of the content.
     51 
     52     Converts all values to their repr() for a canonical representation
     53     that is independent of the parquet writer (parquet-go vs pyarrow).
     54     """
     55     table = pq.read_table(parquet_path)
     56     # Normalize column order alphabetically.
     57     col_names = sorted(table.column_names)
     58     table = table.select(col_names)
     59     # Sort rows by all columns.
     60     sort_keys = [(col, "ascending") for col in col_names]
     61     table = table.sort_by(sort_keys)
     62     # Hash a canonical string representation of every cell.
     63     h = hashlib.sha256()
     64     h.update(",".join(col_names).encode())
     65     for i in range(table.num_rows):
     66         for col_name in col_names:
     67             val = table.column(col_name)[i].as_py()
     68             h.update(repr(val).encode())
     69             h.update(b"|")
     70         h.update(b"\n")
     71     return h.hexdigest()
     72 
     73 
     74 def test_python_download_parquet():
     75     """Download a small sample with the Python CLI and verify the parquet output."""
     76     with tempfile.TemporaryDirectory() as tmpdir:
     77         out = os.path.join(tmpdir, "test_py.parquet")
     78 
     79         from click.testing import CliRunner
     80         from wrds_dl.cli import cli
     81 
     82         runner = CliRunner()
     83         result = runner.invoke(cli, ["download", "--query", QUERY, "--out", out])
     84         assert result.exit_code == 0, f"Python download failed: {result.output}"
     85 
     86         # Verify parquet file.
     87         table = pq.read_table(out)
     88         assert table.num_rows == 10
     89         assert set(table.column_names) == {"permno", "date", "prc", "ret", "shrout"}
     90 
     91         py_hash = _content_hash(out)
     92         assert len(py_hash) == 64  # valid sha256
     93 
     94 
     95 @pytest.mark.skipif(
     96     not GO_BINARY.is_file(),
     97     reason=f"Go binary not found at {GO_BINARY}",
     98 )
     99 def test_go_python_parity():
    100     """Download the same data with Go and Python, assert content hashes match."""
    101     with tempfile.TemporaryDirectory() as tmpdir:
    102         py_out = os.path.join(tmpdir, "py.parquet")
    103         go_out = os.path.join(tmpdir, "go.parquet")
    104 
    105         # Python download.
    106         from click.testing import CliRunner
    107         from wrds_dl.cli import cli
    108 
    109         runner = CliRunner()
    110         result = runner.invoke(cli, ["download", "--query", QUERY, "--out", py_out])
    111         assert result.exit_code == 0, f"Python download failed: {result.output}"
    112 
    113         # Go download.
    114         env = os.environ.copy()
    115         proc = subprocess.run(
    116             [str(GO_BINARY), "download", "--query", QUERY, "--out", go_out],
    117             capture_output=True,
    118             text=True,
    119             env=env,
    120             timeout=60,
    121         )
    122         assert proc.returncode == 0, f"Go download failed: {proc.stderr}"
    123 
    124         # Compare content hashes.
    125         py_hash = _content_hash(py_out)
    126         go_hash = _content_hash(go_out)
    127 
    128         # Read both tables for diagnostics on failure.
    129         py_table = pq.read_table(py_out)
    130         go_table = pq.read_table(go_out)
    131 
    132         assert py_table.num_rows == go_table.num_rows, (
    133             f"Row count mismatch: Python={py_table.num_rows}, Go={go_table.num_rows}"
    134         )
    135         assert set(py_table.column_names) == set(go_table.column_names), (
    136             f"Column mismatch: Python={py_table.column_names}, Go={go_table.column_names}"
    137         )
    138         assert py_hash == go_hash, (
    139             f"Content hash mismatch:\n"
    140             f"  Python: {py_hash}\n"
    141             f"  Go:     {go_hash}\n"
    142             f"  Python schema:\n{py_table.schema}\n"
    143             f"  Go schema:\n{go_table.schema}\n"
    144         )