Feyn

Feyn

  • Tutorials
  • Guides
  • API Reference
  • FAQ

›Tutorials

Tutorials

  • Covid-19 vaccination RNA dataset.
  • Titanic: A general binary classification case

Titanic: A general binary classification case

by: Meera Machado & Chris Cave

Feyn version: 1.4.+

In this tutorial, we'll be using Feyn and the QLattice to solve a binary classification problem by exploring models that aim to predict the probability of surviving the disaster of the RMS Titanic during her maiden voyage in April of 1912.

import numpy as np
import pandas as pd

import feyn
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
%matplotlib inline

Importing dataset and quick check-up

The Titanic passenger dataset was acquired through the data.world and Encyclopedia Titanica websites.

df = pd.read_csv('../data/titanic.csv')
print("hello")
df.head()
pclass survived name sex age sibsp parch ticket fare cabin embarked boat body home.dest
0 1 1 Allen, Miss. Elisabeth Walton female 29.0000 0 0 24160 211.3375 B5 S 2 NaN St Louis, MO
1 1 1 Allison, Master. Hudson Trevor male 0.9167 1 2 113781 151.5500 C22 C26 S 11 NaN Montreal, PQ / Chesterville, ON
2 1 0 Allison, Miss. Helen Loraine female 2.0000 1 2 113781 151.5500 C22 C26 S NaN NaN Montreal, PQ / Chesterville, ON
3 1 0 Allison, Mr. Hudson Joshua Creighton male 30.0000 1 2 113781 151.5500 C22 C26 S NaN 135.0 Montreal, PQ / Chesterville, ON
4 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female 25.0000 1 2 113781 151.5500 C22 C26 S NaN NaN Montreal, PQ / Chesterville, ON

Dealing with missing data

# Checking which columns have nan values:
df.isna().any()
pclass       False
survived     False
name         False
sex          False
age           True
sibsp        False
parch        False
ticket       False
fare         False
cabin         True
embarked     False
boat          True
body          True
home.dest     True
dtype: bool

Among all the features containing NaN values, age is the one of most interest.

df[df.age.isna()]
pclass survived name sex age sibsp parch ticket fare cabin embarked boat body home.dest
816 3 0 Gheorgheff, Mr. Stanio male NaN 0 0 349254 7.8958 NaN C NaN NaN NaN
940 3 0 Kraeff, Mr. Theodor male NaN 0 0 349253 7.8958 NaN C NaN NaN NaN

Note that the only gentlemen whose ages are missing share the same feature values (with exception of ticket). In this case, we shall take a simple approach to guessing their ages: a random number between ⟨x⟩−σx\langle x \rangle - \sigma_x⟨x⟩−σx​ and ⟨x⟩+σx\langle x \rangle + \sigma_x⟨x⟩+σx​, where ⟨x⟩\langle x \rangle⟨x⟩ and σx\sigma_xσx​ are, respectively, the mean age and standard deviation of all people sharing the same feature values.

age_dist = df[(df.pclass == 3) & (df.embarked == 'C') & (df.sex == 'male') & 
              (df.sibsp == 0) & (df.parch == 0) & (df.survived == 0)].age.dropna()

mean_age = np.mean(age_dist)
std_age = np.std(age_dist)

np.random.seed(42)
age_guess = np.random.normal(mean_age, std_age, size=2)
# In a simple manner, we drop some features which could be irrelevant (at first look) 
df_mod = df.drop(['boat', 'body', 'home.dest', 'name', 'ticket', 'cabin'], axis=1)
df_mod.loc[df[df.age.isna()].index, 'age'] = age_guess

Training session

Splitting data in train, validation and holdout sets

We wish to make a prediction on the probability of surviving the Titanic sinking, so we set survived to be our target variable.

target = 'survived'

# Train and test
train, test = train_test_split(df_mod, test_size=0.4, random_state=42, stratify=df_mod[target])

# Validation and holdout:
valid, hold = train_test_split(test, test_size=0.4, stratify=test[target], random_state=42)

Pre-training time:

First we connect to a QLattice through its unique url and API token. Since this is a local domain, no url or token are necessary. A QLattice can also be reset of all its learnings, if the user wishes to begin from a clean slate.

