wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

download.go (4933B)


      1 package cmd
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 	"os"
      7 	"strings"
      8 	"text/tabwriter"
      9 	"time"
     10 
     11 	"github.com/jackc/pgx/v5"
     12 	"github.com/louloulibs/wrds-download/internal/config"
     13 	"github.com/louloulibs/wrds-download/internal/db"
     14 	"github.com/louloulibs/wrds-download/internal/export"
     15 	"github.com/spf13/cobra"
     16 )
     17 
     18 var (
     19 	dlSchema string
     20 	dlTable  string
     21 	dlColumns string
     22 	dlWhere  string
     23 	dlQuery  string
     24 	dlOut    string
     25 	dlFormat string
     26 	dlLimit  int
     27 	dlDryRun bool
     28 )
     29 
     30 var downloadCmd = &cobra.Command{
     31 	Use:   "download",
     32 	Short: "Download WRDS data to parquet or CSV",
     33 	Long: `Download data from WRDS to a local file.
     34 
     35 Examples:
     36   wrds download --schema crsp --table dsf --where "date='2020-01-02'" --out crsp_dsf.parquet
     37   wrds download --schema comp --table funda --columns "gvkey,datadate,sale" --out funda.parquet
     38   wrds download --query "SELECT permno, date, prc FROM crsp.dsf LIMIT 1000" --out out.parquet
     39   wrds download --schema comp --table funda --out funda.csv --format csv`,
     40 	RunE: runDownload,
     41 }
     42 
     43 func init() {
     44 	rootCmd.AddCommand(downloadCmd)
     45 
     46 	f := downloadCmd.Flags()
     47 	f.StringVar(&dlSchema, "schema", "", "Schema name (e.g. crsp)")
     48 	f.StringVar(&dlTable, "table", "", "Table name (e.g. dsf)")
     49 	f.StringVarP(&dlColumns, "columns", "c", "*", "Columns to select (comma-separated, default *)")
     50 	f.StringVar(&dlWhere, "where", "", "SQL WHERE clause (without the WHERE keyword)")
     51 	f.StringVar(&dlQuery, "query", "", "Full SQL query (overrides --schema/--table/--where)")
     52 	f.StringVar(&dlOut, "out", "", "Output file path (required)")
     53 	f.StringVar(&dlFormat, "format", "", "Output format: parquet or csv (inferred from extension if omitted)")
     54 	f.IntVar(&dlLimit, "limit", 0, "Limit number of rows (0 = no limit)")
     55 	f.BoolVar(&dlDryRun, "dry-run", false, "Preview the query, row count, and first 5 rows without downloading")
     56 }
     57 
     58 func runDownload(cmd *cobra.Command, args []string) error {
     59 	config.ApplyCredentials()
     60 
     61 	query, err := buildQuery()
     62 	if err != nil {
     63 		return err
     64 	}
     65 
     66 	if dlDryRun {
     67 		return runDryRun(query)
     68 	}
     69 
     70 	if dlOut == "" {
     71 		return fmt.Errorf("required flag \"out\" not set")
     72 	}
     73 
     74 	format := resolveFormat(dlOut, dlFormat)
     75 
     76 	fmt.Fprintf(os.Stderr, "Exporting to %s (%s)...\n", dlOut, format)
     77 
     78 	opts := export.Options{
     79 		Format: format,
     80 		ProgressFunc: func(rows int) {
     81 			fmt.Fprintf(os.Stderr, "Exported %d rows...\n", rows)
     82 		},
     83 	}
     84 	if err := export.Export(query, dlOut, opts); err != nil {
     85 		return fmt.Errorf("export failed: %w", err)
     86 	}
     87 
     88 	fmt.Fprintf(os.Stderr, "Done: %s\n", dlOut)
     89 	return nil
     90 }
     91 
     92 func runDryRun(query string) error {
     93 	dsn, err := db.DSNFromEnv()
     94 	if err != nil {
     95 		return fmt.Errorf("dsn: %w", err)
     96 	}
     97 
     98 	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
     99 	defer cancel()
    100 
    101 	conn, err := pgx.Connect(ctx, dsn)
    102 	if err != nil {
    103 		return fmt.Errorf("connect: %w", err)
    104 	}
    105 	defer conn.Close(ctx)
    106 
    107 	fmt.Fprintln(os.Stdout, "Query:")
    108 	fmt.Fprintln(os.Stdout, " ", query)
    109 	fmt.Fprintln(os.Stdout)
    110 
    111 	// Row count
    112 	countQuery := fmt.Sprintf("SELECT count(*) FROM (%s) sub", query)
    113 	var count int64
    114 	if err := conn.QueryRow(ctx, countQuery).Scan(&count); err != nil {
    115 		return fmt.Errorf("count query: %w", err)
    116 	}
    117 	fmt.Fprintf(os.Stdout, "Row count: %d\n\n", count)
    118 
    119 	// Preview first 5 rows
    120 	previewQuery := fmt.Sprintf("SELECT * FROM (%s) sub LIMIT 5", query)
    121 	rows, err := conn.Query(ctx, previewQuery)
    122 	if err != nil {
    123 		return fmt.Errorf("preview query: %w", err)
    124 	}
    125 	defer rows.Close()
    126 
    127 	fds := rows.FieldDescriptions()
    128 	w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
    129 
    130 	// Header
    131 	headers := make([]string, len(fds))
    132 	for i, fd := range fds {
    133 		headers[i] = fd.Name
    134 	}
    135 	fmt.Fprintln(w, strings.Join(headers, "\t"))
    136 
    137 	// Rows
    138 	for rows.Next() {
    139 		vals, err := rows.Values()
    140 		if err != nil {
    141 			return fmt.Errorf("scan row: %w", err)
    142 		}
    143 		cells := make([]string, len(vals))
    144 		for i, v := range vals {
    145 			if v == nil {
    146 				cells[i] = "NULL"
    147 			} else {
    148 				cells[i] = fmt.Sprintf("%v", v)
    149 			}
    150 		}
    151 		fmt.Fprintln(w, strings.Join(cells, "\t"))
    152 	}
    153 	return w.Flush()
    154 }
    155 
    156 func buildQuery() (string, error) {
    157 	if dlQuery != "" {
    158 		return dlQuery, nil
    159 	}
    160 	if dlSchema == "" || dlTable == "" {
    161 		return "", fmt.Errorf("either --query or both --schema and --table must be specified")
    162 	}
    163 
    164 	sel := "*"
    165 	if dlColumns != "" && dlColumns != "*" {
    166 		parts := strings.Split(dlColumns, ",")
    167 		quoted := make([]string, len(parts))
    168 		for i, p := range parts {
    169 			quoted[i] = db.QuoteIdent(strings.TrimSpace(p))
    170 		}
    171 		sel = strings.Join(quoted, ", ")
    172 	}
    173 	q := fmt.Sprintf("SELECT %s FROM %s.%s", sel, db.QuoteIdent(dlSchema), db.QuoteIdent(dlTable))
    174 
    175 	if dlWhere != "" {
    176 		q += " WHERE " + dlWhere
    177 	}
    178 	if dlLimit > 0 {
    179 		q += fmt.Sprintf(" LIMIT %d", dlLimit)
    180 	}
    181 
    182 	return q, nil
    183 }
    184 
    185 func resolveFormat(path, flag string) string {
    186 	if flag != "" {
    187 		return strings.ToLower(flag)
    188 	}
    189 	lower := strings.ToLower(path)
    190 	if strings.HasSuffix(lower, ".csv") {
    191 		return "csv"
    192 	}
    193 	return "parquet"
    194 }