Filtering a QGraph
by: Chris Cave & Meera Machado
(Feyn version 1.4 or newer)
Assume you have the following dataset (here we generate one). We also assume you already know how to access your QLattice, and this code assumes you use a configuration file.
from sklearn.datasets import make_classification
import pandas as pd
from feyn.tools import split
# Generate a dataset and put it into a dataframe
X, y = make_classification()
data = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
data['target'] = y
max_depth
parameter
The When extracting a QGraph
from the QLattice
, you can control the depth of the graphs in the QGraph
by setting the max_depth
parameter:
import feyn
qlattice = feyn.QLattice()
# Capping the max graph depths to 3 - this will contain graphs of depths 1 through 3
qgraph = qlattice.get_regressor(data.columns, 'target', max_depth=3)
The benefit of reducing the max_depth
is that it decreases the number of features being used and produces a simpler model. Whereas increasing the max_depth
allows for more complex model architectures. Having a smaller max_depth
allows the user to answer the question: what is the simplest model that could describe the relation between the input and output variables?
Filters
Aside from controlling the max_depth
of the graphs in the QGraph
, one can set other conditions to these models by using the QGraph.filter
function. In other words, the filter ensures that the only graphs being trained are the ones that satisfy the condition(s) imposed by the filter. It should be noted that when the filter function
is called it does not modify the original QGraph
, rather it returns a new QGraph
object. Below are some filter examples.
Using a filtered QGraph
Once the QGraph
is filtered the workflow proceeds as normal. For how to use some of this in practice, refer to our workflow on formulating hypotheses.
MaxDepth and Depth
This filter works in a similar manner to the max_depth
parameter above. However, it ensures that the graphs' depths are equal to the set value.
qgraph_filtered = qgraph.filter(feyn.filters.Depth(3)) # Now all graphs in the QGraph have exactly depth 3.
qgraph_filtered.head()
The MaxDepth filter
works in exactly the same manner as the max_depth
parameter in the QLattice.get_qgraph
function.
qgraph_filtered = qgraph.filter(feyn.filters.MaxDepth(2)) # Now all graphs are capped at exactly depth 2.
MaxEdges and Edges
One can also set the number of edges of the graphs. By using MaxEdges
, the graphs will be limited up to a maximum number of edges, while Edges
fixes that amount.
qgraph_filtered = qgraph.filter(feyn.filters.MaxEdges(9)) # Number of edges capped at 9
Or
qgraph_filtered = qgraph.filter(feyn.filters.Edges(12)) # Only graphs with 12 edges are allowed
qgraph_filtered.head()
Contains
Another built-in filter option, Contains
permits the user to focus on graphs with a chosen feature. Consider we have dataset with features x1, x2, x3 and x4. If one wishes to select models which certainly include x3, the Contains
option should be used.
qgraph_filtered = qgraph.filter(feyn.filters.Contains('x3')) # All graphs will have the *x3* feature.
qgraph_filtered.head()
Multiple filters
Moreover, it is possible to apply multiple filters. For example, if we want to find a simple model with a particular feature then we can combine the MaxDepth
and Contains
filter.
# Each graph has max depth 3 and contains the x3 feature.
qgraph_filtered = qgraph.filter(feyn.filters.MaxDepth(3)) \
.filter(feyn.filters.Contains('x3'))
filter
Make your own Lastly, one can make their own filter. It should be created as a class
with a __call__
and an __init__
function. Refer to the examples below, where we create the filter ContainsOR
.
class ContainsOR(feyn.filters.QGraphFilter): # Should always be the class input
def __init__(self, features): # To call the function: ContainsOR(features)
self.features = features
def __call__(self, graphs:feyn.filters.List[feyn._graph.Graph]): # Should always be the __call__ input
filtered_list = [g for g in graphs if any(list(map(lambda x : x in g, self.features)))]
return filtered_list
# Usage
qgraph_filtered = qgraph.filter(ContainsOR(['x1', 'x3'])) # Only graphs that contain *x1* or *x3* are allowed.
The ContainsOR
filter above takes a list of features as input and returns a QGraph
whose graphs contain at least one of these features.