ql = feyn.QLattice() # Connecting
ql.reset() # Resetting

Each feature from the dataset including the target variable interacts with the QLattice through Registers. These are divided into input and output and accommodate numerical and categorical types of variables.
In the following cell, we assign the features of train that are categorical to what we call semantic types, or stypes. By making this mapping between the columns and their type, the QLattice knows to use a categorical register. The default mapping is numerical and doesn't need to be stated, so only the ones that are categorical (categorical, cat or c) are necessary to specify.

stypes = {}

for var in train.columns:
    if train[var].dtype == np.object:
        stypes[var] = 'c'
        
stypes['pclass'] = 'c' # Making sure that pclass, which is of int type, is considered a categorical variable

After naming the Registers it is possible to extract a QGraph from the QLattice. A QGraph has the input and output registers as parameters along with the max_depth option. The latter determines the maximum size of a Graph inside the QGraph.

qgraph = ql.get_classifier(train.columns, target, max_depth=3, stypes=stypes)

It should be noted that get_classifier means that every graph in the QGraph has a logistic function in the output cell.

Simply put, a QGraph is a collection of Graphs:

qgraph.head(3)
Loss: NAN(7, -1, -1) embarked categorical with 0 values bias: -0.044227embarked0cat(14, 6, 3) 1/xinverse1x0(15, 7, 2) survived logistic: w=0.3354 bias=0.3641survived2out Loss: NAN(4, -1, -1) sibsp linear: scale= nan w=-0.211243 bias=1.6443sibsp0num(7, -1, -1) embarked categorical with 0 values bias: 0.268403embarked1cat(14, 0, 0) cell:gaussian(i,i)->igaussian2x0x1(15, 1, 0) tanhtanh3x0(16, 1, 7) tanhtanh4x0(17, 2, 7) survived logistic: w=-0.4740 bias=-0.4119survived5out Loss: NAN(7, -1, -1) embarked categorical with 0 values bias: -0.299236embarked0cat(5, -1, -1) parch linear: scale= nan w=-1.252701 bias=1.8539parch1num(1, 1, 4) cell:multiply(i,i)->imultiply2x0x1(2, 2, 4) survived logistic: w=-0.7319 bias=0.5917survived3out

And each of these Graphs is in fact a possible model for the problem at hand.

It should be stressed that no training has taken place yet. As one may have noticed, aside from the column names, no data has actually been fed into the QLattice!

Asking questions

This is where we diverge from traditional machine learning frameworks.

Akin to the scientific method, the first thing that we do is to ask a question based on the data. Then the QLattice serves as a hypotheses generator to the question asked.

We do know that women were more likely to survive the Titantic diaster:

train.groupby('sex').agg({'survived': 'mean'})
survived
sex
female 0.719858
male 0.192843

Let's use this to formulate our first question to the QLattice.

Question 1: How does sex predict survival?

So we now have a question and we want to use the QLattice to generate hypotheses for it.

This occurs in the following steps:

  1. Extract a QGraph (as exemplified above);
  2. Fit the training data to the Graphs from the QGraph. One may:
    • Choose a loss function between squared_error, absolute_error and categorical_cross_entropy;
    • Set a number of threads for fitting the Graphs;
    • Choose what should be displayed while training;
  3. Update the QLattice with the best Graphs;
  4. Repeat the process from step 2: now more Graphs in the QGraph have similar architecture to the best ones from qgraph.best()
# Defining some parameters:
nloops = 20
loss_function = feyn.losses.squared_error

qgraph = ql.get_classifier(['sex'], target, max_depth=1, stypes=stypes) # (1)

# And training begins
for loop in range(nloops): # (4)
    qgraph.fit(train, loss_function=loss_function, threads=4, show='graph') # (2)  
    ql.update(qgraph.best()) # (3)
Loss: 1.72E-01Fitting 20: 100% completed. Best loss so far: 0.171682(2, -1, -1) sex categorical with 2 values bias: 0.462881sex0cat(7, 8, 2) cell:gaussian(i,i)->igaussian1x0x1(8, 7, 3) survived logistic: w=2.4005 bias=-1.4553survived2out

