wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

export.py (4464B)


      1 """Export query results to Parquet or CSV with streaming and progress."""
      2 
      3 from __future__ import annotations
      4 
      5 import csv
      6 from decimal import Decimal
      7 from typing import Callable
      8 
      9 import psycopg
     10 import pyarrow as pa
     11 import pyarrow.parquet as pq
     12 
     13 from wrds_dl.db import dsn_from_env
     14 
     15 ROW_GROUP_SIZE = 10_000
     16 
     17 # Map PostgreSQL type OIDs to PyArrow types.
     18 _PG_OID_TO_ARROW: dict[int, pa.DataType] = {
     19     16: pa.bool_(),        # bool
     20     21: pa.int32(),        # int2
     21     23: pa.int32(),        # int4
     22     20: pa.int64(),        # int8
     23     700: pa.float32(),     # float4
     24     701: pa.float64(),     # float8
     25     1082: pa.date32(),     # date
     26     1114: pa.timestamp("us"),  # timestamp
     27     1184: pa.timestamp("us", tz="UTC"),  # timestamptz
     28 }
     29 
     30 
     31 def _arrow_type_for_oid(oid: int) -> pa.DataType:
     32     return _PG_OID_TO_ARROW.get(oid, pa.string())
     33 
     34 
     35 def export_data(
     36     query: str,
     37     out_path: str,
     38     fmt: str = "parquet",
     39     progress_fn: Callable[[int], None] | None = None,
     40 ) -> None:
     41     """Run *query* against WRDS and write results to *out_path*."""
     42     conn = psycopg.connect(dsn_from_env())
     43     try:
     44         with conn.cursor(name="wrds_export") as cur:
     45             cur.itersize = ROW_GROUP_SIZE
     46             cur.execute(query)
     47 
     48             if cur.description is None:
     49                 raise RuntimeError("Query returned no columns")
     50 
     51             col_names = [desc.name for desc in cur.description]
     52             col_oids = [desc.type_code for desc in cur.description]
     53 
     54             if fmt == "csv":
     55                 _write_csv(cur, col_names, out_path, progress_fn)
     56             else:
     57                 _write_parquet(cur, col_names, col_oids, out_path, progress_fn)
     58     finally:
     59         conn.close()
     60 
     61 
     62 def _write_csv(
     63     cur: psycopg.Cursor,
     64     col_names: list[str],
     65     out_path: str,
     66     progress_fn: Callable[[int], None] | None,
     67 ) -> None:
     68     with open(out_path, "w", newline="") as f:
     69         writer = csv.writer(f)
     70         writer.writerow(col_names)
     71         total = 0
     72         for row in cur:
     73             writer.writerow(_format_row(row))
     74             total += 1
     75             if progress_fn and total % ROW_GROUP_SIZE == 0:
     76                 progress_fn(total)
     77 
     78 
     79 def _write_parquet(
     80     cur: psycopg.Cursor,
     81     col_names: list[str],
     82     col_oids: list[int],
     83     out_path: str,
     84     progress_fn: Callable[[int], None] | None,
     85 ) -> None:
     86     arrow_types = [_arrow_type_for_oid(oid) for oid in col_oids]
     87     schema = pa.schema([(name, typ) for name, typ in zip(col_names, arrow_types)])
     88 
     89     writer = pq.ParquetWriter(out_path, schema, compression="zstd")
     90     try:
     91         batch_rows: list[tuple] = []
     92         total = 0
     93 
     94         for row in cur:
     95             batch_rows.append(row)
     96             if len(batch_rows) >= ROW_GROUP_SIZE:
     97                 _flush_batch(writer, schema, batch_rows, col_names)
     98                 total += len(batch_rows)
     99                 batch_rows = []
    100                 if progress_fn:
    101                     progress_fn(total)
    102 
    103         if batch_rows:
    104             _flush_batch(writer, schema, batch_rows, col_names)
    105             total += len(batch_rows)
    106     finally:
    107         writer.close()
    108 
    109 
    110 def _flush_batch(
    111     writer: pq.ParquetWriter,
    112     schema: pa.Schema,
    113     rows: list[tuple],
    114     col_names: list[str],
    115 ) -> None:
    116     """Convert a batch of rows into a PyArrow table and write it."""
    117     columns: dict[str, list] = {name: [] for name in col_names}
    118     for row in rows:
    119         for i, val in enumerate(row):
    120             # Strip trailing zeros from Decimal values (numeric columns)
    121             # so output matches Go's pgx behaviour.
    122             if isinstance(val, Decimal):
    123                 val = str(val.normalize())
    124             columns[col_names[i]].append(val)
    125 
    126     arrays = []
    127     for i, name in enumerate(col_names):
    128         try:
    129             arrays.append(pa.array(columns[name], type=schema.field(name).type))
    130         except (pa.ArrowInvalid, pa.ArrowTypeError):
    131             # Fallback: convert to strings
    132             arrays.append(pa.array([str(v) if v is not None else None for v in columns[name]],
    133                                    type=pa.string()))
    134 
    135     table = pa.table(dict(zip(col_names, arrays)))
    136     writer.write_table(table)
    137 
    138 
    139 def _format_row(row: tuple) -> list[str]:
    140     """Format a row for CSV output."""
    141     out = []
    142     for v in row:
    143         if v is None:
    144             out.append("")
    145         elif isinstance(v, Decimal):
    146             out.append(str(v.normalize()))
    147         else:
    148             out.append(str(v))
    149     return out