Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dodiscover/replearning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .deci import DECI
from .gin import GIN
60 changes: 60 additions & 0 deletions dodiscover/replearning/deci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np
import pandas as pd
from dodiscover.base import BaseCausalDiscovery

class DECI(BaseCausalDiscovery):
"""Wrapper for Causica DECI algorithm.

DECI (Deep End-to-end Causal Inference) is a framework for causal discovery
and inference using neural networks.
"""

def __init__(self, model_params=None, training_params=None):
super().__init__()
self.model_params = model_params or {}
self.training_params = training_params or {}
self.model = None

def fit(self, X, y=None):
"""Fit the DECI model to the data.

Parameters
----------
X : pd.DataFrame or np.ndarray
The input data.
y : Ignored
Not used, present for API consistency.

Returns
-------
self : object
Returns the instance itself.
"""
try:
from causica.models.deci.deci import DECI as CausicaDECI
except ImportError:
raise ImportError(
"DECI requires the 'causica' package. "
"Please install it with `pip install causica`."
)

if isinstance(X, np.ndarray):
X = pd.DataFrame(X)

# Implementation details would go here
# This is a wrapper placeholder as per request
return self

def predict_graph(self):
"""Return the discovered causal graph.

Returns
-------
graph : array-like
The adjacency matrix of the discovered graph.
"""
if self.model is None:
raise ValueError("Model has not been fitted yet.")

# Implementation to extract graph from self.model
pass