wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

download_test.go (2406B)


      1 package cmd
      2 
      3 import "testing"
      4 
      5 func TestBuildQuery(t *testing.T) {
      6 	tests := []struct {
      7 		name    string
      8 		setup   func()
      9 		want    string
     10 		wantErr bool
     11 	}{
     12 		{
     13 			name: "raw query passthrough",
     14 			setup: func() {
     15 				dlQuery = "SELECT * FROM crsp.dsf LIMIT 10"
     16 				dlSchema = ""
     17 				dlTable = ""
     18 			},
     19 			want: "SELECT * FROM crsp.dsf LIMIT 10",
     20 		},
     21 		{
     22 			name: "schema and table",
     23 			setup: func() {
     24 				dlQuery = ""
     25 				dlSchema = "crsp"
     26 				dlTable = "dsf"
     27 				dlColumns = "*"
     28 				dlWhere = ""
     29 				dlLimit = 0
     30 			},
     31 			want: `SELECT * FROM "crsp"."dsf"`,
     32 		},
     33 		{
     34 			name: "with columns",
     35 			setup: func() {
     36 				dlQuery = ""
     37 				dlSchema = "comp"
     38 				dlTable = "funda"
     39 				dlColumns = "gvkey,datadate,sale"
     40 				dlWhere = ""
     41 				dlLimit = 0
     42 			},
     43 			want: `SELECT "gvkey", "datadate", "sale" FROM "comp"."funda"`,
     44 		},
     45 		{
     46 			name: "with where and limit",
     47 			setup: func() {
     48 				dlQuery = ""
     49 				dlSchema = "crsp"
     50 				dlTable = "dsf"
     51 				dlColumns = "*"
     52 				dlWhere = "date >= '2020-01-01'"
     53 				dlLimit = 1000
     54 			},
     55 			want: `SELECT * FROM "crsp"."dsf" WHERE date >= '2020-01-01' LIMIT 1000`,
     56 		},
     57 		{
     58 			name: "missing schema and table",
     59 			setup: func() {
     60 				dlQuery = ""
     61 				dlSchema = ""
     62 				dlTable = ""
     63 			},
     64 			wantErr: true,
     65 		},
     66 		{
     67 			name: "column with spaces trimmed",
     68 			setup: func() {
     69 				dlQuery = ""
     70 				dlSchema = "crsp"
     71 				dlTable = "dsf"
     72 				dlColumns = " permno , date , prc "
     73 				dlWhere = ""
     74 				dlLimit = 0
     75 			},
     76 			want: `SELECT "permno", "date", "prc" FROM "crsp"."dsf"`,
     77 		},
     78 	}
     79 
     80 	for _, tt := range tests {
     81 		t.Run(tt.name, func(t *testing.T) {
     82 			tt.setup()
     83 			got, err := buildQuery()
     84 			if (err != nil) != tt.wantErr {
     85 				t.Fatalf("buildQuery() error = %v, wantErr %v", err, tt.wantErr)
     86 			}
     87 			if !tt.wantErr && got != tt.want {
     88 				t.Errorf("buildQuery() = %q, want %q", got, tt.want)
     89 			}
     90 		})
     91 	}
     92 }
     93 
     94 func TestResolveFormat(t *testing.T) {
     95 	tests := []struct {
     96 		path string
     97 		flag string
     98 		want string
     99 	}{
    100 		{"out.parquet", "", "parquet"},
    101 		{"out.csv", "", "csv"},
    102 		{"out.CSV", "", "csv"},
    103 		{"out.parquet", "csv", "csv"},
    104 		{"out.csv", "parquet", "parquet"},
    105 		{"out.txt", "", "parquet"},
    106 		{"out", "CSV", "csv"},
    107 	}
    108 
    109 	for _, tt := range tests {
    110 		t.Run(tt.path+"_"+tt.flag, func(t *testing.T) {
    111 			got := resolveFormat(tt.path, tt.flag)
    112 			if got != tt.want {
    113 				t.Errorf("resolveFormat(%q, %q) = %q, want %q", tt.path, tt.flag, got, tt.want)
    114 			}
    115 		})
    116 	}
    117 }