export.py (4464B)
1 """Export query results to Parquet or CSV with streaming and progress.""" 2 3 from __future__ import annotations 4 5 import csv 6 from decimal import Decimal 7 from typing import Callable 8 9 import psycopg 10 import pyarrow as pa 11 import pyarrow.parquet as pq 12 13 from wrds_dl.db import dsn_from_env 14 15 ROW_GROUP_SIZE = 10_000 16 17 # Map PostgreSQL type OIDs to PyArrow types. 18 _PG_OID_TO_ARROW: dict[int, pa.DataType] = { 19 16: pa.bool_(), # bool 20 21: pa.int32(), # int2 21 23: pa.int32(), # int4 22 20: pa.int64(), # int8 23 700: pa.float32(), # float4 24 701: pa.float64(), # float8 25 1082: pa.date32(), # date 26 1114: pa.timestamp("us"), # timestamp 27 1184: pa.timestamp("us", tz="UTC"), # timestamptz 28 } 29 30 31 def _arrow_type_for_oid(oid: int) -> pa.DataType: 32 return _PG_OID_TO_ARROW.get(oid, pa.string()) 33 34 35 def export_data( 36 query: str, 37 out_path: str, 38 fmt: str = "parquet", 39 progress_fn: Callable[[int], None] | None = None, 40 ) -> None: 41 """Run *query* against WRDS and write results to *out_path*.""" 42 conn = psycopg.connect(dsn_from_env()) 43 try: 44 with conn.cursor(name="wrds_export") as cur: 45 cur.itersize = ROW_GROUP_SIZE 46 cur.execute(query) 47 48 if cur.description is None: 49 raise RuntimeError("Query returned no columns") 50 51 col_names = [desc.name for desc in cur.description] 52 col_oids = [desc.type_code for desc in cur.description] 53 54 if fmt == "csv": 55 _write_csv(cur, col_names, out_path, progress_fn) 56 else: 57 _write_parquet(cur, col_names, col_oids, out_path, progress_fn) 58 finally: 59 conn.close() 60 61 62 def _write_csv( 63 cur: psycopg.Cursor, 64 col_names: list[str], 65 out_path: str, 66 progress_fn: Callable[[int], None] | None, 67 ) -> None: 68 with open(out_path, "w", newline="") as f: 69 writer = csv.writer(f) 70 writer.writerow(col_names) 71 total = 0 72 for row in cur: 73 writer.writerow(_format_row(row)) 74 total += 1 75 if progress_fn and total % ROW_GROUP_SIZE == 0: 76 progress_fn(total) 77 78 79 def _write_parquet( 80 cur: psycopg.Cursor, 81 col_names: list[str], 82 col_oids: list[int], 83 out_path: str, 84 progress_fn: Callable[[int], None] | None, 85 ) -> None: 86 arrow_types = [_arrow_type_for_oid(oid) for oid in col_oids] 87 schema = pa.schema([(name, typ) for name, typ in zip(col_names, arrow_types)]) 88 89 writer = pq.ParquetWriter(out_path, schema, compression="zstd") 90 try: 91 batch_rows: list[tuple] = [] 92 total = 0 93 94 for row in cur: 95 batch_rows.append(row) 96 if len(batch_rows) >= ROW_GROUP_SIZE: 97 _flush_batch(writer, schema, batch_rows, col_names) 98 total += len(batch_rows) 99 batch_rows = [] 100 if progress_fn: 101 progress_fn(total) 102 103 if batch_rows: 104 _flush_batch(writer, schema, batch_rows, col_names) 105 total += len(batch_rows) 106 finally: 107 writer.close() 108 109 110 def _flush_batch( 111 writer: pq.ParquetWriter, 112 schema: pa.Schema, 113 rows: list[tuple], 114 col_names: list[str], 115 ) -> None: 116 """Convert a batch of rows into a PyArrow table and write it.""" 117 columns: dict[str, list] = {name: [] for name in col_names} 118 for row in rows: 119 for i, val in enumerate(row): 120 # Strip trailing zeros from Decimal values (numeric columns) 121 # so output matches Go's pgx behaviour. 122 if isinstance(val, Decimal): 123 val = str(val.normalize()) 124 columns[col_names[i]].append(val) 125 126 arrays = [] 127 for i, name in enumerate(col_names): 128 try: 129 arrays.append(pa.array(columns[name], type=schema.field(name).type)) 130 except (pa.ArrowInvalid, pa.ArrowTypeError): 131 # Fallback: convert to strings 132 arrays.append(pa.array([str(v) if v is not None else None for v in columns[name]], 133 type=pa.string())) 134 135 table = pa.table(dict(zip(col_names, arrays))) 136 writer.write_table(table) 137 138 139 def _format_row(row: tuple) -> list[str]: 140 """Format a row for CSV output.""" 141 out = [] 142 for v in row: 143 if v is None: 144 out.append("") 145 elif isinstance(v, Decimal): 146 out.append(str(v.normalize())) 147 else: 148 out.append(str(v)) 149 return out