diff --git a/src/stratum/query/plan.clj b/src/stratum/query/plan.clj index e183cef..98c3ae7 100644 --- a/src/stratum/query/plan.clj +++ b/src/stratum/query/plan.clj @@ -311,6 +311,138 @@ p' (recur p' (inc n)))))) +;; ============================================================================ +;; Pass 1.5: Join reordering (DP-based) +;; ============================================================================ + +(defn- extract-join-chain + "Extract a left-deep chain of LJoin nodes into [fact-node, dims-vec]. + dims-vec is in join order (innermost = first joined)." + [node] + (loop [n node, dims-rev []] + (if (instance? LJoin n) + (recur (:left n) + (conj dims-rev {:node (:right n) + :on-pairs (:on-pairs n) + :join-type (:join-type n)})) + [n (vec (rseq dims-rev))]))) + +(defn- estimate-node-rows + "Estimate the output row count of a logical plan node." + ^long [node] + (cond + (instance? LScan node) + (long (:length node)) + + (instance? LFilter node) + (let [child (:input node)] + (if (instance? LScan child) + (est/estimate-output-rows (:predicates node) (:columns child) (:length child)) + (long (or (:length child) 1000000)))) + + :else + (long (or (:length node) 1000000)))) + +(defn- dim-join-selectivity + "Estimate what fraction of fact rows survive joining with a dim. + For FK→PK: selectivity ≈ filtered_dim_rows / total_dim_rows." + ^double [{:keys [node]}] + (let [total (double + (cond + (instance? LFilter node) + (let [inner (:input node)] + (if (instance? LScan inner) (:length inner) (estimate-node-rows node))) + (instance? LScan node) (:length node) + :else (estimate-node-rows node))) + filtered (double (estimate-node-rows node))] + (if (zero? total) 1.0 (min 1.0 (/ filtered total))))) + +(defn- compute-join-deps + "For each dim, compute the set of dim indices it depends on. + Dim j depends on dim i if j's probe-side join key comes from dim i (not from fact). + This handles snowflake schemas where dim2 joins on a column from dim1." + [fact-cols dims] + (let [n (count dims) + dim-cols (mapv (fn [d] (or (scan-columns (:node d)) #{})) dims)] + (mapv (fn [j] + (let [probe-keys (set (map first (:on-pairs (nth dims j))))] + (into #{} + (keep (fn [i] + (when (and (not= i j) + (some (fn [k] + (and (not (contains? fact-cols k)) + (contains? (nth dim-cols i) k))) + probe-keys)) + i))) + (range n)))) + (range n)))) + +(defn- dp-join-order + "DP join ordering for dim tables. Returns optimal index ordering. + Minimizes total probe cost = sum of rows entering each join step. + Respects column dependencies (snowflake schemas)." + [^long fact-rows fact-cols dims] + (let [n (count dims) + sels (mapv dim-join-selectivity dims) + deps (compute-join-deps fact-cols dims)] + (if (<= n 1) + (vec (range n)) + (let [full (dec (bit-shift-left 1 n)) + dp (object-array (inc full))] + (aset dp 0 {:cost 0.0 :order [] :rows (double fact-rows)}) + (doseq [mask (range 1 (inc full))] + (let [best (reduce + (fn [best i] + (if (zero? (bit-and mask (bit-shift-left 1 i))) + best + (let [prev-mask (bit-xor mask (bit-shift-left 1 i)) + i-deps (nth deps i)] + ;; Check deps: all dependencies must be in prev-mask + (if (not (every? #(pos? (bit-and prev-mask (bit-shift-left 1 %))) i-deps)) + best + (let [prev (aget dp (int prev-mask))] + (when prev + (let [new-cost (+ (double (:cost prev)) (double (:rows prev))) + new-rows (* (double (:rows prev)) (double (nth sels i)))] + (if (or (nil? best) (< new-cost (double (:cost best)))) + {:cost new-cost :order (conj (:order prev) i) :rows new-rows} + best)))))))) + nil + (range n))] + (aset dp (int mask) best))) + (or (:order (aget dp (int full))) + (vec (range n))))))) + +(defn- rebuild-join-chain + "Rebuild a left-deep join tree from fact-node and reordered dims." + [fact-node dims order] + (reduce + (fn [left idx] + (let [{:keys [node on-pairs join-type]} (nth dims idx)] + (ir/->LJoin join-type on-pairs left node))) + fact-node + order)) + +(defn join-reorder + "Reorder left-deep INNER join chains for optimal execution cost. + Uses DP to minimize total probe cost. Puts most selective dims first. + Respects column dependencies for snowflake schemas. + Only reorders chains of all-INNER joins (outer joins stay in place)." + [plan] + (ir/walk-plan plan + (fn [node] + (if (instance? LJoin node) + (let [[fact-node dims] (extract-join-chain node)] + (if-let [fact-cols (and (>= (count dims) 2) + (every? #(= :inner (:join-type %)) dims) + (scan-columns fact-node))] + (let [order (dp-join-order (estimate-node-rows fact-node) fact-cols dims)] + (if (= order (vec (range (count dims)))) + node + (rebuild-join-chain fact-node dims order))) + node)) + node)))) + ;; ============================================================================ ;; Pass 2: Zone-map annotation ;; ============================================================================ @@ -850,6 +982,137 @@ :else node)))) +;; ============================================================================ +;; Pass 6: Statistics propagation +;; ============================================================================ + +(defn- propagate-est-rows + "Estimate output rows for a physical node based on its children's estimates." + ^long [node] + (cond + ;; Leaf scans + (or (instance? PScan node) (instance? PChunkedScan node)) + (long (:length node)) + + ;; Filters reduce rows + (or (instance? PSIMDFilter node) (instance? PMaskFilter node)) + (let [child-rows (long (or (::estimated-rows (meta (:input node))) + (:length (:input node)) + 1000000)) + columns (when-let [inp (:input node)] + (:columns inp)) + preds (:predicates node)] + (if (and columns (seq preds)) + (est/estimate-output-rows preds columns child-rows) + child-rows)) + + ;; Joins: FK→PK heuristic + (instance? PHashJoin node) + (let [probe-rows (long (or (::estimated-rows (meta (:probe-side node))) + (:length (:probe-side node)) + 1000000)) + build-total (long (or (:length (:build-side node)) 1000000)) + build-rows (long (or (::estimated-rows (meta (:build-side node))) + build-total)) + sel (min 1.0 (/ (double build-rows) (max 1.0 (double build-total))))] + (max 1 (long (* probe-rows sel)))) + + ;; Global agg always produces 1 row + (or (instance? PStatsOnlyAgg node) + (instance? PFusedSIMDAgg node) (instance? PFusedSIMDCount node) + (instance? PChunkedSIMDAgg node) (instance? PChunkedSIMDCount node) + (instance? PBlockSkipCount node) (instance? PFusedMultiSum node) + (instance? PPercentileAgg node) (instance? PScalarAgg node) + (instance? PFusedJoinGlobalAgg node)) + 1 + + ;; Fused join+group: use metadata if present + (instance? PFusedJoinGroupAgg node) + (long (or (::estimated-rows (meta node)) 1000)) + + ;; Pass-through: inherit from child + :else + (long (or (::estimated-rows (meta (ir/input-node node))) + (when-let [c (ir/input-node node)] (:length c)) + 1000000)))) + +(defn stats-propagation + "Propagate estimated row counts through the physical plan tree. + Attaches ::estimated-rows metadata to each node." + [plan] + (ir/walk-plan plan + (fn [node] + (let [est (propagate-est-rows node)] + (vary-meta node assoc ::estimated-rows est))))) + +;; ============================================================================ +;; Pass 7: Column pruning +;; ============================================================================ + +(defn- collect-all-refs + "Collect all column keywords referenced anywhere in the plan tree." + [plan] + (let [refs (volatile! (transient #{}))] + (ir/walk-plan plan + (fn [node] + ;; Predicates + (when-let [preds (:predicates node)] + (doseq [p preds + c (pred-columns p)] + (vswap! refs conj! c))) + ;; Group keys + (when-let [gks (:group-keys node)] + (doseq [gk gks] (when (keyword? gk) (vswap! refs conj! gk)))) + ;; Aggs + (when-let [aggs (:aggs node)] + (doseq [a aggs] + (when-let [c (:col a)] (vswap! refs conj! c)) + (when-let [cs (:cols a)] (doseq [c cs] (when (keyword? c) (vswap! refs conj! c)))))) + (when-let [agg (:agg node)] + (when-let [c (:col agg)] (vswap! refs conj! c))) + ;; Project items + (when-let [items (:items node)] + (doseq [item items] + (when-let [r (:ref item)] (vswap! refs conj! r)) + (when-let [n (:name item)] (vswap! refs conj! n)))) + ;; Join on-pairs + (when-let [on (:on-pairs node)] + (doseq [[l r] on] (vswap! refs conj! l) (vswap! refs conj! r))) + (when-let [js (:join-spec node)] + (doseq [[l r] (:on-pairs js)] (vswap! refs conj! l) (vswap! refs conj! r))) + ;; Window specs + (when-let [specs (:specs node)] + (doseq [s specs] + (when-let [c (:col s)] (vswap! refs conj! c)) + (when-let [pb (:partition-by s)] (doseq [c pb] (vswap! refs conj! c))) + (when-let [ob (:order-by s)] (doseq [[c _] ob] (vswap! refs conj! c))))) + ;; Sort + (when-let [os (:order-specs node)] + (doseq [[c _] os] (vswap! refs conj! c))) + ;; Extract + (when-let [ec (:extract-col node)] + (vswap! refs conj! ec)) + ;; Materialized expressions + (when-let [e (:expr node)] + (when (map? e) + (doseq [a (:args e)] (when (keyword? a) (vswap! refs conj! a))))) + node)) + (persistent! @refs))) + +(defn column-pruning + "Remove unreferenced columns from scan nodes to reduce memory and I/O." + [plan] + (let [all-refs (collect-all-refs plan)] + (ir/walk-plan plan + (fn [node] + (if (or (instance? PScan node) (instance? PChunkedScan node)) + (let [cols (:columns node) + pruned (select-keys cols all-refs)] + (if (< (count pruned) (count cols)) + (assoc node :columns pruned) + node)) + node))))) + ;; ============================================================================ ;; Composite: full optimization pipeline ;; ============================================================================ @@ -860,10 +1123,13 @@ (-> plan annotate predicate-pushdown + join-reorder zone-map-annotation expr-materialization strategy-selection - operator-fusion)) + operator-fusion + stats-propagation + column-pruning)) ;; ============================================================================ ;; Plan explanation (for debugging and EXPLAIN output)