The model above is the one that returns the smallest mean squared error loss during fitting.

best_graph = qgraph[0]

Categorial variables are assigned weights which are learned during the fitting loop.

best_graph[0].state.categories
[('female', -0.4544159291904795), ('male', 0.618932067010086)]

Model evaluation

It is now time to verify the survival probabilities predicted by the top graph above, so we employ a set of metrics to evaluate the chosen model:

Confusion matrix

How many people were correctly classified as survivors (True) and victims (False)?

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(121)
best_graph.plot_confusion_matrix(train, ax=ax, title='training set')
ax = fig.add_subplot(122)
best_graph.plot_confusion_matrix(valid, ax=ax, title='validation set')

svg

print('Overall training accuracy: %.4f' %best_graph.accuracy_score(train))
print('Overall validation accuracy: %.4f' %best_graph.accuracy_score(valid))
Overall training accuracy: 0.7758
Overall validation accuracy: 0.7643

Histograms of the probability scores

The idea here consists of checking the probability distribution scores for the positive (survivors) and negative (non-survivors) classes.

fig = plt.figure(figsize=(20, 6))
ax = fig.add_subplot(131)
best_graph.plot_probability_scores(train, title='Total probability distribution', ax=ax)
ax = fig.add_subplot(132)
best_graph.plot_probability_scores(train.query('sex == "female"'), title='Probability distribution for females', ax=ax)
ax = fig.add_subplot(133)
best_graph.plot_probability_scores(train.query('sex == "male"'), title='Probability distribution for males', ax=ax)
plt.show()

svg

The histograms above give an initial idea of model performance through the probability score distributions. Our naive approach to this problem is to predict all women as survivors and all men as victims. Notice how their probability scores are close to 0.72 for females and 0.19 for males, which correspond to their respective mean survival values.

Receiving Operating Characteristic (ROC) curve

Yet another common metric for evaluating binary classification models. It helps on setting an optimal threshold for the probability scores. Additionally, it can be an auxiliary tool in model selection.

plt.figure(figsize=(8, 5))
best_graph.plot_roc_curve(train, label='train')
best_graph.plot_roc_curve(valid, label='valid')

svg

This simple model can serve as a baseline for our next investigations. We do know that women died and men survived the Titanic. Can we find them?

"Women and children first!" Shall we explore age as well?

Question 2: How does sex combine with age to predict survival?

ql.reset()
qgraph = ql.get_classifier(['sex', 'age'], target, max_depth=1, stypes=stypes)
for _ in range(20):
    qgraph.fit(train, threads=4)
    ql.update(qgraph.best())
Loss: 1.69E-01Fitting 20: 100% completed. Best loss so far: 0.168662(0, -1, -1) sex categorical with 2 values bias: -0.472578sex0cat(1, -1, -1) age linear: scale=0.025105 w=-1.086806 bias=-0.4080age1num(19, 3, 3) cell:multiply(i,i)->imultiply2x0x1(0, 4, 2) survived logistic: w=-0.6724 bias=-0.1436survived3out
best_graph_age_sex = qgraph[0]

We can expect a wider range of probability scores compared to the previous model because the graph has taken age as an input feature.

fig = plt.figure(figsize=(20, 6))
ax = fig.add_subplot(131)
best_graph_age_sex.plot_probability_scores(train, title='Total probability distribution', ax=ax)
ax = fig.add_subplot(132)
best_graph_age_sex.plot_probability_scores(train.query('sex == "female"'), title='Probability distribution for females', ax=ax)
ax = fig.add_subplot(133)
best_graph_age_sex.plot_probability_scores(train.query('sex == "male"'), title='Probability distribution for males', ax=ax)
plt.show()

svg

Although the graph predicts that all women survive and all men did not, we see that age has added extra nuances to this. Let's check their partial plots

best_graph_age_sex.plot_partial(train, by='age')

svg

Women's probability scores already begin above 0.5 and increase with age while men's probability scores begin under 0.5 and decrease with age.

Let's compare the ROC curves between these graphs:

