Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 267 additions & 1 deletion src/stratum/query/plan.clj
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,138 @@
p'
(recur p' (inc n))))))

;; ============================================================================
;; Pass 1.5: Join reordering (DP-based)
;; ============================================================================

(defn- extract-join-chain
"Extract a left-deep chain of LJoin nodes into [fact-node, dims-vec].
dims-vec is in join order (innermost = first joined)."
[node]
(loop [n node, dims-rev []]
(if (instance? LJoin n)
(recur (:left n)
(conj dims-rev {:node (:right n)
:on-pairs (:on-pairs n)
:join-type (:join-type n)}))
[n (vec (rseq dims-rev))])))

(defn- estimate-node-rows
"Estimate the output row count of a logical plan node."
^long [node]
(cond
(instance? LScan node)
(long (:length node))

(instance? LFilter node)
(let [child (:input node)]
(if (instance? LScan child)
(est/estimate-output-rows (:predicates node) (:columns child) (:length child))
(long (or (:length child) 1000000))))

:else
(long (or (:length node) 1000000))))

(defn- dim-join-selectivity
"Estimate what fraction of fact rows survive joining with a dim.
For FK→PK: selectivity ≈ filtered_dim_rows / total_dim_rows."
^double [{:keys [node]}]
(let [total (double
(cond
(instance? LFilter node)
(let [inner (:input node)]
(if (instance? LScan inner) (:length inner) (estimate-node-rows node)))
(instance? LScan node) (:length node)
:else (estimate-node-rows node)))
filtered (double (estimate-node-rows node))]
(if (zero? total) 1.0 (min 1.0 (/ filtered total)))))

(defn- compute-join-deps
"For each dim, compute the set of dim indices it depends on.
Dim j depends on dim i if j's probe-side join key comes from dim i (not from fact).
This handles snowflake schemas where dim2 joins on a column from dim1."
[fact-cols dims]
(let [n (count dims)
dim-cols (mapv (fn [d] (or (scan-columns (:node d)) #{})) dims)]
(mapv (fn [j]
(let [probe-keys (set (map first (:on-pairs (nth dims j))))]
(into #{}
(keep (fn [i]
(when (and (not= i j)
(some (fn [k]
(and (not (contains? fact-cols k))
(contains? (nth dim-cols i) k)))
probe-keys))
i)))
(range n))))
(range n))))

(defn- dp-join-order
"DP join ordering for dim tables. Returns optimal index ordering.
Minimizes total probe cost = sum of rows entering each join step.
Respects column dependencies (snowflake schemas)."
[^long fact-rows fact-cols dims]
(let [n (count dims)
sels (mapv dim-join-selectivity dims)
deps (compute-join-deps fact-cols dims)]
(if (<= n 1)
(vec (range n))
(let [full (dec (bit-shift-left 1 n))
dp (object-array (inc full))]
(aset dp 0 {:cost 0.0 :order [] :rows (double fact-rows)})
(doseq [mask (range 1 (inc full))]
(let [best (reduce
(fn [best i]
(if (zero? (bit-and mask (bit-shift-left 1 i)))
best
(let [prev-mask (bit-xor mask (bit-shift-left 1 i))
i-deps (nth deps i)]
;; Check deps: all dependencies must be in prev-mask
(if (not (every? #(pos? (bit-and prev-mask (bit-shift-left 1 %))) i-deps))
best
(let [prev (aget dp (int prev-mask))]
(when prev
(let [new-cost (+ (double (:cost prev)) (double (:rows prev)))
new-rows (* (double (:rows prev)) (double (nth sels i)))]
(if (or (nil? best) (< new-cost (double (:cost best))))
{:cost new-cost :order (conj (:order prev) i) :rows new-rows}
best))))))))
nil
(range n))]
(aset dp (int mask) best)))
(or (:order (aget dp (int full)))
(vec (range n)))))))

(defn- rebuild-join-chain
"Rebuild a left-deep join tree from fact-node and reordered dims."
[fact-node dims order]
(reduce
(fn [left idx]
(let [{:keys [node on-pairs join-type]} (nth dims idx)]
(ir/->LJoin join-type on-pairs left node)))
fact-node
order))

(defn join-reorder
"Reorder left-deep INNER join chains for optimal execution cost.
Uses DP to minimize total probe cost. Puts most selective dims first.
Respects column dependencies for snowflake schemas.
Only reorders chains of all-INNER joins (outer joins stay in place)."
[plan]
(ir/walk-plan plan
(fn [node]
(if (instance? LJoin node)
(let [[fact-node dims] (extract-join-chain node)]
(if-let [fact-cols (and (>= (count dims) 2)
(every? #(= :inner (:join-type %)) dims)
(scan-columns fact-node))]
(let [order (dp-join-order (estimate-node-rows fact-node) fact-cols dims)]
(if (= order (vec (range (count dims))))
node
(rebuild-join-chain fact-node dims order)))
node))
node))))

;; ============================================================================
;; Pass 2: Zone-map annotation
;; ============================================================================
Expand Down Expand Up @@ -850,6 +982,137 @@

:else node))))

;; ============================================================================
;; Pass 6: Statistics propagation
;; ============================================================================

(defn- propagate-est-rows
"Estimate output rows for a physical node based on its children's estimates."
^long [node]
(cond
;; Leaf scans
(or (instance? PScan node) (instance? PChunkedScan node))
(long (:length node))

;; Filters reduce rows
(or (instance? PSIMDFilter node) (instance? PMaskFilter node))
(let [child-rows (long (or (::estimated-rows (meta (:input node)))
(:length (:input node))
1000000))
columns (when-let [inp (:input node)]
(:columns inp))
preds (:predicates node)]
(if (and columns (seq preds))
(est/estimate-output-rows preds columns child-rows)
child-rows))

;; Joins: FK→PK heuristic
(instance? PHashJoin node)
(let [probe-rows (long (or (::estimated-rows (meta (:probe-side node)))
(:length (:probe-side node))
1000000))
build-total (long (or (:length (:build-side node)) 1000000))
build-rows (long (or (::estimated-rows (meta (:build-side node)))
build-total))
sel (min 1.0 (/ (double build-rows) (max 1.0 (double build-total))))]
(max 1 (long (* probe-rows sel))))

;; Global agg always produces 1 row
(or (instance? PStatsOnlyAgg node)
(instance? PFusedSIMDAgg node) (instance? PFusedSIMDCount node)
(instance? PChunkedSIMDAgg node) (instance? PChunkedSIMDCount node)
(instance? PBlockSkipCount node) (instance? PFusedMultiSum node)
(instance? PPercentileAgg node) (instance? PScalarAgg node)
(instance? PFusedJoinGlobalAgg node))
1

;; Fused join+group: use metadata if present
(instance? PFusedJoinGroupAgg node)
(long (or (::estimated-rows (meta node)) 1000))

;; Pass-through: inherit from child
:else
(long (or (::estimated-rows (meta (ir/input-node node)))
(when-let [c (ir/input-node node)] (:length c))
1000000))))

(defn stats-propagation
"Propagate estimated row counts through the physical plan tree.
Attaches ::estimated-rows metadata to each node."
[plan]
(ir/walk-plan plan
(fn [node]
(let [est (propagate-est-rows node)]
(vary-meta node assoc ::estimated-rows est)))))

;; ============================================================================
;; Pass 7: Column pruning
;; ============================================================================

(defn- collect-all-refs
"Collect all column keywords referenced anywhere in the plan tree."
[plan]
(let [refs (volatile! (transient #{}))]
(ir/walk-plan plan
(fn [node]
;; Predicates
(when-let [preds (:predicates node)]
(doseq [p preds
c (pred-columns p)]
(vswap! refs conj! c)))
;; Group keys
(when-let [gks (:group-keys node)]
(doseq [gk gks] (when (keyword? gk) (vswap! refs conj! gk))))
;; Aggs
(when-let [aggs (:aggs node)]
(doseq [a aggs]
(when-let [c (:col a)] (vswap! refs conj! c))
(when-let [cs (:cols a)] (doseq [c cs] (when (keyword? c) (vswap! refs conj! c))))))
(when-let [agg (:agg node)]
(when-let [c (:col agg)] (vswap! refs conj! c)))
;; Project items
(when-let [items (:items node)]
(doseq [item items]
(when-let [r (:ref item)] (vswap! refs conj! r))
(when-let [n (:name item)] (vswap! refs conj! n))))
;; Join on-pairs
(when-let [on (:on-pairs node)]
(doseq [[l r] on] (vswap! refs conj! l) (vswap! refs conj! r)))
(when-let [js (:join-spec node)]
(doseq [[l r] (:on-pairs js)] (vswap! refs conj! l) (vswap! refs conj! r)))
;; Window specs
(when-let [specs (:specs node)]
(doseq [s specs]
(when-let [c (:col s)] (vswap! refs conj! c))
(when-let [pb (:partition-by s)] (doseq [c pb] (vswap! refs conj! c)))
(when-let [ob (:order-by s)] (doseq [[c _] ob] (vswap! refs conj! c)))))
;; Sort
(when-let [os (:order-specs node)]
(doseq [[c _] os] (vswap! refs conj! c)))
;; Extract
(when-let [ec (:extract-col node)]
(vswap! refs conj! ec))
;; Materialized expressions
(when-let [e (:expr node)]
(when (map? e)
(doseq [a (:args e)] (when (keyword? a) (vswap! refs conj! a)))))
node))
(persistent! @refs)))

(defn column-pruning
"Remove unreferenced columns from scan nodes to reduce memory and I/O."
[plan]
(let [all-refs (collect-all-refs plan)]
(ir/walk-plan plan
(fn [node]
(if (or (instance? PScan node) (instance? PChunkedScan node))
(let [cols (:columns node)
pruned (select-keys cols all-refs)]
(if (< (count pruned) (count cols))
(assoc node :columns pruned)
node))
node)))))

;; ============================================================================
;; Composite: full optimization pipeline
;; ============================================================================
Expand All @@ -860,10 +1123,13 @@
(-> plan
annotate
predicate-pushdown
join-reorder
zone-map-annotation
expr-materialization
strategy-selection
operator-fusion))
operator-fusion
stats-propagation
column-pruning))

;; ============================================================================
;; Plan explanation (for debugging and EXPLAIN output)
Expand Down
Loading