by: Chris Cave & Meera Machado
(Feyn version 1.4 or newer)
Assume you have the following dataset (here we generate one). We also assume you already know how to access your QLattice, and this code assumes you use a configuration file.
from sklearn.datasets import make_classification import pandas as pd from feyn.tools import split # Generate a dataset and put it into a dataframe X, y = make_classification() data = pd.DataFrame(X, columns=[str(i) for i in range(X.shape)]) data['target'] = y
When extracting a
QGraph from the
QLattice, you can control the depth of the graphs in the
QGraph by setting the
import feyn qlattice = feyn.QLattice() # Capping the max graph depths to 3 - this will contain graphs of depths 1 through 3 qgraph = qlattice.get_regressor(data.columns, 'target', max_depth=3)
The benefit of reducing the
max_depth is that it decreases the number of features being used and produces a simpler model. Whereas increasing the
max_depth allows for more complex model architectures. Having a smaller
max_depth allows the user to answer the question: what is the simplest model that could describe the relation between the input and output variables?
Aside from controlling the
max_depth of the graphs in the
QGraph, one can set other conditions to these models by using the
QGraph.filter function. In other words, the filter ensures that the only graphs being trained are the ones that satisfy the condition(s) imposed by the filter. It should be noted that when the
filter function is called it does not modify the original
QGraph, rather it returns a new
QGraph object. Below are some filter examples.
Using a filtered QGraph
QGraph is filtered the workflow proceeds as normal. For how to use some of this in practice, refer to our workflow on formulating hypotheses.
MaxDepth and Depth
This filter works in a similar manner to the
max_depth parameter above. However, it ensures that the graphs' depths are equal to the set value.
qgraph_filtered = qgraph.filter(feyn.filters.Depth(3)) # Now all graphs in the QGraph have exactly depth 3. qgraph_filtered.head()
MaxDepth filter works in exactly the same manner as the
max_depth parameter in the
qgraph_filtered = qgraph.filter(feyn.filters.MaxDepth(2)) # Now all graphs are capped at exactly depth 2.
MaxEdges and Edges
One can also set the number of edges of the graphs. By using
MaxEdges, the graphs will be limited up to a maximum number of edges, while
Edges fixes that amount.
qgraph_filtered = qgraph.filter(feyn.filters.MaxEdges(9)) # Number of edges capped at 9
qgraph_filtered = qgraph.filter(feyn.filters.Edges(12)) # Only graphs with 12 edges are allowed qgraph_filtered.head()
Another built-in filter option,
Contains permits the user to focus on graphs with a chosen feature. Consider we have dataset with features x1, x2, x3 and x4. If one wishes to select models which certainly include x3, the
Contains option should be used.
qgraph_filtered = qgraph.filter(feyn.filters.Contains('x3')) # All graphs will have the *x3* feature. qgraph_filtered.head()
Moreover, it is possible to apply multiple filters. For example, if we want to find a simple model with a particular feature then we can combine the
# Each graph has max depth 3 and contains the x3 feature. qgraph_filtered = qgraph.filter(feyn.filters.MaxDepth(3)) \ .filter(feyn.filters.Contains('x3'))
Make your own
Lastly, one can make their own filter. It should be created as a
class with a
__call__ and an
__init__ function. Refer to the examples below, where we create the filter
class ContainsOR(feyn.filters.QGraphFilter): # Should always be the class input def __init__(self, features): # To call the function: ContainsOR(features) self.features = features def __call__(self, graphs:feyn.filters.List[feyn._graph.Graph]): # Should always be the __call__ input filtered_list = [g for g in graphs if any(list(map(lambda x : x in g, self.features)))] return filtered_list # Usage qgraph_filtered = qgraph.filter(ContainsOR(['x1', 'x3'])) # Only graphs that contain *x1* or *x3* are allowed.
ContainsOR filter above takes a list of features as input and returns a
QGraph whose graphs contain at least one of these features.