Feyn Documentation

Feyn Documentation

  • Learn
  • Guides
  • Tutorials
  • API Reference
  • FAQ

›Regression

Overview

  • Tutorials

Beginner

    Classification

    • Titanic survival
    • Pulsar stars
    • Poisonous Mushrooms

    Regression

    • Airbnb prices
    • Automobile MPG
    • Concrete strength

Advanced

    Regression

    • Wine Quality

Use cases

  • Rewriting models with correlated inputs
  • Complexity-Loss Trade-Off
  • Plotting the loss graph
  • Simple linear and logistic regression
  • Deploy a model for inference

Life Sciences

    Classification

    • Detecting Liver Cancer (HCC) in Plasma
    • Classifying toxicity of antisense oligonucleotides

    Regression

    • Covid-19 RNA vaccine degradation data set
    • Preventing the Honeybee Apocalypse (QSAR)

Interfacing with R

  • Classifying toxicity of antisense oligonucleotides

Archive

  • Covid-19 vaccination RNA dataset.

Covid-19 RNA vaccine degradation data set

by: Chris Cave, Kevin Broløs and Sam Demharter

Feyn version: 2.1+

Last updated: 23/09/2021

In this tutorial we are going to go through a typical QLattice workflow. We perform an analysis on the OpenVaccine: COVID-19 mRNA Vaccine Degradation Prediction dataset.

The raw dataset consists of 2400 mRNA samples. Each mRNA consists of 107 nucleotides and various measurments were performed on the first 68 nucleotides. These measurements consisted of reactivity, degradation at pH10 with and without magnesium, and degradation at 50oC50^o C50oC with and without magnesium.

Some of the Covid-19 vaccines are mRNA based. However due to the unstable nature of mRNA they must be refrigerated in extreme conditions. What this means is that distribution of the vaccine can be problematic.

The aim of this tutorial is to gain insights into the stability of general mRNA samples with the potential to apply it to Covid-19 vaccine candidates.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import feyn
from sklearn.model_selection import train_test_split

from IPython.display import display

Inspecting the raw data

Here we can see that the data we have is the sequence and the predicted structure and loop type of each base in the RNA. The feature reactivity measures the degradation at each base. The higher the reactivity the more likely the RNA is to degrade at that base.

data = pd.read_json('../data/covid_mrna.json', lines=True)
data = data.drop('index', axis=1)
data.query('SN_filter == 1', inplace=True)

length = len(data.iloc[0]['reactivity'])
first_68 = data['structure'].apply(lambda x : x[0: length])

