StataUtils.jl (17849B)
1 # ------------------------------------------------------------------------------------------ 2 3 # StataUtils.jl 4 5 # Collection of functions that replicate some stata utilities 6 # ------------------------------------------------------------------------------------------ 7 8 9 10 # ------------------------------------------------------------------------------------------ 11 # List of exported functions 12 # tabulate 13 # xtile 14 # ------------------------------------------------------------------------------------------ 15 16 17 # ------------------------------------------------------------------------------------------ 18 """ 19 tabulate(df::AbstractDataFrame, cols::Union{Symbol, Array{Symbol}}; 20 reorder_cols=true, out::Symbol=:stdout) 21 22 Frequency tabulation inspired by Stata's `tabulate` command. 23 Forked from TexTables.jl and inspired by https://github.com/matthieugomez/statar 24 25 # Arguments 26 - `df::AbstractDataFrame`: Input DataFrame to analyze 27 - `cols::Union{Symbol, Vector{Symbol}}`: Single column name or vector of column names to tabulate 28 - `group_type::Union{Symbol, Vector{Symbol}}=:value`: Specifies how to group each column: 29 - `:value`: Group by the actual values in the column 30 - `:type`: Group by the type of values in the column 31 - `Vector{Symbol}`: Vector combining `:value` and `:type` for different columns 32 - `reorder_cols::Bool=true` Whether to sort the output by sortable columns 33 - `format_tbl::Symbol=:long` How to present the results long or wide (stata twoway) 34 - `format_stat::Symbol=:freq` Which statistics to present for format :freq or :pct 35 - `skip_stat::Union{Nothing, Symbol, Vector{Symbol}}=nothing` do not print out all statistics (only for string) 36 - `out::Symbol=:stdout` Output format: 37 - `:stdout` Print formatted table to standard output (returns nothing) 38 - `:df` Return the result as a DataFrame 39 - `:string` Return the formatted table as a string 40 41 # Returns 42 - `Nothing` if `out=:stdout` 43 - `DataFrame` if `out=:df` 44 - `String` if `out=:string` 45 46 # Output Format 47 The resulting table contains the following columns: 48 - Specified grouping columns (from `cols`) 49 - `freq`: Frequency count 50 - `pct`: Percentage of total 51 - `cum`: Cumulative percentage 52 53 # Examples 54 See the README for more examples 55 ```julia 56 # Simple frequency table for one column 57 tabulate(df, :country) 58 59 ## Group by value type 60 tabulate(df, :age, group_type=:type) 61 62 # Multiple columns with mixed grouping 63 tabulate(df, [:country, :age], group_type=[:value, :type]) 64 65 # Return as DataFrame instead of printing 66 result_df = tabulate(df, :country, out=:df) 67 ``` 68 69 """ 70 function tabulate( 71 df::AbstractDataFrame, cols::Union{Symbol, Vector{Symbol}}; 72 group_type::Union{Symbol, Vector{Symbol}}=:value, 73 reorder_cols::Bool=true, 74 format_tbl::Symbol=:long, 75 format_stat::Symbol=:freq, 76 skip_stat::Union{Nothing, Symbol, Vector{Symbol}}=nothing, 77 out::Symbol=:stdout) 78 79 N_COLS = cols isa Symbol ? 1 : length(cols) 80 81 if !(format_tbl ∈ [:long, :wide]) 82 if N_COLS == 1 83 @warn "Converting format_tbl to :long" 84 format_tbl = :long 85 else 86 error("Table format_tbl must be :long or :wide") 87 end 88 end 89 90 if isempty(df) 91 @warn "Input Dataframe is empty ..." 92 return nothing 93 end 94 95 df_out, new_cols = _tabulate_compute(df, cols, group_type, reorder_cols) 96 97 if format_tbl == :long 98 return _tabulate_render_long(df_out, new_cols, N_COLS, out, skip_stat) 99 else # :wide 100 return _tabulate_render_wide(df_out, new_cols, N_COLS, format_stat, out) 101 end 102 end 103 104 105 # ----- Computation: groupby, combine, sort, pct/cum transforms 106 function _tabulate_compute(df, cols, group_type, reorder_cols) 107 group_type_error_msg = """ 108 \ngroup_type input must specify either ':value' or ':type' for columns; 109 options are :value, :type, or a vector combining the two; 110 see help for more information 111 """ 112 if group_type == :value 113 df_out = combine(groupby(df, cols), nrow => :freq, proprow =>:pct) 114 new_cols = cols 115 elseif group_type == :type 116 name_type_cols = Symbol.(cols, "_typeof") 117 df_out = transform(df, cols .=> ByRow(typeof) .=> name_type_cols) |> 118 (d -> combine(groupby(d, name_type_cols), nrow => :freq, proprow =>:pct)) 119 new_cols = name_type_cols 120 elseif group_type isa Vector{Symbol} 121 !all(s -> s in [:value, :type], group_type) && error(group_type_error_msg) 122 (size(group_type, 1) != size(cols, 1)) && 123 error("group_type and cols must be the same size; see help for more information") 124 type_cols = cols[group_type .== :type] 125 name_type_cols = Symbol.(type_cols, "_typeof") 126 group_cols = [cols[group_type .== :value]; name_type_cols] 127 df_out = transform(df, type_cols .=> ByRow(typeof) .=> name_type_cols) |> 128 (d -> combine(groupby(d, group_cols), nrow => :freq, proprow =>:pct)) 129 new_cols = group_cols 130 else 131 error(group_type_error_msg) 132 end 133 # resort columns based on the original order 134 new_cols = sort(new_cols isa Symbol ? [new_cols] : new_cols, 135 by= x -> findfirst(==(replace(string(x), r"_typeof$" => "")), string.(cols)) ) 136 137 if reorder_cols 138 cols_sortable = [ 139 name 140 for (name, col) in pairs(eachcol(select(df_out, new_cols))) 141 if eltype(col) |> t -> hasmethod(isless, Tuple{t,t}) 142 ] 143 if !isempty(cols_sortable) 144 sort!(df_out, cols_sortable) # order before we build cumulative 145 end 146 end 147 transform!(df_out, :pct => cumsum => :cum, :freq => ByRow(Int) => :freq) 148 transform!(df_out, 149 :pct => (x -> x .* 100), 150 :cum => (x -> Int.(round.(x .* 100, digits=0))), renamecols=false) 151 152 return df_out, new_cols 153 end 154 155 156 # ----- Long format rendering 157 function _tabulate_render_long(df_out, new_cols, N_COLS, out, skip_stat) 158 transform!(df_out, :freq => (x->text_histogram(x, width=24)) => :freq_hist) 159 160 # highlighter with gradient for the freq/pct/cum columns (rest is cyan) 161 col_highlighters = Tuple(vcat( 162 map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_COLS), 163 hl_custom_gradient(cols=(N_COLS+1), colorscheme=:Oranges_9, scale=maximum(df_out.freq)), 164 hl_custom_gradient(cols=(N_COLS+2), colorscheme=:Greens_9, scale=ceil(Int, maximum(df_out.pct))), 165 hl_custom_gradient(cols=(N_COLS+3), colorscheme=:Greens_9, scale=100), 166 )) 167 168 # when skip_stat is provided and output is string, filter columns 169 if out == :string && !isnothing(skip_stat) 170 all_stats = [:freq, :pct, :cum, :freq_hist] 171 skip_list = skip_stat isa Vector ? skip_stat : [skip_stat] 172 col_stat = setdiff(all_stats, skip_list) 173 N_COL_STAT = length(col_stat) 174 175 stat_headers = Dict(:freq=>"Freq.", :pct=>"Percent", :cum=>"Cum", :freq_hist=>"Hist.") 176 stat_formats = Dict(:freq=>"%d", :pct=>"%.1f", :cum=>"%d", :freq_hist=>"%s") 177 stat_colorschemes = Dict( 178 :freq => (:Oranges_9, maximum(df_out.freq)), 179 :pct => (:Greens_9, ceil(Int, maximum(df_out.pct))), 180 :cum => (:Greens_9, 100), 181 ) 182 183 header = vcat(string.(new_cols), 184 [stat_headers[k] for k in col_stat]) 185 formatters = Tuple(vcat( 186 [ft_printf("%s", i) for i in 1:N_COLS], 187 [ft_printf(stat_formats[k], N_COLS + i) for (i, k) in enumerate(col_stat)] 188 )) 189 # rebuild highlighters for the filtered column layout 190 filtered_highlighters = Tuple(vcat( 191 map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_COLS), 192 [haskey(stat_colorschemes, k) ? 193 hl_custom_gradient(cols=N_COLS+i, colorscheme=stat_colorschemes[k][1], scale=stat_colorschemes[k][2]) : 194 Highlighter((data, row, col) -> col == N_COLS+i, crayon"white") 195 for (i, k) in enumerate(col_stat)] 196 )) 197 alignment = vcat(repeat([:l], N_COLS), repeat([:c], N_COL_STAT)) 198 cell_alignment = reduce(push!, 199 map(i -> (i,1)=>:l, 1:N_COLS+N_COL_STAT-1), 200 init=Dict{Tuple{Int64, Int64}, Symbol}()) 201 202 df_render = select(df_out, new_cols, col_stat) 203 return _render_pretty_table(df_render, out; 204 hlines=[1], vlines=[N_COLS], 205 alignment=alignment, cell_alignment=cell_alignment, 206 header=header, formatters=formatters, highlighters=filtered_highlighters) 207 end 208 209 # default: all stat columns 210 header = [string.(new_cols); "Freq."; "Percent"; "Cum"; "Hist."] 211 formatters = Tuple(vcat( 212 [ft_printf("%s", i) for i in 1:N_COLS], 213 [ft_printf("%d", N_COLS+1), ft_printf("%.1f", N_COLS+2), 214 ft_printf("%d", N_COLS+3), ft_printf("%s", N_COLS+4)] 215 )) 216 alignment = vcat(repeat([:l], N_COLS), :c, :c, :c, :c) 217 cell_alignment = reduce(push!, 218 map(i -> (i,1)=>:l, 1:N_COLS+3), 219 init=Dict{Tuple{Int64, Int64}, Symbol}()) 220 221 return _render_pretty_table(df_out, out; 222 hlines=[1], vlines=[N_COLS], 223 alignment=alignment, cell_alignment=cell_alignment, 224 header=header, formatters=formatters, highlighters=col_highlighters) 225 end 226 227 228 # ----- Wide format rendering 229 function _tabulate_render_wide(df_out, new_cols, N_COLS, format_stat, out) 230 format_stat ∈ (:freq, :pct) || error("format_stat must be :freq or :pct, got :$format_stat") 231 df_out = unstack(df_out, 232 new_cols[1:(N_COLS-1)], new_cols[N_COLS], format_stat, 233 allowmissing=true) 234 235 N_GROUP_COLS = N_COLS - 1 236 N_VAR_COLS = size(df_out, 2) - N_GROUP_COLS 237 238 if format_stat == :freq 239 240 # frequency: add row and column totals 241 total_row_des = "Total by $(string(new_cols[N_COLS]))" 242 total_col_des = join(vcat("Total by ", join(string.(new_cols[1:(N_COLS-1)]), ", "))) 243 244 sum_cols = sum.(skipmissing.(eachcol(df_out[:, range(1+N_GROUP_COLS; length=N_VAR_COLS)]))) 245 row_vector = vcat([total_row_des], repeat(["-"], max(0, N_GROUP_COLS-1)), sum_cols) 246 df_out = vcat(df_out, 247 DataFrame(permutedims(row_vector)[:, end+1-size(df_out,2):end], names(df_out))) 248 sum_rows = sum.(skipmissing.(eachrow(df_out[:, range(1+N_GROUP_COLS; length=N_VAR_COLS)]))) 249 col_vector = rename(DataFrame(total = sum_rows), "total" => total_col_des) 250 df_out = hcat(df_out, col_vector) 251 rename!(df_out, [i => "-"^i for i in 1:N_GROUP_COLS]) 252 253 col_highlighters = Tuple(vcat( 254 map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_GROUP_COLS), 255 [ hl_custom_gradient(cols=i, colorscheme=:Greens_9, 256 scale = ceil(Int, maximum(skipmissing(df_out[1:end-1, i])))) 257 for i in range(1+N_GROUP_COLS; length=N_VAR_COLS) ], 258 Highlighter((data, row, col) -> col == size(df_out, 2), crayon"green") 259 )) 260 261 formatters = Tuple(vcat( 262 [ ft_printf("%s", i) for i in 1:N_GROUP_COLS ], 263 [ ft_printf("%d", j) for j in range(1+N_GROUP_COLS; length=N_VAR_COLS) ], 264 [ ft_printf("%d", 1+N_GROUP_COLS+N_VAR_COLS) ] 265 )) 266 267 hlines = [1, size(df_out, 1)] 268 vlines = [N_GROUP_COLS, N_GROUP_COLS+N_VAR_COLS] 269 alignment = vcat(repeat([:l], N_GROUP_COLS), repeat([:c], N_VAR_COLS), [:l]) 270 271 elseif format_stat == :pct 272 273 col_highlighters = Tuple(vcat( 274 map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_GROUP_COLS), 275 [ hl_custom_gradient(cols=i, colorscheme=:Greens_9, 276 scale = ceil(Int, maximum(skipmissing(df_out[:, i]))) ) 277 for i in range(1+N_GROUP_COLS; length=N_VAR_COLS) ], 278 )) 279 280 formatters = Tuple(vcat( 281 [ ft_printf("%s", i) for i in 1:N_GROUP_COLS ], 282 [ ft_printf("%.1f", j) for j in range(1+N_GROUP_COLS; length=N_VAR_COLS) ] 283 )) 284 285 hlines = [1] 286 vlines = [0, N_GROUP_COLS, N_GROUP_COLS+N_VAR_COLS] 287 alignment = vcat(repeat([:l], N_GROUP_COLS), repeat([:c], N_VAR_COLS)) 288 289 end 290 291 cell_alignment = reduce(push!, 292 map(i -> (i,1)=>:l, 1:N_GROUP_COLS), 293 init=Dict{Tuple{Int64, Int64}, Symbol}()) 294 295 return _render_pretty_table(df_out, out; 296 hlines=hlines, vlines=vlines, 297 alignment=alignment, cell_alignment=cell_alignment, 298 formatters=formatters, highlighters=col_highlighters, 299 show_subheader=false) 300 end 301 302 303 # ----- Unified pretty_table output handler (stdout / df / string) 304 function _render_pretty_table(df, out::Symbol; show_subheader=true, pt_kwargs...) 305 common = ( 306 border_crayon = crayon"bold yellow", 307 header_crayon = crayon"bold light_green", 308 show_header = true, 309 show_subheader = show_subheader, 310 ) 311 312 if out ∈ [:stdout, :df] 313 pretty_table(df; common..., vcrop_mode=:middle, pt_kwargs...) 314 return out == :stdout ? nothing : df 315 else # :string 316 return pretty_table(String, df; common..., crop=:none, pt_kwargs...) 317 end 318 end 319 # -------------------------------------------------------------------------------------------------- 320 321 322 # -------------------------------------------------------------------------------------------------- 323 function hl_custom_gradient(; 324 cols::Int=0, 325 colorscheme::Symbol=:Oranges_9, 326 scale::Int=1) 327 328 Highlighter( 329 (data, i, j) -> j == cols, 330 (h, data, i, j) -> begin 331 if ismissing(data[i, j]) 332 return Crayon(foreground=(128, 128, 128)) # Use a default color for missing values 333 end 334 color = get(colorschemes[colorscheme], data[i, j], (0, scale)) 335 return Crayon(foreground=(round(Int, color.r * 255), 336 round(Int, color.g * 255), 337 round(Int, color.b * 255))) 338 end 339 ) 340 341 end 342 # -------------------------------------------------------------------------------------------------- 343 344 345 # -------------------------------------------------------------------------------------------------- 346 # From https://github.com/mbauman/Sparklines.jl/blob/master/src/Sparklines.jl 347 # Unicode characters: 348 # █ (Full block, U+2588) 349 # ⣿ (Full Braille block, U+28FF) 350 # ▓ (Dark shade, U+2593) 351 # ▒ (Medium shade, U+2592) 352 # ░ (Light shade, U+2591) 353 # ◼ (Small black square, U+25FC) 354 355 function text_histogram(frequencies; width=12) 356 blocks = [" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"] 357 max_freq = maximum(frequencies) 358 max_freq == 0 && return fill(" " ^ width, length(frequencies)) 359 scale = (width * 8 - 1) / max_freq # Subtract 1 to ensure we don't exceed width 360 361 function bar(f) 362 units = round(Int, f * scale) 363 full_blocks = div(units, 8) 364 remainder = units % 8 365 rpad(repeat("█", full_blocks) * blocks[remainder + 1], width) 366 end 367 bar.(frequencies) 368 end 369 # -------------------------------------------------------------------------------------------------- 370 371 372 373 # -------------------------------------------------------------------------------------------------- 374 375 """ 376 xtile(data::Vector{T}, n_quantiles::Integer, 377 weights::Union{Vector{Float64}, Nothing}=nothing)::Vector{Int} where T <: Real 378 379 Create quantile groups using Julia's built-in weighted quantile functionality. 380 381 # Arguments 382 - `data`: Values to group 383 - `n_quantiles`: Number of groups 384 - `weights`: Optional weights of weight type (StatasBase) 385 386 # Examples 387 ```julia 388 sales = rand(10_000); 389 a = xtile(sales, 10); 390 b = xtile(sales, 10, weights=Weights(repeat([1], length(sales))) ); 391 @assert a == b 392 ``` 393 """ 394 function xtile( 395 data::AbstractVector{T}, 396 n_quantiles::Integer; 397 weights::Union{Weights{<:Real}, Nothing} = nothing 398 )::Vector{Int} where T <: Real 399 400 N = length(data) 401 n_quantiles < 1 && error("n_quantiles must be >= 1") 402 n_quantiles > N && (@warn "More quantiles than data") 403 404 probs = range(0, 1, length=n_quantiles + 1)[2:end] 405 if weights === nothing 406 weights = UnitWeights{T}(N) 407 end 408 cuts = quantile(collect(data), weights, probs) 409 410 return searchsortedlast.(Ref(cuts), data) 411 end 412 413 # String version: use lexicographic rank, then delegate to numeric xtile 414 function xtile( 415 data::AbstractVector{T}, 416 n_quantiles::Integer; 417 weights::Union{Weights{<:Real}, Nothing} = nothing 418 )::Vector{Int} where T <: AbstractString 419 420 sorted_cats = sort(unique(data)) 421 rank_map = Dict(cat => i for (i, cat) in enumerate(sorted_cats)) 422 ranks = [rank_map[d] for d in data] 423 424 return xtile(ranks, n_quantiles; weights=weights) 425 end 426 427 # Dealing with missing and Numbers 428 function xtile( 429 data::AbstractVector{T}, 430 n_quantiles::Integer; 431 weights::Union{Weights{<:Real}, Nothing} = nothing 432 )::Vector{Union{Int, Missing}} where {T <: Union{Missing, AbstractString, Number}} 433 434 # Determine the non-missing type 435 non_missing_type = Base.nonmissingtype(T) 436 437 # Identify valid (non-missing) data 438 data_notmissing_idx = findall(!ismissing, data) 439 440 if isempty(data_notmissing_idx) # If all values are missing, return all missing 441 return fill(missing, length(data)) 442 end 443 444 # Use @view to avoid unnecessary allocations but convert explicitly to non-missing type 445 valid_data = convert(Vector{non_missing_type}, @view data[data_notmissing_idx]) 446 valid_weights = weights === nothing ? nothing : Weights(@view weights[data_notmissing_idx]) 447 448 # Compute quantile groups on valid data 449 valid_result = xtile(valid_data, n_quantiles; weights=valid_weights) 450 451 # Allocate result array with correct type 452 result = Vector{Union{Int, Missing}}(missing, length(data)) 453 result[data_notmissing_idx] .= valid_result # Assign computed quantile groups 454 455 return result 456 end 457 # -------------------------------------------------------------------------------------------------- 458