commit ebaa5cb0b8d657ea585b4fb1c93e10c21381bb05
parent fdb6e793588099732b112541b18e1018ca08f65c
Author: Erik Loualiche <eloualic@umn.edu>
Date: Wed, 25 Feb 2026 09:41:52 -0600
fix string xtile, optimize panel_fill, add input validation and tests
- Replace broken string xtile (unsorted searchsortedlast) with
lexicographic ranking that delegates to numeric xtile
- Optimize panel_fill: collect chunks into Vector{DataFrame} and
single vcat instead of O(n²) repeated vcat; move interpolation
method selection outside the per-group loop
- Add format_stat validation in _tabulate_render_wide
- Add error-path tests for winsorize, xtile, and tabulate
- Remove stale TODO comment in Winsorize.jl
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat:
6 files changed, 92 insertions(+), 57 deletions(-)
diff --git a/src/PanelData.jl b/src/PanelData.jl
@@ -67,58 +67,56 @@ function panel_fill!(
df[!, time_var_r] .= df[!, time_var]
end
+ if method == :backwards
+ interpolate_method = BSpline(Constant(Previous))
+ elseif method == :forwards
+ interpolate_method = BSpline(Constant(Next))
+ elseif method == :nearest
+ interpolate_method = BSpline(Constant())
+ elseif method == :linear
+ interpolate_method = BSpline(Linear())
+ else
+ error(
+ """
+ Method $method not available.
+ Please choose from :backwards (default), :forwards, :nearest, :linear
+ """
+ )
+ end
+
gdf = groupby(df, [id_var])
- df_fill = DataFrame();
+ fill_chunks = DataFrame[]
for id_gdf in eachindex(gdf)
subdf = gdf[id_gdf]
- sub_fill = DataFrame()
-
- if method == :backwards
- interpolate_method = BSpline(Constant(Previous))
- elseif method == :forwards
- interpolate_method = BSpline(Constant(Next)) # # Next-neighbor interpolation
- elseif method == :nearest
- interpolate_method = BSpline(Constant()) # Nearest-neighbor interpolation
- elseif method == :linear
- interpolate_method = BSpline(Linear())
- else
- error(
- """
- Method $method not available.
- Please choose from :backwards (default), :forwards, :nearest, :linear
- """
- )
- end
-
- if nrow(subdf)>1 # condition for filling: at least one open
+
+ if nrow(subdf) > 1
sort!(subdf, time_var_r)
rowdf_init = subdf[1, :]
for rowdf in eachrow(subdf)[2:end]
-
- old_t = rowdf_init[time_var_r] # initialize the iteration
- enum_t = rowdf[time_var_r]
-
+
+ old_t = rowdf_init[time_var_r]
+ enum_t = rowdf[time_var_r]
+
t_fill = collect(range(old_t, enum_t, step=sign(enum_t-old_t) * gap))[2:end-1]
group_fill = DataFrame(
Dict(Symbol(time_var_r) => t_fill, id_var => id_gdf[1]))
N_fill = nrow(group_fill)
- scale_xs = range(1, 2, N_fill+2)[2:end-1] # the scaling matrix
+ scale_xs = range(1, 2, N_fill+2)[2:end-1]
- # this builds the interpolator and home made scales
interp_dict = Dict(
- v => interpolate([rowdf_init[v], rowdf[v]], interpolate_method)
+ v => interpolate([rowdf_init[v], rowdf[v]], interpolate_method)
for v in value_var)
var_fill = DataFrame(
Dict(v => interp_dict[v].(scale_xs) for v in value_var))
- # process the iteration and move on
- sub_fill = vcat(sub_fill, hcat(group_fill, var_fill))
- rowdf_init = rowdf;
+ push!(fill_chunks, hcat(group_fill, var_fill))
+ rowdf_init = rowdf
end
end
- df_fill = vcat(sub_fill, df_fill)
end
+
+ df_fill = isempty(fill_chunks) ? DataFrame() : vcat(fill_chunks...)
# clean up the output
if flag
diff --git a/src/StataUtils.jl b/src/StataUtils.jl
@@ -227,6 +227,7 @@ end
# ----- Wide format rendering
function _tabulate_render_wide(df_out, new_cols, N_COLS, format_stat, out)
+ format_stat ∈ (:freq, :pct) || error("format_stat must be :freq or :pct, got :$format_stat")
df_out = unstack(df_out,
new_cols[1:(N_COLS-1)], new_cols[N_COLS], format_stat,
allowmissing=true)
@@ -409,25 +410,18 @@ function xtile(
return searchsortedlast.(Ref(cuts), data)
end
-# String version
+# String version: use lexicographic rank, then delegate to numeric xtile
function xtile(
data::AbstractVector{T},
n_quantiles::Integer;
weights::Union{Weights{<:Real}, Nothing} = nothing
)::Vector{Int} where T <: AbstractString
- if weights === nothing
- weights = UnitWeights{Int}(length(data))
- end
- # Assign weights to each category
- category_weights = [sum(weights[data .== category]) for category in unique(data)]
- # Sort categories based on the weighted cumulative sum
- sorted_categories = sortperm(category_weights, rev=true)
- step = max(1, round(Int, length(sorted_categories) / n_quantiles))
- cuts = unique(data)[sorted_categories][1:step:end]
-
- return searchsortedlast.(Ref(cuts), data)
+ sorted_cats = sort(unique(data))
+ rank_map = Dict(cat => i for (i, cat) in enumerate(sorted_cats))
+ ranks = [rank_map[d] for d in data]
+ return xtile(ranks, n_quantiles; weights=weights)
end
# Dealing with missing and Numbers
diff --git a/src/Winsorize.jl b/src/Winsorize.jl
@@ -63,7 +63,7 @@ function winsorize(x::AbstractVector{T};
if any(ismissing.(replace_value))
y = Vector{Union{T, Missing}}(x) # Make a copy of x that can also store missing values
else
- y = Vector{Union{T, eltype(replace_value)}}(x) # TODO could be faster using views here ...
+ y = Vector{Union{T, eltype(replace_value)}}(x)
end
y[findall(skipmissing(x .< cutpoints[1]))] .= replace_value[1];
diff --git a/test/UnitTests/tabulate.jl b/test/UnitTests/tabulate.jl
@@ -164,3 +164,10 @@ end
@test nrow(df_tab) == 3
@test sum(df_tab.freq) == 6
end
+
+
+@testset "Tabulate - invalid format_stat in wide" begin
+ df = dropmissing(DataFrame(PalmerPenguins.load()))
+ @test_throws Exception tabulate(df, [:island, :species],
+ format_tbl=:wide, format_stat=:invalid, out=:df)
+end
diff --git a/test/UnitTests/winsorize.jl b/test/UnitTests/winsorize.jl
@@ -109,3 +109,13 @@ end
# probs path uses skipmissing which will be empty - quantile on empty should error
@test_throws Exception winsorize(x_all_missing, probs=(0.05, 0.95))
end
+
+
+@testset "winsorize - error paths" begin
+ # empty vector
+ @test_throws Exception winsorize(Float64[])
+
+ # invalid probability bounds
+ @test_throws Exception winsorize([1.0, 2.0, 3.0], probs=(-0.1, 0.9))
+ @test_throws Exception winsorize([1.0, 2.0, 3.0], probs=(0.1, 1.1))
+end
diff --git a/test/UnitTests/xtile.jl b/test/UnitTests/xtile.jl
@@ -2,16 +2,28 @@
df = dropmissing(DataFrame(PalmerPenguins.load()))
- # -- test on strings!
- a = xtile(df.species, 2);
- b = xtile(df.species, 2; weights=Weights(repeat([1], inner=nrow(df))));
- @test a==b
- @test sum(a)==520
-
- # -- test for more xtile than categories
- a = xtile(df.species, 4);
- b = xtile(df.species, 5);
- @test a==b
+ # -- test on strings (lexicographic ordering)
+ a = xtile(df.species, 2)
+ b = xtile(df.species, 2; weights=Weights(repeat([1], inner=nrow(df))))
+ @test a == b # uniform weights == no weights
+ @test all(0 .<= a .<= 2) # bins in valid range
+ # same species must get same bin
+ for sp in unique(df.species)
+ @test allequal(a[df.species .== sp])
+ end
+ # lexicographic order: Adelie < Chinstrap < Gentoo
+ @test a[findfirst(df.species .== "Adelie")] <= a[findfirst(df.species .== "Chinstrap")]
+ @test a[findfirst(df.species .== "Chinstrap")] <= a[findfirst(df.species .== "Gentoo")]
+
+ # -- string xtile with non-alphabetical categories
+ s_nonalpha = ["z", "z", "z", "a", "a", "m"]
+ result = xtile(s_nonalpha, 2)
+ @test allequal(result[1:3]) # all "z" in same bin
+ @test allequal(result[4:5]) # all "a" in same bin
+ @test all(0 .<= result .<= 2) # bins are valid
+ # lexicographic: "a" < "m" < "z"
+ @test result[4] <= result[6] # "a" <= "m"
+ @test result[6] <= result[1] # "m" <= "z"
# -- test on int
a = xtile(df.flipper_length_mm, 2);
@@ -48,8 +60,18 @@
# -- test on Union{Missing, AbstractString}
s_m = ["a", "c", "g", missing, "e", missing, "za"]
- @test isequal(xtile(s_m, 3), [1, 1, 2, missing, 1, missing, 3])
- @test isequal(xtile(s_m, 20), [1, 2, 4, missing, 2, missing, 5])
+ result_m = xtile(s_m, 3)
+ @test count(ismissing, result_m) == 2 # missing preserved
+ @test all(0 .<= skipmissing(result_m) .<= 3) # bins in valid range
+ # lexicographic: a < c < e < g < za
+ non_miss_idx = findall(!ismissing, s_m)
+ non_miss_vals = s_m[non_miss_idx]
+ non_miss_bins = result_m[non_miss_idx]
+ for (i, j) in zip(non_miss_idx[1:end-1], non_miss_idx[2:end])
+ if s_m[i] < s_m[j]
+ @test result_m[i] <= result_m[j] # ordering preserved
+ end
+ end
end
@@ -98,4 +120,8 @@ end
result = xtile(rand(100), 10)
@test all(r -> 0 <= r <= 10, result)
+ # n_quantiles validation
+ @test_throws Exception xtile([1.0, 2.0, 3.0], 0)
+ @test_throws Exception xtile([1.0, 2.0, 3.0], -1)
+
end