wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

export.go (7700B)


      1 package export
      2 
      3 import (
      4 	"context"
      5 	"encoding/csv"
      6 	"fmt"
      7 	"math/big"
      8 	"os"
      9 	"strings"
     10 	"time"
     11 
     12 	"github.com/jackc/pgx/v5"
     13 	"github.com/jackc/pgx/v5/pgconn"
     14 	"github.com/jackc/pgx/v5/pgtype"
     15 	"github.com/parquet-go/parquet-go"
     16 	"github.com/parquet-go/parquet-go/compress/zstd"
     17 
     18 	"github.com/louloulibs/wrds-download/internal/db"
     19 )
     20 
     21 // Options controls the export behaviour.
     22 type Options struct {
     23 	Format       string         // "parquet" or "csv"
     24 	ProgressFunc func(rows int) // called periodically with total rows exported so far
     25 }
     26 
     27 const rowGroupSize = 10_000
     28 
     29 // Export runs query against the WRDS PostgreSQL instance and writes output to outPath.
     30 // Format is determined by opts.Format (default: parquet).
     31 func Export(query, outPath string, opts Options) error {
     32 	format := strings.ToLower(opts.Format)
     33 	if format == "" {
     34 		if strings.HasSuffix(strings.ToLower(outPath), ".csv") {
     35 			format = "csv"
     36 		} else {
     37 			format = "parquet"
     38 		}
     39 	}
     40 
     41 	dsn, err := db.DSNFromEnv()
     42 	if err != nil {
     43 		return fmt.Errorf("dsn: %w", err)
     44 	}
     45 
     46 	ctx := context.Background()
     47 	conn, err := pgx.Connect(ctx, dsn)
     48 	if err != nil {
     49 		return fmt.Errorf("connect: %w", err)
     50 	}
     51 	defer conn.Close(ctx)
     52 
     53 	rows, err := conn.Query(ctx, query)
     54 	if err != nil {
     55 		return fmt.Errorf("query: %w", err)
     56 	}
     57 	defer rows.Close()
     58 
     59 	switch format {
     60 	case "csv":
     61 		return writeCSV(rows, outPath, opts.ProgressFunc)
     62 	default:
     63 		return writeParquet(rows, outPath, opts.ProgressFunc)
     64 	}
     65 }
     66 
     67 // writeCSV streams rows into a CSV file with a header row.
     68 func writeCSV(rows pgx.Rows, outPath string, progressFn func(int)) error {
     69 	f, err := os.Create(outPath)
     70 	if err != nil {
     71 		return fmt.Errorf("create csv: %w", err)
     72 	}
     73 	defer f.Close()
     74 
     75 	w := csv.NewWriter(f)
     76 	defer w.Flush()
     77 
     78 	fds := rows.FieldDescriptions()
     79 	header := make([]string, len(fds))
     80 	for i, fd := range fds {
     81 		header[i] = fd.Name
     82 	}
     83 	if err := w.Write(header); err != nil {
     84 		return fmt.Errorf("write header: %w", err)
     85 	}
     86 
     87 	record := make([]string, len(fds))
     88 	total := 0
     89 	for rows.Next() {
     90 		vals, err := rows.Values()
     91 		if err != nil {
     92 			return fmt.Errorf("scan row: %w", err)
     93 		}
     94 		for i, v := range vals {
     95 			record[i] = formatValue(v)
     96 		}
     97 		if err := w.Write(record); err != nil {
     98 			return fmt.Errorf("write row: %w", err)
     99 		}
    100 		total++
    101 		if progressFn != nil && total%rowGroupSize == 0 {
    102 			progressFn(total)
    103 		}
    104 	}
    105 	if err := rows.Err(); err != nil {
    106 		return fmt.Errorf("rows: %w", err)
    107 	}
    108 
    109 	w.Flush()
    110 	return w.Error()
    111 }
    112 
    113 // writeParquet streams rows into a Parquet file using parquet-go.
    114 func writeParquet(rows pgx.Rows, outPath string, progressFn func(int)) error {
    115 	fds := rows.FieldDescriptions()
    116 
    117 	schema, colTypes := buildParquetSchema(fds)
    118 
    119 	f, err := os.Create(outPath)
    120 	if err != nil {
    121 		return fmt.Errorf("create parquet: %w", err)
    122 	}
    123 	defer f.Close()
    124 
    125 	writer := parquet.NewGenericWriter[map[string]any](f,
    126 		schema,
    127 		parquet.Compression(&zstd.Codec{}),
    128 		parquet.DefaultEncodingFor(parquet.ByteArray, &parquet.Plain),
    129 	)
    130 
    131 	buf := make([]map[string]any, 0, rowGroupSize)
    132 	total := 0
    133 
    134 	for rows.Next() {
    135 		vals, err := rows.Values()
    136 		if err != nil {
    137 			return fmt.Errorf("scan row: %w", err)
    138 		}
    139 
    140 		row := make(map[string]any, len(fds))
    141 		for i, v := range vals {
    142 			row[fds[i].Name] = convertValue(v, colTypes[i])
    143 		}
    144 		buf = append(buf, row)
    145 
    146 		if len(buf) >= rowGroupSize {
    147 			if _, err := writer.Write(buf); err != nil {
    148 				return fmt.Errorf("write row group: %w", err)
    149 			}
    150 			total += len(buf)
    151 			buf = buf[:0]
    152 			if progressFn != nil {
    153 				progressFn(total)
    154 			}
    155 		}
    156 	}
    157 	if err := rows.Err(); err != nil {
    158 		return fmt.Errorf("rows: %w", err)
    159 	}
    160 
    161 	// Flush remaining rows.
    162 	if len(buf) > 0 {
    163 		if _, err := writer.Write(buf); err != nil {
    164 			return fmt.Errorf("write final rows: %w", err)
    165 		}
    166 	}
    167 
    168 	return writer.Close()
    169 }
    170 
    171 // colType tags how we convert PG values for Parquet.
    172 type colType int
    173 
    174 const (
    175 	colString colType = iota
    176 	colBool
    177 	colInt32
    178 	colInt64
    179 	colFloat32
    180 	colFloat64
    181 	colDate      // days since epoch → int32
    182 	colTimestamp // microseconds since epoch → int64
    183 )
    184 
    185 // buildParquetSchema maps PG field descriptors to a parquet schema.
    186 func buildParquetSchema(fds []pgconn.FieldDescription) (*parquet.Schema, []colType) {
    187 	cols := make([]colType, len(fds))
    188 	group := make(parquet.Group, len(fds))
    189 
    190 	for i, fd := range fds {
    191 		var node parquet.Node
    192 
    193 		switch fd.DataTypeOID {
    194 		case 16: // bool
    195 			cols[i] = colBool
    196 			node = parquet.Optional(parquet.Leaf(parquet.BooleanType))
    197 		case 21: // int2
    198 			cols[i] = colInt32
    199 			node = parquet.Optional(parquet.Leaf(parquet.Int32Type))
    200 		case 23: // int4
    201 			cols[i] = colInt32
    202 			node = parquet.Optional(parquet.Leaf(parquet.Int32Type))
    203 		case 20: // int8
    204 			cols[i] = colInt64
    205 			node = parquet.Optional(parquet.Leaf(parquet.Int64Type))
    206 		case 700: // float4
    207 			cols[i] = colFloat32
    208 			node = parquet.Optional(parquet.Leaf(parquet.FloatType))
    209 		case 701: // float8
    210 			cols[i] = colFloat64
    211 			node = parquet.Optional(parquet.Leaf(parquet.DoubleType))
    212 		case 1082: // date
    213 			cols[i] = colDate
    214 			node = parquet.Optional(parquet.Date())
    215 		case 1114, 1184: // timestamp, timestamptz
    216 			cols[i] = colTimestamp
    217 			node = parquet.Optional(parquet.Timestamp(parquet.Microsecond))
    218 		default:
    219 			// text (25), varchar (1043), char (18, 1042), numeric (1700), etc.
    220 			cols[i] = colString
    221 			node = parquet.Optional(parquet.String())
    222 		}
    223 
    224 		group[fd.Name] = node
    225 	}
    226 
    227 	return parquet.NewSchema("wrds", group), cols
    228 }
    229 
    230 var epoch = time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
    231 
    232 // convertValue converts a pgx-scanned value to the appropriate Go type for parquet-go.
    233 func convertValue(v any, ct colType) any {
    234 	if v == nil {
    235 		return nil
    236 	}
    237 
    238 	switch ct {
    239 	case colBool:
    240 		if b, ok := v.(bool); ok {
    241 			return b
    242 		}
    243 	case colInt32:
    244 		switch n := v.(type) {
    245 		case int16:
    246 			return int32(n)
    247 		case int32:
    248 			return n
    249 		case int64:
    250 			return int32(n)
    251 		}
    252 	case colInt64:
    253 		switch n := v.(type) {
    254 		case int64:
    255 			return n
    256 		case int32:
    257 			return int64(n)
    258 		case int16:
    259 			return int64(n)
    260 		}
    261 	case colFloat32:
    262 		if f, ok := v.(float32); ok {
    263 			return f
    264 		}
    265 		if f, ok := v.(float64); ok {
    266 			return float32(f)
    267 		}
    268 	case colFloat64:
    269 		if f, ok := v.(float64); ok {
    270 			return f
    271 		}
    272 		if f, ok := v.(float32); ok {
    273 			return float64(f)
    274 		}
    275 	case colDate:
    276 		if t, ok := v.(time.Time); ok {
    277 			days := int32(t.Sub(epoch).Hours() / 24)
    278 			return days
    279 		}
    280 	case colTimestamp:
    281 		if t, ok := v.(time.Time); ok {
    282 			return t.Sub(epoch).Microseconds()
    283 		}
    284 	case colString:
    285 		return formatValue(v)
    286 	}
    287 
    288 	// Fallback: stringify.
    289 	return formatValue(v)
    290 }
    291 
    292 // formatValue converts any value to its string representation.
    293 func formatValue(v any) string {
    294 	if v == nil {
    295 		return ""
    296 	}
    297 	switch val := v.(type) {
    298 	case string:
    299 		return val
    300 	case []byte:
    301 		return string(val)
    302 	case time.Time:
    303 		if val.Hour() == 0 && val.Minute() == 0 && val.Second() == 0 && val.Nanosecond() == 0 {
    304 			return val.Format("2006-01-02")
    305 		}
    306 		return val.Format(time.RFC3339)
    307 	case pgtype.Numeric:
    308 		if !val.Valid {
    309 			return ""
    310 		}
    311 		if val.NaN {
    312 			return "NaN"
    313 		}
    314 		if val.InfinityModifier == pgtype.Infinity {
    315 			return "Infinity"
    316 		}
    317 		if val.InfinityModifier == pgtype.NegativeInfinity {
    318 			return "-Infinity"
    319 		}
    320 		// Convert to big.Float for string representation.
    321 		bi := val.Int
    322 		if bi == nil {
    323 			bi = new(big.Int)
    324 		}
    325 		bf := new(big.Float).SetInt(bi)
    326 		if val.Exp < 0 {
    327 			divisor := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(-val.Exp)), nil))
    328 			bf.Quo(bf, divisor)
    329 		} else if val.Exp > 0 {
    330 			multiplier := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(val.Exp)), nil))
    331 			bf.Mul(bf, multiplier)
    332 		}
    333 		return bf.Text('f', -1)
    334 	default:
    335 		return fmt.Sprintf("%v", val)
    336 	}
    337 }