download_test.go (2406B)
1 package cmd 2 3 import "testing" 4 5 func TestBuildQuery(t *testing.T) { 6 tests := []struct { 7 name string 8 setup func() 9 want string 10 wantErr bool 11 }{ 12 { 13 name: "raw query passthrough", 14 setup: func() { 15 dlQuery = "SELECT * FROM crsp.dsf LIMIT 10" 16 dlSchema = "" 17 dlTable = "" 18 }, 19 want: "SELECT * FROM crsp.dsf LIMIT 10", 20 }, 21 { 22 name: "schema and table", 23 setup: func() { 24 dlQuery = "" 25 dlSchema = "crsp" 26 dlTable = "dsf" 27 dlColumns = "*" 28 dlWhere = "" 29 dlLimit = 0 30 }, 31 want: `SELECT * FROM "crsp"."dsf"`, 32 }, 33 { 34 name: "with columns", 35 setup: func() { 36 dlQuery = "" 37 dlSchema = "comp" 38 dlTable = "funda" 39 dlColumns = "gvkey,datadate,sale" 40 dlWhere = "" 41 dlLimit = 0 42 }, 43 want: `SELECT "gvkey", "datadate", "sale" FROM "comp"."funda"`, 44 }, 45 { 46 name: "with where and limit", 47 setup: func() { 48 dlQuery = "" 49 dlSchema = "crsp" 50 dlTable = "dsf" 51 dlColumns = "*" 52 dlWhere = "date >= '2020-01-01'" 53 dlLimit = 1000 54 }, 55 want: `SELECT * FROM "crsp"."dsf" WHERE date >= '2020-01-01' LIMIT 1000`, 56 }, 57 { 58 name: "missing schema and table", 59 setup: func() { 60 dlQuery = "" 61 dlSchema = "" 62 dlTable = "" 63 }, 64 wantErr: true, 65 }, 66 { 67 name: "column with spaces trimmed", 68 setup: func() { 69 dlQuery = "" 70 dlSchema = "crsp" 71 dlTable = "dsf" 72 dlColumns = " permno , date , prc " 73 dlWhere = "" 74 dlLimit = 0 75 }, 76 want: `SELECT "permno", "date", "prc" FROM "crsp"."dsf"`, 77 }, 78 } 79 80 for _, tt := range tests { 81 t.Run(tt.name, func(t *testing.T) { 82 tt.setup() 83 got, err := buildQuery() 84 if (err != nil) != tt.wantErr { 85 t.Fatalf("buildQuery() error = %v, wantErr %v", err, tt.wantErr) 86 } 87 if !tt.wantErr && got != tt.want { 88 t.Errorf("buildQuery() = %q, want %q", got, tt.want) 89 } 90 }) 91 } 92 } 93 94 func TestResolveFormat(t *testing.T) { 95 tests := []struct { 96 path string 97 flag string 98 want string 99 }{ 100 {"out.parquet", "", "parquet"}, 101 {"out.csv", "", "csv"}, 102 {"out.CSV", "", "csv"}, 103 {"out.parquet", "csv", "csv"}, 104 {"out.csv", "parquet", "parquet"}, 105 {"out.txt", "", "parquet"}, 106 {"out", "CSV", "csv"}, 107 } 108 109 for _, tt := range tests { 110 t.Run(tt.path+"_"+tt.flag, func(t *testing.T) { 111 got := resolveFormat(tt.path, tt.flag) 112 if got != tt.want { 113 t.Errorf("resolveFormat(%q, %q) = %q, want %q", tt.path, tt.flag, got, tt.want) 114 } 115 }) 116 } 117 }