# Remove sequences that only contain “.” i.e. unpaired bases
idx_all_dots = [i for i in first_68.index if first_68[i].count('.') == length]
data = data.drop(idx_all_dots)
data.head()
id sequence structure predicted_loop_type signal_to_noise SN_filter seq_length seq_scored reactivity_error deg_error_Mg_pH10 deg_error_pH10 deg_error_Mg_50C deg_error_50C reactivity deg_Mg_pH10 deg_pH10 deg_Mg_50C deg_50C
0 id_001f94081 GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUA... .....((((((.......)))).)).((.....((..((((((...... EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHH... 6.894 1 107 68 [0.1359, 0.20700000000000002, 0.1633, 0.1452, ... [0.26130000000000003, 0.38420000000000004, 0.1... [0.2631, 0.28600000000000003, 0.0964, 0.1574, ... [0.1501, 0.275, 0.0947, 0.18660000000000002, 0... [0.2167, 0.34750000000000003, 0.188, 0.2124, 0... [0.3297, 1.5693000000000001, 1.1227, 0.8686, 0... [0.7556, 2.983, 0.2526, 1.3789, 0.637600000000... [2.3375, 3.5060000000000002, 0.3008, 1.0108, 0... [0.35810000000000003, 2.9683, 0.2589, 1.4552, ... [0.6382, 3.4773, 0.9988, 1.3228, 0.78770000000...
2 id_006f36f57 GGAAAGUGCUCAGAUAAGCUAAGCUCGAAUAGCAAUCGAAUAGAAU... .....((((.((.....((((.(((.....)))..((((......)... EEEEESSSSISSIIIIISSSSMSSSHHHHHSSSMMSSSSHHHHHHS... 8.800 1 107 68 [0.0931, 0.13290000000000002, 0.11280000000000... [0.1365, 0.2237, 0.1812, 0.1333, 0.1148, 0.160... [0.17020000000000002, 0.178, 0.111, 0.091, 0.0... [0.1033, 0.1464, 0.1126, 0.09620000000000001, ... [0.14980000000000002, 0.1761, 0.1517, 0.116700... [0.44820000000000004, 1.4822, 1.1819, 0.743400... [0.2504, 1.4021, 0.9804, 0.49670000000000003, ... [2.243, 2.9361, 1.0553, 0.721, 0.6396000000000... [0.5163, 1.6823000000000001, 1.0426, 0.7902, 0... [0.9501000000000001, 1.7974999999999999, 1.499...
5 id_00ab2d761 GGAAAGCGCCGCGGCGGUAGCGGCAGCGAGGAGCGCUACCAAGGCA... .....(.(((((.(((((((((...........)))))))..(((.... EEEEESISSSSSISSSSSSSSSHHHHHHHHHHHSSSSSSSMMSSSH... 4.136 1 107 68 [0.1942, 0.2041, 0.1626, 0.1213, 0.10590000000... [0.2726, 0.2984, 0.21660000000000001, 0.1637, ... [0.3393, 0.2728, 0.2005, 0.1703, 0.1495, 0.134... [0.165, 0.20520000000000002, 0.179, 0.1333, 0.... [0.2864, 0.24710000000000001, 0.2222, 0.1903, ... [0.7642, 1.6641, 1.0622, 0.5008, 0.4107, 0.133... [0.9559000000000001, 1.9442, 1.0114, 0.5105000... [1.9554, 2.1298, 1.0403, 0.609, 0.5486, 0.386,... [0.22460000000000002, 1.7281, 1.381, 0.6623, 0... [0.5882000000000001, 1.1786, 0.9704, 0.6035, 0...
6 id_00abef1d7 GGAAAACAAUUGCAUCGUUAGUACGACUCCACAGCGUAAGCUGUGG... .........((((((((......((((((((((((....)))))))... EEEEEEEEESSSSSSSSIIIIIISSSSSSSSSSSSHHHHSSSSSSS... 2.485 1 107 68 [0.422, 0.5478000000000001, 0.4749000000000000... [0.4801, 0.7943, 0.42160000000000003, 0.397300... [0.9822000000000001, 1.272, 0.6940000000000001... [0.5827, 0.7555000000000001, 0.5949, 0.4511, 0... [0.9306000000000001, 1.0496, 0.5844, 0.7796000... [0.895, 2.3377, 2.2305, 2.003, 1.9006, 1.0373,... [0.46040000000000003, 3.6695, 0.78550000000000... [2.7711, 7.365, 1.6924000000000001, 1.43840000... [1.073, 2.8604000000000003, 1.9936, 1.0273, 1.... [2.0964, 3.3688000000000002, 0.6399, 2.1053, 1...
7 id_00b436dec GGAAAUCAUCGAGGACGGGUCCGUUCAGCACGCGAAAGCGUCGUGA... .....(((((((((((..(((((((((..((((....))))..)))... EEEEESSSSSSSSSSSIISSSSSSSSSIISSSSHHHHSSSSIISSS... 1.727 1 107 68 [0.4843, 0.5233, 0.4554, 0.43520000000000003, ... [0.8719, 1.0307, 0.6649, 0.34500000000000003, ... [0.7045, 0.7775000000000001, 0.5662, 0.4561, 0... [0.384, 0.723, 0.4766, 0.30260000000000004, 0.... [0.7429, 0.9137000000000001, 0.480400000000000... [1.1576, 1.5137, 1.3382, 1.5622, 1.2121, 0.295... [1.6912, 5.2652, 2.3901, 0.45890000000000003, ... [1.8641, 2.3767, 1.149, 1.0132, 0.9876, 0.0, 0... [0.49060000000000004, 4.6339, 1.95860000000000... [1.2852000000000001, 2.5460000000000003, 0.234...

There's a column in this data set called SN_filter. This is the signal to noise filter capturing which RNA molecules that passed the evaluation criteria defined by the Stanford researchers. This means we will drop the rows with SN_filter == 0

There are some RNAs that have quite a large amount of noise which is filtered by the SN_filter.

There are also RNAs that are all "." for the first 68 bases. During analysis, we've removed these sequences, as they were hard to capture since they have no variation in predicted structure, but a lot of variation in reactivity. This could be due to it having a very complex structure that is not represented here, or the predicted loop type being incorrect, or something entirely different.

Preparing the sequences for the QLattice

Here, we prepare the sequences for the QLattice, by expanding the sequences into individual samples for each nucleobase. This means, that each sample consists of one nucleobase (base), and a predicted loop type (loop).

In order to maintain information about its position and relevance in the structure, we add some features about its surrounding neighbours. We call these motifs.

The motifs are defined by a left-side (5') and right-side (3') window of the two neighbouring bases. We then do the same, but for the loop type.

end_pos = len(data.loc[0, 'predicted_loop_type'])

RNA_idx = [j for j in data.index for i in range(0, end_pos)]
pos_idx = [i for j in data.index for i in range(0, end_pos)]

loop_exp = data['predicted_loop_type'].apply(lambda x : list(x)).agg(sum)
base_exp = data['sequence'].apply(lambda x : list(x)).agg(sum)

exp_df = pd.DataFrame({'loop' : loop_exp, 'base': base_exp, 'RNA_idx' : RNA_idx, 'pos_idx' : pos_idx})
react_len = len(data.iloc[0].reactivity)
df = exp_df[exp_df.pos_idx < react_len]
df = df[df.pos_idx >= 5]

df['reactivity'] = df.apply(lambda row: data.loc[row.RNA_idx].reactivity[row.pos_idx], axis=1)

df['sequence'] = data.loc[df.RNA_idx].set_index(df.index).sequence
df['base_left_motif'] = df.apply(lambda x: x['sequence'][:x.pos_idx][-2:], axis=1)
df['base_right_motif'] = df.apply(lambda x: x['sequence'][x.pos_idx + 1:][:2], axis=1)

df['loop_type'] = data.loc[df.RNA_idx].set_index(df.index).predicted_loop_type
df['loop_left_motif'] = df.apply(lambda x: x['loop_type'][:x.pos_idx][-2:], axis=1)
df['loop_right_motif'] = df.apply(lambda x: x['loop_type'][x.pos_idx + 1:][:2], axis=1)

Train, validation and holdout split

We split our train/validation/holdout split according to the sequence it belongs to. This is to make sure we don't contaminate our validation and holdout sets with samples from the same sequences we have trained on. We've captured the original sequences in the column RNA_idx to make this splitting easier. We have some other meta columns like this, that we'll remove prior to training.

train_idx, remain_idx = train_test_split(list(data.index),train_size = 0.5, random_state = 42)
valid_idx, holdout_idx = train_test_split(remain_idx,train_size = 0.5, random_state = 42)

train = df.query('RNA_idx == @train_idx')
valid = df.query('RNA_idx == @valid_idx')
holdout = df.query('RNA_idx == @holdout_idx')

Here is our training set. We've expanded the previous data set so we have a feature for base, predicted loop type, base_left_motif, base_right_motif, left_loop_motif and right_loop_motif

train.head()
loop base RNA_idx pos_idx reactivity sequence base_left_motif base_right_motif loop_type loop_left_motif loop_right_motif
5 S A 0 5 0.4384 GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUA... AA GC EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHH... EE SS
6 S G 0 6 0.2560 GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUA... AA CU EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHH... ES SS
7 S C 0 7 0.3364 GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUA... AG UC EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHH... SS SS
8 S U 0 8 0.2168 GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUA... GC CU EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHH... SS SS
9 S C 0 9 0.3583 GGAAAAGCUCUAAUAACAGGAGACUAGGACUACGUAUUUCUAGGUA... CU UA EEEEESSSSSSHHHHHHHSSSSBSSXSSIIIIISSIISSSSSSHHH... SS SH

Approach 1: Training a QLattice to produce highly complex models

The first approach is to connect to the QLattice, and just fit a model with a large complexity.

# Connecting to the QLattice
ql = feyn.connect_qlattice()

# Seeding the QLattice for reproducible results
ql.reset(42)
# Output variable
output = 'reactivity'

# Declaring features
features = ['base', 'loop', 'base_left_motif', 'base_right_motif', 'loop_left_motif', 'loop_right_motif']

# Declaring categorical features
stypes = {}
for f in features:
    if train[f].dtype =='object':
        stypes[f] = 'c'

In order to penalize more complex graphs that don't necessarily add enough utility compared to their simpler counterparts, we use BIC as a selection criterion prior to updating the QLattice with the best models.

This avoids redundancies in your models, so that each interaction is potentially contributing something useful whenever possible.

This is a regression task, which is default for auto_run

models = ql.auto_run(train[features+[output]], output, stypes=stypes, criterion='bic')
Loss: 8.51E-02Epoch no. 10/10 - Tried 19618 models - Completed in 3m 59s.reactivity linear: scale=2.993250 w=1.942259 bias=0.0009reactivi..0outmultiplymultiply1addadd2tanhtanh3loop_right_motif categorical with 18 values bias=-0.5318loop_rig..4catmultiplymultiply5loop categorical with 7 values bias=-1.1033loop6catbase categorical with 4 values bias=-0.8093base7cataddadd8base_left_motif categorical with 16 values bias=0.3443base_lef..9catbase_right_motif categorical with 16 values bias=-0.1301base_rig..10cat
model_base = models[0]
model_base.plot(train)
reactivity linear: scale=2.993250 w=1.942259 bias=0.0009reactivi..0outmultiplymultiply1addadd2tanhtanh3loop_right_motif categorical with 18 values bias=-0.5318loop_rig..4catmultiplymultiply5loop categorical with 7 values bias=-1.1033loop6catbase categorical with 4 values bias=-0.8093base7cataddadd8base_left_motif categorical with 16 values bias=0.3443base_lef..9catbase_right_motif categorical with 16 values bias=-0.1301base_rig..10catTraining MetricsR20.487RMSE0.292MAE0.188Inputsloop_right_motifloopbasebase_left_motifbase_right_motif

Training Metrics

We've supplied a random seed to this QLattice, but depending on version you might still be experiencing different results, so keep that in mind as you try this for yourself.

model_base.plot_signal(train)
reactivity linear: scale=2.993250 w=1.942259 bias=0.0009reactivi..0outmultiplymultiply1addadd2tanhtanh3loop_right_motif categorical with 18 values bias=-0.5318loop_rig..4catmultiplymultiply5loop categorical with 7 values bias=-1.1033loop6catbase categorical with 4 values bias=-0.8093base7cataddadd8base_left_motif categorical with 16 values bias=0.3443base_lef..9catbase_right_motif categorical with 16 values bias=-0.1301base_rig..10cat0.70.70.560.370.370.53-0.5-0.290.390.070.39-10+1Pearson correlation

Looking at the plot above it appears that one of them doesn't add a lot of signal to the model - notice how the base_left_motif into the next interaction of the model doesn't seem to increase the signal by much. This is more obvious when compared to the added signal from combining the other features in the model.

Let's constrain the models

The previous plots tells us is that we should try to restrict the graph a bit more to force the QLattice to choose the best features, as it won't have room to use them all.

We can achieve this constraint by setting the max complexity to 7, which would allow for the individual models to only contain maximum 4 features.

ql.reset(42)
models = ql.auto_run(train[features+[output]], output, stypes=stypes, max_complexity=7, criterion='bic')
Loss: 8.76E-02Epoch no. 10/10 - Tried 17225 models - Completed in 2m 55s.reactivity linear: scale=2.993250 w=2.408210 bias=-0.0104reactivi..0outgaussian2gaussian1addadd2multiplymultiply3loop_right_motif categorical with 18 values bias=0.7496loop_rig..4catloop categorical with 7 values bias=-1.3791loop5catbase_right_motif categorical with 16 values bias=0.0631base_rig..6catbase categorical with 4 values bias=-0.6175base7cat
model_constrained = models[0]

Taking a look at the train and validation sets:

print('The base model (unconstrained):')
display(model_base.plot(train, valid))
print("The constrained model (max complexity = 7):")
display(model_constrained.plot(train, valid))
The base model (unconstrained):
reactivity linear: scale=2.993250 w=1.942259 bias=0.0009reactivi..0outmultiplymultiply1addadd2tanhtanh3loop_right_motif categorical with 18 values bias=-0.5318loop_rig..4catmultiplymultiply5loop categorical with 7 values bias=-1.1033loop6catbase categorical with 4 values bias=-0.8093base7cataddadd8base_left_motif categorical with 16 values bias=0.3443base_lef..9catbase_right_motif categorical with 16 values bias=-0.1301base_rig..10catTraining MetricsR20.487RMSE0.292MAE0.188Test0.4960.2920.189Inputsloop_right_motifloopbasebase_left_motifbase_right_motif

Training Metrics

Test

The constrained model (max complexity = 7):
reactivity linear: scale=2.993250 w=2.408210 bias=-0.0104reactivi..0outgaussian2gaussian1addadd2multiplymultiply3loop_right_motif categorical with 18 values bias=0.7496loop_rig..4catloop categorical with 7 values bias=-1.3791loop5catbase_right_motif categorical with 16 values bias=0.0631base_rig..6catbase categorical with 4 values bias=-0.6175base7catTraining MetricsR20.468RMSE0.297MAE0.2Test0.4780.2970.201Inputsloop_right_motifloopbase_right_motifbase

Training Metrics

Test

model_constrained.plot_signal(train)
reactivity linear: scale=2.993250 w=2.408210 bias=-0.0104reactivi..0outgaussian2gaussian1addadd2multiplymultiply3loop_right_motif categorical with 18 values bias=0.7496loop_rig..4catloop categorical with 7 values bias=-1.3791loop5catbase_right_motif categorical with 16 values bias=0.0631base_rig..6catbase categorical with 4 values bias=-0.6175base7cat0.690.690.620.53-0.360.50.390.28-10+1Pearson correlation

Both of these models actually appear to generalize quite well. However, it also confirms our suspicions that the left motifs are less important.

What's next?

So we seem to have a stable model with only four features.

Now it'd be interesting to see just how far down we can carve this model until it falls apart.

We'll try to get rid of the features that seem to contribute the least.

In the above plots, it would appear that the loop_right_motif contributes the least to the growth of signal in the model, so let's try to remove it

ql.reset(42)

features = ['base', 'loop', 'base_right_motif']
models = ql.auto_run(train[features+[output]], output, stypes=stypes, max_complexity=6, criterion='bic')
Loss: 8.91E-02Epoch no. 10/10 - Tried 16331 models - Completed in 2m 14s.reactivity linear: scale=2.993250 w=1.789918 bias=-0.0004reactivi..0outgaussian1gaussian1addadd2base categorical with 4 values bias=0.5596base3cataddadd4base_right_motif categorical with 16 values bias=-0.9106base_rig..5catloop categorical with 7 values bias=-0.7239loop6cat
model_three_features = models[0]
print('The constrained model:')
display(model_constrained.plot(train, valid))
print("The three feature model:")
display(model_three_features.plot(train, valid))
The constrained model:
reactivity linear: scale=2.993250 w=2.408210 bias=-0.0104reactivi..0outgaussian2gaussian1addadd2multiplymultiply3loop_right_motif categorical with 18 values bias=0.7496loop_rig..4catloop categorical with 7 values bias=-1.3791loop5catbase_right_motif categorical with 16 values bias=0.0631base_rig..6catbase categorical with 4 values bias=-0.6175base7catTraining MetricsR20.468RMSE0.297MAE0.2Test0.4780.2970.201Inputsloop_right_motifloopbase_right_motifbase

Training Metrics

Test

The three feature model:
reactivity linear: scale=2.993250 w=1.789918 bias=-0.0004reactivi..0outgaussian1gaussian1addadd2base categorical with 4 values bias=0.5596base3cataddadd4base_right_motif categorical with 16 values bias=-0.9106base_rig..5catloop categorical with 7 values bias=-0.7239loop6catTraining MetricsR20.463RMSE0.298MAE0.193Test0.4770.2970.193Inputsbasebase_right_motifloop

Training Metrics

Test

model_three_features.plot_signal(train)
reactivity linear: scale=2.993250 w=1.789918 bias=-0.0004reactivi..0outgaussian1gaussian1addadd2base categorical with 4 values bias=0.5596base3cataddadd4base_right_motif categorical with 16 values bias=-0.9106base_rig..5catloop categorical with 7 values bias=-0.7239loop6cat0.680.680.630.280.610.40.5-10+1Pearson correlation

We're very close here. It did get a little worse - but if we want to tend towards simpler, more interpretble models, the three feature model is definitely what we should go for.

It also already provides so much signal, that we can expect it to explain most of the behaviour we are able to describe using these features.

That said - we might now already also see that the base doesn't supply much information, so let's try to remove it and see how dramatic of a difference that makes.

ql.reset(42)

features = ['loop', 'base_right_motif']
# Note we're reducing to max complexity of 3 for two features.
models = ql.auto_run(train[features+[output]], output, stypes=stypes, max_complexity=3, criterion='bic')
Loss: 9.56E-02Epoch no. 10/10 - Tried 11366 models - Completed in 56s.reactivity linear: scale=2.993250 w=3.133148 bias=0.0028reactivi..0outmultiplymultiply1loop categorical with 7 values bias=0.2794loop2catbase_right_motif categorical with 16 values bias=0.1917base_rig..3cat
model_two_features = models[0]
print("The three feature model:")
display(model_three_features.plot(train, valid))
print("The two feature model:")
display(model_two_features.plot(train, valid))
The three feature model:
reactivity linear: scale=2.993250 w=1.789918 bias=-0.0004reactivi..0outgaussian1gaussian1addadd2base categorical with 4 values bias=0.5596base3cataddadd4base_right_motif categorical with 16 values bias=-0.9106base_rig..5catloop categorical with 7 values bias=-0.7239loop6catTraining MetricsR20.463RMSE0.298MAE0.193Test0.4770.2970.193Inputsbasebase_right_motifloop

Training Metrics

Test

The two feature model:
reactivity linear: scale=2.993250 w=3.133148 bias=0.0028reactivi..0outmultiplymultiply1loop categorical with 7 values bias=0.2794loop2catbase_right_motif categorical with 16 values bias=0.1917base_rig..3catTraining MetricsR20.424RMSE0.309MAE0.202Test0.4450.3060.201Inputsloopbase_right_motif

Training Metrics

Test

We see a bit of a drop but that's also in line with our expectations. Still, we seem to be getting closer to the essence of the model here. The loop and the base_right_motiftogether.

We could reduce further down to just the loop - but without feature interactions it's fairly obvious that the performance would reduce to just the explanative power of the loop feature.

model_two_features.plot_signal(train)
reactivity linear: scale=2.993250 w=3.133148 bias=0.0028reactivi..0outmultiplymultiply1loop categorical with 7 values bias=0.2794loop2catbase_right_motif categorical with 16 values bias=0.1917base_rig..3cat0.650.650.510.4-10+1Pearson correlation

Let's look at an interesting sequence

Now we have a few models we can take a look at an intersting sequence and try to map out what kind of tradeoffs we get over different complexities.

Secondary structure of sequence 642

This is one of the sequences from the validation set. Let's start plotting the predicted reactivity and compare to the actual reactivity for the simplest model and work our way up. This picture was generated by the forna server that ViennaRNA has made available here: http://rna.tbi.univie.ac.at/forna/

Defining a useful plotting function for RNA sequences

First, let's define a function to plot the sequences of the predictions versus the actuals, to help us identify the peaks and valleys of our predictions across the sequences.

def plot_RNA_seq(model, data, idx, figsize = (24,5), grid=True, title=''):
    
    output = model.output    
    fig, ax = plt.subplots(figsize = figsize)        
    
    sub_seq = data.query(f'RNA_idx == {idx}')

    prediction = model.predict(sub_seq)
    ax.plot(range(len(sub_seq)), sub_seq[output], label = 'actuals')
    ax.plot(range(len(sub_seq)), prediction, label = 'pred')

    x_axis1 = list(sub_seq['loop'].values)
    x_axis2 = list(sub_seq['base'].values)

    x_axis = list(zip(x_axis1, x_axis2))
    x_axis = ["_".join(x_axis[i]) for i in range(len(x_axis))]

    ax.set_xticks(range(len(sub_seq)))
    ax.set_xticklabels(x_axis)
    ax.tick_params(rotation = 90)
    ax.set_title(title+str(idx))

    ax.set_ylabel(output)
    ax.set_xlabel('sequence (loop and base)')

    ax.legend()

    if grid:
        ax.grid()

    return ax
plot_RNA_seq(model_two_features, valid, [642])
display(model_two_features)
Loss: 9.56E-02reactivity linear: scale=2.993250 w=3.133148 bias=0.0028reactivi..0outmultiplymultiply1loop categorical with 7 values bias=0.2794loop2catbase_right_motif categorical with 16 values bias=0.1917base_rig..3cat

png

This sequence is already captured really well with the simplest, two-feature model of the loop and base_right_motif. Let's compare it with the three-feature one, that includes the base.

ax = plot_RNA_seq(model_three_features, valid, [642])
ax.annotate('Region 1', (24,0.8), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
ax.annotate('Region 2', (38,1), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
ax.annotate('Region 3', (34,0.05), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
display(model_three_features)
Loss: 8.91E-02reactivity linear: scale=2.993250 w=1.789918 bias=-0.0004reactivi..0outgaussian1gaussian1addadd2base categorical with 4 values bias=0.5596base3cataddadd4base_right_motif categorical with 16 values bias=-0.9106base_rig..5catloop categorical with 7 values bias=-0.7239loop6cat

png

It's not the steepest difference, but we see a bit more of a defined peak at the I_A in region 2 on the plot. We also see a bit more reactivity in the S_G-S_A region 1. The I_A-I_C in region 3 is also much more closely mapped in this prediction.

In general we might see a bit more correction on the stems with the base included.

ax = plot_RNA_seq(model_constrained, valid, [642])
ax.annotate('Region 1', (24,0.8), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
ax.annotate('Region 2', (38,1), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
ax.annotate('Region 3', (34,0.05), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
display(model_constrained)
Loss: 8.76E-02reactivity linear: scale=2.993250 w=2.408210 bias=-0.0104reactivi..0outgaussian2gaussian1addadd2multiplymultiply3loop_right_motif categorical with 18 values bias=0.7496loop_rig..4catloop categorical with 7 values bias=-1.3791loop5catbase_right_motif categorical with 16 values bias=0.0631base_rig..6catbase categorical with 4 values bias=-0.6175base7cat

png

With the constrained, four feature model we're getting a little closer to some of the reactive stems in region 1, but it comes at the cost of some oversensitivity in other regions. However, looking at our three regions from before, they are now all generally better captured.

This is where it'd make sense to ask yourself what you're trying to accomplish with the model and whether to focus more on performance on specific things, or to understand the dynamics using simpler models.

ax = plot_RNA_seq(model_base, valid, [642])
ax.annotate('Region 1', (24,0.8), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
ax.annotate('Region 2', (38,1), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
ax.annotate('Region 3', (34,0.05), bbox=dict(boxstyle="round,pad=0.3", fc='#FFE10A', lw=2))
display(model_base)
Loss: 8.51E-02reactivity linear: scale=2.993250 w=1.942259 bias=0.0009reactivi..0outmultiplymultiply1addadd2tanhtanh3loop_right_motif categorical with 18 values bias=-0.5318loop_rig..4catmultiplymultiply5loop categorical with 7 values bias=-1.1033loop6catbase categorical with 4 values bias=-0.8093base7cataddadd8base_left_motif categorical with 16 values bias=0.3443base_lef..9catbase_right_motif categorical with 16 values bias=-0.1301base_rig..10cat

png

As we discovered, having these extra features doesn't really capture much more or add more to the story - it's just curve fitting at this point, trying to squeeze out every last bit of signal. You might see some better adjustments here and there, and some worse performance in other places, but nothing major like what we've seen previously.

Plotting the development of the RMSE from more complex to simpler models

print(f'RMSE model features: {len(model_base.features)}: {model_base.rmse(valid):.4f}')
print(f'RMSE model features: {len(model_constrained.features)}: {model_constrained.rmse(valid):.4f}')
print(f'RMSE model features: {len(model_three_features.features)}: {model_three_features.rmse(valid):.4f}')
print(f'RMSE model features: {len(model_two_features.features)}: {model_two_features.rmse(valid):.4f}')
RMSE model features: 5: 0.2915
RMSE model features: 4: 0.2968
RMSE model features: 3: 0.2970
RMSE model features: 2: 0.3060

Looking at this, depending on whether you want interpretability or performance, you'd be well off picking between the three or four feature graph.

Let's take a look at how they perform on the holdout set

print(f'RMSE model features: {len(model_base.features)}: {model_base.rmse(holdout):.4f}')
print(f'RMSE model features: {len(model_constrained.features)}: {model_constrained.rmse(holdout):.4f}')
print(f'RMSE model features: {len(model_three_features.features)}: {model_three_features.rmse(holdout):.4f}')
print(f'RMSE model features: {len(model_two_features.features)}: {model_two_features.rmse(holdout):.4f}')
RMSE model features: 5: 0.2912
RMSE model features: 4: 0.2957
RMSE model features: 3: 0.2971
RMSE model features: 2: 0.3070

We see a good generalization to the holdout set, and a similar story as to which complexity level to pick. This is a very good sign, and shows us that we're on the right track to a possible model on how the structure of an RNA sequence combined with the bases impacts the reactivity at each point.

Concluding remarks

In this example we showed how the QLattice can be used to produce simple models that pick out the important features describing reactivity. This gives us a better understanding of the mechanisms underlying mRNA reactivity and, because of it's simplicity, is able to generalise to unseen data sets. This is the power of the QLattice, simple and explainable models with high predictive performance on unseen data sets!

← Classifying toxicity of antisense oligonucleotidesPreventing the Honeybee Apocalypse (QSAR) →
  • Inspecting the raw data
  • Preparing the sequences for the QLattice
  • Train, validation and holdout split
  • Approach 1: Training a QLattice to produce highly complex models
  • Let's constrain the models
  • What's next?
    • Let's look at an interesting sequence
    • Defining a useful plotting function for RNA sequences
  • Plotting the development of the RMSE from more complex to simpler models
  • Let's take a look at how they perform on the holdout set
  • Concluding remarks

Subscribe to get news about Feyn and the QLattice.

You can opt out at any time, and you can read our privacy policy here.

Copyright © 2024 Abzu.ai - Feyn license: CC BY-NC-ND 4.0
Feyn®, QGraph®, and the QLattice® are registered trademarks of Abzu®