import math
import re
import numpy as np
import polars as pl
from pydantic import PositiveInt
# Pre-compiled pattern for extracting column names from X["col"] patterns.
_FEATURE_PATTERN = re.compile(r'X\["([^"]+)"\]')
[docs]
def filter_rules_by_feature_overlap(
R: pl.DataFrame,
importance: dict[str, float],
min_difference: PositiveInt = 1,
rule_column: str = "rule",
) -> pl.DataFrame:
"""Filter out rules that are too similar based on column usage, keeping the most important.
Uses a greedy algorithm that processes rules sequentially. Note that this can result
in keeping rules that are transitively similar (A similar to B, B filtered out,
C similar to B but not to A, both A and C kept).
Rules with identical column sets are always considered similar regardless of
min_difference value (max one-sided difference = 0).
Parameters
----------
R : pl.DataFrame
DataFrame with a column containing rule strings (X["column_name"] patterns).
importance : dict
Dictionary mapping rule strings to their importance values.
Keys: rule strings matching those in R[rule_column]
Values: importance values for each rule (missing rules default to 0.0)
min_difference : PositiveInt, default=1
Minimum number of different columns required between two rules.
If two rules differ by fewer than this many columns, only the one
with highest importance is kept. Must be >= 1.
rule_column : str, default="rule"
Name of the column containing rule strings.
Returns
-------
pl.DataFrame
Filtered DataFrame with similar rules removed (keeping highest importance).
Examples
--------
>>> import polars as pl
>>> rules_X = pl.DataFrame({
... 'rule': ['(X["a"] > 1) & (X["b"] < 2)',
... '(X["a"] > 1) & (X["c"] < 3)',
... '(X["a"] > 1) & (X["b"] < 2)'],
... 'score': [0.9, 0.85, 0.8]
... })
>>> importance = {'(X["a"] > 1) & (X["b"] < 2)': 0.7,
... '(X["a"] > 1) & (X["c"] < 3)': 0.9}
>>> filter_rules_by_feature_overlap(rules_X, importance, min_difference=1)
"""
# Get the rule strings from the specified column
rules = R[rule_column].to_list()
if len(rules) <= 1:
return R
# Extract column names for each rule
rule_columns = []
for rule in rules:
cols = set(extract_feature_names_from_rule(rule))
rule_columns.append(cols)
# Track which indices to keep
indices_to_keep: list[int] = []
for i, cols_i in enumerate(rule_columns):
rule_i = rules[i]
importance_i = importance.get(rule_i, 0.0)
# Check if this rule is too similar to any previously kept rule
similar_index = None
for j in indices_to_keep:
cols_j = rule_columns[j]
# Calculate max one-sided difference (columns unique to each rule)
cols_only_in_i = cols_i - cols_j
cols_only_in_j = cols_j - cols_i
max_one_sided_diff = max(len(cols_only_in_i), len(cols_only_in_j))
# If the max one-sided difference is less than min_difference, they're too similar
if max_one_sided_diff < min_difference:
similar_index = j
break
if similar_index is not None:
# Rules are similar - compare importance values
rule_j = rules[similar_index]
importance_j = importance.get(rule_j, 0.0)
# Replace the kept rule if current one has higher importance
if importance_i > importance_j:
indices_to_keep.remove(similar_index)
indices_to_keep.append(i)
else:
# Not similar to any kept rule, so keep it
indices_to_keep.append(i)
return R[indices_to_keep]
[docs]
def select_best_rule_per_column_combination(
metrics: pl.DataFrame, ranking_metric: str = "precision"
) -> list[str]:
"""
Select the rule with the highest metric score for each unique column combination.
Parameters
----------
metrics : pl.DataFrame
DataFrame containing rule performance metrics. Must have a "rule" column
and the metric specified in ranking_metric.
ranking_metric : str, default="precision"
Name of the metric column to use for selecting the best rule in each group.
Returns
-------
list[str]
Filtered rules with only the best rule for each column combination.
Examples
--------
>>> metrics = pl.DataFrame({
... "rule": ['(X["a"] > 1)', '(X["a"] > 2)', '(X["b"] < 3)'],
... "precision": [0.95, 0.98, 0.96]
... })
>>> select_best_rule_per_column_combination(metrics, ranking_metric="precision")
# Returns the rule with highest precision for column "a" and the rule for column "b"
"""
# Validate inputs
if "rule" not in metrics.columns:
raise ValueError("metrics DataFrame must contain a 'rule' column")
if ranking_metric not in metrics.columns:
raise ValueError(f"ranking_metric metric '{ranking_metric}' not found in metrics columns")
# Extract rules and build rule → column-combination mapping
rules = metrics["rule"].to_list()
rule_to_columns: dict[str, tuple[str, ...]] = {}
for rule in rules:
cols = _FEATURE_PATTERN.findall(rule)
if cols:
rule_to_columns[rule] = tuple(sorted(set(cols)))
# Build column_combination series via list comprehension (faster than map_elements)
combo_series = pl.Series(
"column_combination",
[str(rule_to_columns.get(r, ())) for r in rules],
)
metrics_with_combo = metrics.with_columns(combo_series)
# Group by column combination and select the row with max ranking_metric value
best_rules = (
metrics_with_combo.sort(ranking_metric, descending=True)
.group_by("column_combination", maintain_order=True)
.first()
.drop("column_combination")
)
return best_rules["rule"].to_list()