download.go (4933B)
1 package cmd 2 3 import ( 4 "context" 5 "fmt" 6 "os" 7 "strings" 8 "text/tabwriter" 9 "time" 10 11 "github.com/jackc/pgx/v5" 12 "github.com/louloulibs/wrds-download/internal/config" 13 "github.com/louloulibs/wrds-download/internal/db" 14 "github.com/louloulibs/wrds-download/internal/export" 15 "github.com/spf13/cobra" 16 ) 17 18 var ( 19 dlSchema string 20 dlTable string 21 dlColumns string 22 dlWhere string 23 dlQuery string 24 dlOut string 25 dlFormat string 26 dlLimit int 27 dlDryRun bool 28 ) 29 30 var downloadCmd = &cobra.Command{ 31 Use: "download", 32 Short: "Download WRDS data to parquet or CSV", 33 Long: `Download data from WRDS to a local file. 34 35 Examples: 36 wrds download --schema crsp --table dsf --where "date='2020-01-02'" --out crsp_dsf.parquet 37 wrds download --schema comp --table funda --columns "gvkey,datadate,sale" --out funda.parquet 38 wrds download --query "SELECT permno, date, prc FROM crsp.dsf LIMIT 1000" --out out.parquet 39 wrds download --schema comp --table funda --out funda.csv --format csv`, 40 RunE: runDownload, 41 } 42 43 func init() { 44 rootCmd.AddCommand(downloadCmd) 45 46 f := downloadCmd.Flags() 47 f.StringVar(&dlSchema, "schema", "", "Schema name (e.g. crsp)") 48 f.StringVar(&dlTable, "table", "", "Table name (e.g. dsf)") 49 f.StringVarP(&dlColumns, "columns", "c", "*", "Columns to select (comma-separated, default *)") 50 f.StringVar(&dlWhere, "where", "", "SQL WHERE clause (without the WHERE keyword)") 51 f.StringVar(&dlQuery, "query", "", "Full SQL query (overrides --schema/--table/--where)") 52 f.StringVar(&dlOut, "out", "", "Output file path (required)") 53 f.StringVar(&dlFormat, "format", "", "Output format: parquet or csv (inferred from extension if omitted)") 54 f.IntVar(&dlLimit, "limit", 0, "Limit number of rows (0 = no limit)") 55 f.BoolVar(&dlDryRun, "dry-run", false, "Preview the query, row count, and first 5 rows without downloading") 56 } 57 58 func runDownload(cmd *cobra.Command, args []string) error { 59 config.ApplyCredentials() 60 61 query, err := buildQuery() 62 if err != nil { 63 return err 64 } 65 66 if dlDryRun { 67 return runDryRun(query) 68 } 69 70 if dlOut == "" { 71 return fmt.Errorf("required flag \"out\" not set") 72 } 73 74 format := resolveFormat(dlOut, dlFormat) 75 76 fmt.Fprintf(os.Stderr, "Exporting to %s (%s)...\n", dlOut, format) 77 78 opts := export.Options{ 79 Format: format, 80 ProgressFunc: func(rows int) { 81 fmt.Fprintf(os.Stderr, "Exported %d rows...\n", rows) 82 }, 83 } 84 if err := export.Export(query, dlOut, opts); err != nil { 85 return fmt.Errorf("export failed: %w", err) 86 } 87 88 fmt.Fprintf(os.Stderr, "Done: %s\n", dlOut) 89 return nil 90 } 91 92 func runDryRun(query string) error { 93 dsn, err := db.DSNFromEnv() 94 if err != nil { 95 return fmt.Errorf("dsn: %w", err) 96 } 97 98 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 99 defer cancel() 100 101 conn, err := pgx.Connect(ctx, dsn) 102 if err != nil { 103 return fmt.Errorf("connect: %w", err) 104 } 105 defer conn.Close(ctx) 106 107 fmt.Fprintln(os.Stdout, "Query:") 108 fmt.Fprintln(os.Stdout, " ", query) 109 fmt.Fprintln(os.Stdout) 110 111 // Row count 112 countQuery := fmt.Sprintf("SELECT count(*) FROM (%s) sub", query) 113 var count int64 114 if err := conn.QueryRow(ctx, countQuery).Scan(&count); err != nil { 115 return fmt.Errorf("count query: %w", err) 116 } 117 fmt.Fprintf(os.Stdout, "Row count: %d\n\n", count) 118 119 // Preview first 5 rows 120 previewQuery := fmt.Sprintf("SELECT * FROM (%s) sub LIMIT 5", query) 121 rows, err := conn.Query(ctx, previewQuery) 122 if err != nil { 123 return fmt.Errorf("preview query: %w", err) 124 } 125 defer rows.Close() 126 127 fds := rows.FieldDescriptions() 128 w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) 129 130 // Header 131 headers := make([]string, len(fds)) 132 for i, fd := range fds { 133 headers[i] = fd.Name 134 } 135 fmt.Fprintln(w, strings.Join(headers, "\t")) 136 137 // Rows 138 for rows.Next() { 139 vals, err := rows.Values() 140 if err != nil { 141 return fmt.Errorf("scan row: %w", err) 142 } 143 cells := make([]string, len(vals)) 144 for i, v := range vals { 145 if v == nil { 146 cells[i] = "NULL" 147 } else { 148 cells[i] = fmt.Sprintf("%v", v) 149 } 150 } 151 fmt.Fprintln(w, strings.Join(cells, "\t")) 152 } 153 return w.Flush() 154 } 155 156 func buildQuery() (string, error) { 157 if dlQuery != "" { 158 return dlQuery, nil 159 } 160 if dlSchema == "" || dlTable == "" { 161 return "", fmt.Errorf("either --query or both --schema and --table must be specified") 162 } 163 164 sel := "*" 165 if dlColumns != "" && dlColumns != "*" { 166 parts := strings.Split(dlColumns, ",") 167 quoted := make([]string, len(parts)) 168 for i, p := range parts { 169 quoted[i] = db.QuoteIdent(strings.TrimSpace(p)) 170 } 171 sel = strings.Join(quoted, ", ") 172 } 173 q := fmt.Sprintf("SELECT %s FROM %s.%s", sel, db.QuoteIdent(dlSchema), db.QuoteIdent(dlTable)) 174 175 if dlWhere != "" { 176 q += " WHERE " + dlWhere 177 } 178 if dlLimit > 0 { 179 q += fmt.Sprintf(" LIMIT %d", dlLimit) 180 } 181 182 return q, nil 183 } 184 185 func resolveFormat(path, flag string) string { 186 if flag != "" { 187 return strings.ToLower(flag) 188 } 189 lower := strings.ToLower(path) 190 if strings.HasSuffix(lower, ".csv") { 191 return "csv" 192 } 193 return "parquet" 194 }