wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

cli.py (5991B)


      1 """Click CLI — download and info subcommands matching the Go wrds-dl interface."""
      2 
      3 from __future__ import annotations
      4 
      5 import json
      6 
      7 import click
      8 import psycopg
      9 
     10 from wrds_dl.config import apply_credentials
     11 from wrds_dl.db import build_query, connect, dsn_from_env, table_meta
     12 
     13 
     14 @click.group()
     15 def cli() -> None:
     16     """Download data from the WRDS PostgreSQL database to Parquet or CSV."""
     17 
     18 
     19 @cli.command()
     20 @click.option("--schema", default="", help="Schema name (e.g. crsp)")
     21 @click.option("--table", default="", help="Table name (e.g. dsf)")
     22 @click.option("-c", "--columns", default="*", help="Columns to select (comma-separated, default *)")
     23 @click.option("--where", "where_clause", default="", help="SQL WHERE clause (without the WHERE keyword)")
     24 @click.option("--query", default="", help="Full SQL query (overrides --schema/--table/--where)")
     25 @click.option("--out", default="", help="Output file path (required unless --dry-run)")
     26 @click.option("--format", "fmt", default="", help="Output format: parquet or csv (inferred from extension)")
     27 @click.option("--limit", default=0, type=int, help="Limit number of rows (0 = no limit)")
     28 @click.option("--dry-run", is_flag=True, help="Preview query, row count, and first 5 rows")
     29 def download(
     30     schema: str,
     31     table: str,
     32     columns: str,
     33     where_clause: str,
     34     query: str,
     35     out: str,
     36     fmt: str,
     37     limit: int,
     38     dry_run: bool,
     39 ) -> None:
     40     """Download WRDS data to Parquet or CSV."""
     41     apply_credentials()
     42 
     43     # Build query.
     44     if query:
     45         sql = query
     46     elif schema and table:
     47         sql = build_query(schema, table, columns, where_clause, limit)
     48     else:
     49         raise click.UsageError("Either --query or both --schema and --table must be specified")
     50 
     51     if dry_run:
     52         _run_dry_run(sql)
     53         return
     54 
     55     if not out:
     56         raise click.UsageError('Required option "--out" not provided')
     57 
     58     # Resolve format.
     59     resolved_fmt = fmt.lower() if fmt else ("csv" if out.lower().endswith(".csv") else "parquet")
     60 
     61     click.echo(f"Exporting to {out} ({resolved_fmt})...", err=True)
     62 
     63     from wrds_dl.export import export_data
     64 
     65     def progress(rows: int) -> None:
     66         click.echo(f"Exported {rows} rows...", err=True)
     67 
     68     export_data(sql, out, resolved_fmt, progress)
     69     click.echo(f"Done: {out}", err=True)
     70 
     71 
     72 def _run_dry_run(sql: str) -> None:
     73     """Print query, row count, and first 5 rows."""
     74     conn = psycopg.connect(dsn_from_env())
     75     try:
     76         with conn.cursor() as cur:
     77             click.echo("Query:")
     78             click.echo(f"  {sql}")
     79             click.echo()
     80 
     81             # Row count.
     82             cur.execute(f"SELECT count(*) FROM ({sql}) sub")
     83             row = cur.fetchone()
     84             count = row[0] if row else 0
     85             click.echo(f"Row count: {count}")
     86             click.echo()
     87 
     88             # Preview first 5 rows.
     89             cur.execute(f"SELECT * FROM ({sql}) sub LIMIT 5")
     90             if cur.description is None:
     91                 return
     92 
     93             col_names = [desc.name for desc in cur.description]
     94             rows = cur.fetchall()
     95 
     96             # Calculate column widths.
     97             widths = [len(name) for name in col_names]
     98             str_rows = []
     99             for row in rows:
    100                 cells = [str(v) if v is not None else "NULL" for v in row]
    101                 str_rows.append(cells)
    102                 for i, cell in enumerate(cells):
    103                     widths[i] = max(widths[i], len(cell))
    104 
    105             # Print header and rows.
    106             header = "  ".join(name.ljust(widths[i]) for i, name in enumerate(col_names))
    107             click.echo(header)
    108             for cells in str_rows:
    109                 click.echo("  ".join(cell.ljust(widths[i]) for i, cell in enumerate(cells)))
    110     finally:
    111         conn.close()
    112 
    113 
    114 @cli.command()
    115 @click.option("--schema", required=True, help="Schema name (required)")
    116 @click.option("--table", required=True, help="Table name (required)")
    117 @click.option("--json", "as_json", is_flag=True, help="Output as JSON")
    118 def info(schema: str, table: str, as_json: bool) -> None:
    119     """Show table metadata (columns, types, row count)."""
    120     apply_credentials()
    121 
    122     conn = connect()
    123     try:
    124         meta = table_meta(conn, schema, table)
    125     finally:
    126         conn.close()
    127 
    128     if as_json:
    129         _print_info_json(meta)
    130     else:
    131         _print_info_table(meta)
    132 
    133 
    134 def _print_info_json(meta) -> None:
    135     data = {
    136         "schema": meta.schema,
    137         "table": meta.table,
    138         "comment": meta.comment or None,
    139         "row_count": meta.row_count,
    140         "size": meta.size or None,
    141         "columns": [
    142             {
    143                 "name": c.name,
    144                 "type": c.data_type,
    145                 "nullable": c.nullable,
    146                 **({"description": c.description} if c.description else {}),
    147             }
    148             for c in meta.columns
    149         ],
    150     }
    151     # Match Go: omit null keys
    152     data = {k: v for k, v in data.items() if v is not None}
    153     click.echo(json.dumps(data, indent=2))
    154 
    155 
    156 def _print_info_table(meta) -> None:
    157     click.echo(f"{meta.schema}.{meta.table}")
    158     if meta.comment:
    159         click.echo(f"  {meta.comment}")
    160 
    161     parts = []
    162     if meta.row_count > 0:
    163         parts.append(f"~{meta.row_count} rows")
    164     if meta.size:
    165         parts.append(meta.size)
    166     if parts:
    167         click.echo(f"  {', '.join(parts)}")
    168 
    169     click.echo()
    170 
    171     # Column table with tab-aligned output.
    172     widths = [4, 4, 8, 11]  # NAME, TYPE, NULLABLE, DESCRIPTION minimums
    173     rows = []
    174     for c in meta.columns:
    175         nullable = "YES" if c.nullable else "NO"
    176         row = [c.name, c.data_type, nullable, c.description]
    177         rows.append(row)
    178         for i, cell in enumerate(row):
    179             widths[i] = max(widths[i], len(cell))
    180 
    181     header = "  ".join(
    182         label.ljust(widths[i])
    183         for i, label in enumerate(["NAME", "TYPE", "NULLABLE", "DESCRIPTION"])
    184     )
    185     click.echo(header)
    186     for row in rows:
    187         click.echo("  ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)))