plt.figure(figsize=(8, 5))
best_graph.plot_roc_curve(train, label='Sex only')
best_graph_age_sex.plot_roc_curve(train, label='Sex and age')

svg

The AUC score has improved slightly but we can probably do better. Let's investigate more

Rich people are privileged. Shall we see if pclass improve on this?

Question 3: How does sex, age and pclass affect the survival probability?

ql.reset()
qgraph = ql.get_classifier(['sex', 'age', 'pclass'], target, stypes=stypes, max_depth=2)
for _ in range(30):
    qgraph.fit(train, threads = 4)
    ql.update(qgraph.best())
Loss: 1.49E-01Fitting 30: 100% completed. Best loss so far: 0.149388(2, -1, -1) pclass categorical with 3 values bias: 0.701380pclass0cat(0, -1, -1) sex categorical with 2 values bias: -0.791103sex1cat(1, -1, -1) age linear: scale=0.025105 w=-3.552240 bias=-1.6524age2num(18, 1, 4) cell:add(i,i)->iadd3x0x1(18, 1, 5) cell:multiply(i,i)->imultiply4x0x1(19, 0, 4) cell:multiply(i,i)->imultiply5x0x1(0, 1, 4) survived logistic: w=-2.2214 bias=0.5998survived6out
best_graph_age_sex_pclass = qgraph[0]

Here let's go straight to the ROC-curve

plt.figure(figsize=(8, 5))
best_graph.plot_roc_curve(train, label='Sex only')
best_graph_age_sex.plot_roc_curve(train, label='Sex and age')
best_graph_age_sex_pclass.plot_roc_curve(train, label='Sex, age and pclass')

svg

best_graph_age_sex_pclass.plot_partial(train, by='age', fixed={'sex': 'female', 'pclass': [1,2,3]})

svg

For women we see that the same behaviour as before is present for the 1st\mathrm{1^{st}}1st and 2nd\mathrm{2^{nd}}2nd classes: as age increases, so does their chance of survival. However, the opposite happens for women in the 3rd\mathrm{3^{rd}}3rd class! As their age increases, their chance of survival decreases.

best_graph_age_sex_pclass.plot_partial(train, by='age', fixed={'sex': 'male', 'pclass': [1,2,3]})

svg

The trend for men remains the same throughout the classes: as age increases, the chance of survival decreases. However, young men in the 1st\mathrm{1^{st}}1st and 2nd\mathrm{2^{nd}}2nd classes have a probability score above 0.5.

best_graph_age_sex_pclass.plot_partial2d(train, fixed={'sex': 'female'})
best_graph_age_sex_pclass.plot_partial2d(train, fixed={'sex': 'male'})

svg

svg

The top plot represents the female passengers while the bottom plot represents the male passengers. This is another way of visualizing the impact of age, pclass and sex on the chances of survival. The plot_partial2d allows us to see the boundaries traced by the functions in the model.


Suggestions for next steps

Now that we went through the first exploration of Feyn to a binary classification problem, there are a few extra things one could try:

  1. Increasing the complexity of the graphs in the QGraph by allowing an extra feature;
  2. Separating females from males and checking whether simple models (max_depth <= 2) can predict who survives;
  3. Asking questions about the impact on survival when a passenger had family onboard (parch or sibsp);

It is interesting to note from the Titanic dataset that no matter how much one trains a model, there is always a factor of luck which cannot be predicted. From a 3rd class man who pretended to be a woman to get onto a lifeboat to a 1st class woman who chose to stay with her husband aboard the Titanic: humans are fully capable of bending their fate.

← Covid-19 vaccination RNA dataset.
  • Importing dataset and quick check-up
    • Dealing with missing data
  • Training session
    • Splitting data in train, validation and holdout sets
    • Pre-training time:
    • Asking questions
    • Question 1: How does sex predict survival?
  • Model evaluation
    • Confusion matrix
    • Histograms of the probability scores
    • Receiving Operating Characteristic (ROC) curve
  • Question 2: How does sex combine with age to predict survival?
  • Question 3: How does sex, age and pclass affect the survival probability?
  • Suggestions for next steps
Copyright © 2021 Abzu.ai