Prédire les valeurs manquantes

Valeurs manquantes


Dès lors qu’on est amené à collecter et analyser des données, on est rapidement confronté au problème de données manquantes.

Que faire ?

Il y a différentes possibilités pour traiter ces données manquantes, en fonction du problème, du type de données, du volume, etc.

Par exemple :

  • Supprimer les observations avec les données manquantes (attention si elle ne sont pas manquantes aléatoirement)
  • Imputer les manquants avec la valeur moyenne, médiane etc
  • Pour des séries temporelles, imputer avec la dernière valeurs, ou la prochaine
  • La valeur manquante peut être une information en soit, il est alors intéressant de conserver l’information pour analyser les données en créant une nouvelle variable par exemple manquant oui/non.
  • Créer un algorithme pour prédire la valeur manquante à l’aide des valeurs des autres variables.

Dans ce post, je vais décrire la dernière méthode, montrer comment utiliser le deep learning pour prédire les valeurs manquantes d’un jeu de données, contenant un million d’observations avec un millions de manquants répartis sur un ensemble de variables.

Je vais pour cela utiliser les données provenant de Kaggle pour la compétition du mois de juin 2022.

Note : J’utiliserai indépendamment valeur manquante ou NA (not available)

Vous pouvez trouver les données à l’adresse suivante : https://www.kaggle.com/competitions/tabular-playground-series-jun-2022/data


Exploration des données


Le jeu de données se composent de 80 variables (plus un id), réparties en 4 groupes F_1*, F_2*, F_3* et F_4*.

skimr::skim(data)
   skim_variable n_missing complete_rate     mean         sd     p0        p25      p50        p75      p100 hist 
 1 row_id                0         1      5.00e+5 288675.      0    250000.     5.00e+5 749999.    999999    ▇▇▇▇▇
 2 F_1_0             18397         0.982 -6.87e-4      1.00   -4.66     -0.675 -7.69e-4      0.673      5.04 ▁▂▇▂▁
 3 F_1_1             18216         0.982  2.09e-3      1.00   -4.79     -0.672  2.05e-3      0.676      5.04 ▁▂▇▂▁
 4 F_1_2             18008         0.982  5.51e-4      1.00   -4.87     -0.674  1.39e-3      0.674      5.13 ▁▂▇▂▁
 5 F_1_3             18250         0.982  9.82e-4      1.00   -5.05     -0.672  3.7 e-4      0.675      5.46 ▁▂▇▁▁
 6 F_1_4             18322         0.982  2.44e-3      1.00   -5.36     -0.672  2.73e-3      0.677      4.86 ▁▁▇▂▁
 7 F_1_5             18089         0.982  6.35e-4      1.00   -5.51     -0.674  2.76e-4      0.676      4.96 ▁▁▇▂▁
 8 F_1_6             18133         0.982 -1.24e-4      1.00   -5.20     -0.675  8.14e-4      0.674      4.96 ▁▂▇▂▁
 9 F_1_7             18128         0.982 -6.39e-2      0.726  -6.99     -0.500  5.78e-4      0.444      2.53 ▁▁▁▇▂
10 F_1_8             18162         0.982 -1.38e-5      1.00   -4.57     -0.674 -4.7 e-5      0.674      4.89 ▁▂▇▂▁
11 F_1_9             18249         0.982  4.51e-4      1.00   -5.00     -0.674  1.12e-3      0.676      4.79 ▁▂▇▂▁
12 F_1_10            17961         0.982  1.85e-4      0.999  -4.79     -0.674  6.71e-4      0.674      4.91 ▁▂▇▂▁
13 F_1_11            18170         0.982 -1.13e-3      1.00   -4.61     -0.677 -1.29e-3      0.674      4.82 ▁▂▇▂▁
14 F_1_12            18203         0.982 -6.12e-2      0.712  -7.06     -0.489  5.47e-4      0.436      2.30 ▁▁▁▇▃
15 F_1_13            18398         0.982 -6.71e-2      0.746  -6.90     -0.514 -8.04e-4      0.455      2.54 ▁▁▁▇▂
16 F_1_14            18039         0.982 -9.05e-4      1.00   -4.63     -0.676 -1.61e-4      0.673      4.82 ▁▂▇▂▁
17 F_2_0                 0         1      2.69e+0      1.88    0         1      2   e+0      4         15    ▇▃▁▁▁
18 F_2_1                 0         1      2.51e+0      1.75    0         1      2   e+0      4         14    ▇▆▁▁▁
19 F_2_2                 0         1      9.77e-1      1.04    0         0      1   e+0      2         11    ▇▁▁▁▁
20 F_2_3                 0         1      2.52e+0      1.65    0         1      2   e+0      4         14    ▇▆▁▁▁
21 F_2_4                 0         1      2.94e+0      1.98    0         1      3   e+0      4         16    ▇▃▁▁▁
22 F_2_5                 0         1      1.53e+0      1.35    0         1      1   e+0      2         12    ▇▂▁▁▁
23 F_2_6                 0         1      1.49e+0      1.32    0         0      1   e+0      2         12    ▇▂▁▁▁
24 F_2_7                 0         1      2.65e+0      1.74    0         1      2   e+0      4         16    ▇▃▁▁▁
25 F_2_8                 0         1      1.18e+0      1.32    0         0      1   e+0      2         13    ▇▁▁▁▁
26 F_2_9                 0         1      1.11e+0      1.10    0         0      1   e+0      2         11    ▇▁▁▁▁
27 F_2_10                0         1      3.28e+0      1.87    0         2      3   e+0      4         17    ▇▅▁▁▁
28 F_2_11                0         1      2.47e+0      1.60    0         1      2   e+0      3         13    ▇▆▁▁▁
29 F_2_12                0         1      2.76e+0      1.70    0         2      3   e+0      4         15    ▇▃▁▁▁
30 F_2_13                0         1      2.48e+0      1.65    0         1      2   e+0      3         15    ▇▂▁▁▁
31 F_2_14                0         1      1.72e+0      1.56    0         1      1   e+0      3         13    ▇▂▁▁▁
32 F_2_15                0         1      1.78e+0      1.46    0         1      2   e+0      3         13    ▇▃▁▁▁
33 F_2_16                0         1      1.80e+0      1.46    0         1      2   e+0      3         13    ▇▃▁▁▁
34 F_2_17                0         1      1.24e+0      1.25    0         0      1   e+0      2         12    ▇▁▁▁▁
35 F_2_18                0         1      1.56e+0      1.44    0         0      1   e+0      2         15    ▇▁▁▁▁
36 F_2_19                0         1      1.60e+0      1.42    0         0      1   e+0      2         13    ▇▂▁▁▁
37 F_2_20                0         1      2.23e+0      1.56    0         1      2   e+0      3         14    ▇▅▁▁▁
38 F_2_21                0         1      2.03e+0      1.61    0         1      2   e+0      3         15    ▇▂▁▁▁
39 F_2_22                0         1      1.61e+0      1.56    0         0      1   e+0      2         16    ▇▁▁▁▁
40 F_2_23                0         1      7.09e-1      1.08    0         0      0            1         11    ▇▁▁▁▁
41 F_2_24                0         1      3.13e+0      1.82    0         2      3   e+0      4         17    ▇▅▁▁▁
42 F_3_0             18029         0.982  1.74e-3      1.00   -4.69     -0.675  3.25e-3      0.677      4.59 ▁▂▇▂▁
43 F_3_1             18345         0.982 -1.15e-3      1.00   -4.47     -0.675  4.81e-4      0.674      4.85 ▁▃▇▂▁
44 F_3_2             18056         0.982  6.05e-4      0.999  -4.89     -0.673  3.92e-4      0.675      4.76 ▁▂▇▂▁
45 F_3_3             18054         0.982  8.34e-4      1.00   -4.68     -0.674  8.54e-4      0.676      4.99 ▁▂▇▂▁
46 F_3_4             18373         0.982  1.29e-3      1.00   -5.01     -0.673  2.64e-3      0.677      4.72 ▁▂▇▂▁
47 F_3_5             18298         0.982 -2.18e-3      1.00   -4.87     -0.676 -1.60e-3      0.672      5.04 ▁▂▇▂▁
48 F_3_6             18192         0.982  5.78e-5      0.999  -5.02     -0.675  8.54e-4      0.673      4.53 ▁▂▇▃▁
49 F_3_7             18013         0.982  1.52e-3      1.00   -5.05     -0.673  1.20e-3      0.676      5.46 ▁▂▇▁▁
50 F_3_8             18098         0.982  7.73e-4      1.00   -5.51     -0.676 -1.95e-5      0.676      5.11 ▁▁▇▂▁
51 F_3_9             18106         0.982 -4.40e-4      1.00   -4.85     -0.675 -1.77e-3      0.674      5.10 ▁▂▇▂▁
52 F_3_10            18200         0.982  1.71e-3      1.00   -4.63     -0.672  1.57e-3      0.675      5.13 ▁▃▇▁▁
53 F_3_11            18388         0.982  7.33e-4      0.999  -4.60     -0.675  4.80e-4      0.675      4.68 ▁▂▇▂▁
54 F_3_12            18297         0.982  2.59e-4      1.00   -4.53     -0.674  1.83e-3      0.674      4.94 ▁▃▇▂▁
55 F_3_13            18060         0.982 -2.46e-3      1.00   -4.75     -0.676 -1.59e-3      0.674      4.71 ▁▂▇▂▁
56 F_3_14            18139         0.982  7.27e-4      0.999  -5.36     -0.673  1   e-4      0.674      4.82 ▁▁▇▃▁
57 F_3_15            18238         0.982 -1.51e-3      1.00   -4.45     -0.675 -1.36e-3      0.674      5.25 ▁▃▇▁▁
58 F_3_16            18122         0.982 -6.65e-4      1.00   -4.82     -0.675 -1.70e-3      0.675      4.84 ▁▂▇▂▁
59 F_3_17            18278         0.982 -2.14e-4      1.00   -4.81     -0.675 -4.14e-4      0.674      5.06 ▁▂▇▂▁
60 F_3_18            18089         0.982  6.27e-5      1.00   -5.20     -0.674  1.63e-4      0.675      4.96 ▁▂▇▂▁
61 F_3_19            18200         0.982 -6.49e-2      0.739  -6.07     -0.507  6.76e-4      0.451      2.67 ▁▁▂▇▁
62 F_3_20            18248         0.982  2.37e-3      0.999  -5.00     -0.671  2.45e-3      0.676      6.03 ▁▃▇▁▁
63 F_3_21            18396         0.982 -5.93e-2      0.697  -7.15     -0.480 -6.49e-4      0.428      2.39 ▁▁▁▇▂
64 F_3_22            18177         0.982  8.73e-5      0.999  -4.74     -0.674  5.9 e-5      0.674      4.97 ▁▂▇▂▁
65 F_3_23            18206         0.982  3.65e-4      1.00   -5.25     -0.674 -4.52e-4      0.675      4.81 ▁▁▇▂▁
66 F_3_24            18145         0.982 -8.17e-4      1.00   -4.89     -0.675 -4.57e-4      0.674      4.98 ▁▂▇▂▁
67 F_4_0             18128         0.982  3.27e-1      2.32  -12.9      -1.17   4.21e-1      1.91      10.7  ▁▁▇▅▁
68 F_4_1             18164         0.982 -3.31e-1      2.41  -12.5      -1.96  -3.56e-1      1.28      11.7  ▁▂▇▂▁
69 F_4_2             18495         0.982 -8.58e-2      0.837  -9.66     -0.608 -6.20e-2      0.485      2.91 ▁▁▁▇▃
70 F_4_3             18029         0.982 -1.95e-1      0.821  -9.94     -0.686 -1.37e-1      0.369      2.58 ▁▁▁▇▅
71 F_4_4             17957         0.982  3.33e-1      2.37  -12.8      -1.19   4.25e-1      1.94      11.9  ▁▁▇▃▁
72 F_4_5             18063         0.982  3.36e-1      2.35  -12.5      -1.27   3.03e-1      1.92      13.5  ▁▂▇▁▁
73 F_4_6             18325         0.982  3.77e-3      2.29  -11.1      -1.57  -7.18e-2      1.52      11.5  ▁▂▇▂▁
74 F_4_7             18014         0.982  3.34e-1      2.36  -11.7      -1.22   3.79e-1      1.93      12.5  ▁▂▇▂▁
75 F_4_8             18176         0.982 -7.18e-2      0.778 -10.1      -0.518  1.82e-2      0.475      2.61 ▁▁▁▇▇
76 F_4_9             18265         0.982 -7.99e-2      0.807  -9.86     -0.577 -2.78e-2      0.480      2.81 ▁▁▁▇▅
77 F_4_10            18225         0.982  3.83e-2      0.707 -10.4      -0.386  1.03e-1      0.530      2.55 ▁▁▁▆▇
78 F_4_11            18119         0.982  5.52e-1      5.00  -26.3      -2.79   2.03e-1      3.65      31.2  ▁▂▇▁▁
79 F_4_12            18306         0.982  3.34e-1      2.38  -11.5      -1.27   3.54e-1      1.95      11.3  ▁▂▇▂▁
80 F_4_13            17995         0.982  3.30e-1      2.36  -10.7      -1.30   2.95e-1      1.92      11.9  ▁▂▇▂▁
81 F_4_14            18267         0.982  3.72e-2      0.776  -9.98     -0.396  1.31e-1      0.574      2.58 ▁▁▁▇▇

Toutes les variables sont continues à l’exception des variables F_2* qui ne contiennent pas de données manquantes. On doit donc prédire les manquants pour des données continues, c’est un problème de régression.

Chaque variable contient à peu près la même proportion de manquant 1.8%.

Comment sont-elles réparties ?

data$count_na <- rowSums(is.na(data))

data %>%
  count(count_na, sort = TRUE)
   count_na      n
      <dbl>  <int>
 1        1 370798
 2        0 364774
 3        2 185543
 4        3  61191
 5        4  14488
 6        5   2723
 7        6    413
 8        7     64
 9        8      4
10        9      2

Il y a 364774 observations sans manquants, 370798 avec un seul NA, et jusqu’à 9 Na pour une même observation.

Une stratégie peut commencer à se dessiner ici, qui consisterai à utiliser les 364774 observations sans manquants entrainer l’algorithme et prédire les 370798 avec manquants. Le problème sera plus complexe à partir de 2 manquants car ils ne seront probablement pas répartis sur les mêmes variables.

Si on regarde les combinaisons de manquants, il y a 42633 combinaisons possibles. Il va donc falloir simplifier le problème. Poussons un peu l’analyse.

Visualisons les corrélations :

corrplot::corrplot(cor(drop_na(data[,-1])), tl.cex = 0.5, tl.col = "black", method = "color")

matrice de corrélation pour toutes les variables

En regardant la matrice de corrélation, on s’aperçoit qu’il n’y a aucune corrélation entre les groupes de variables, et seuls les groupes F_2 et F_4 présentent des corrélations mais uniquement internes au groupe.

Ça laisse à penser qu’il va être difficile de prédire les variables F_1* et F_3*, et que les variables F_2* ne seront probablement pas d’une grande aide.

Un moyen de le confirmer, est de créer un premier algorithme pour prédire ces variables lorsqu’il n’y a qu’un seul manquant. J’ai donc créé rapidement un algorithme qui boucle sur chaque variable, et créé un modèle (avec LightGBM) par variable à prédire. On s’aperçoit tout de suite qu’il n’y parvient pas, et ce pour chacune des variables F_1* et F_3*.

On voit tout de suite que le RMSE augmente pour les données de validation au lieu de diminuer :

[1]:  train's rmse:0.99907  valid's rmse:0.999992
[11]:  train's rmse:0.983548  valid's rmse:1.00045
[21]:  train's rmse:0.968703  valid's rmse:1.00171
[31]:  train's rmse:0.954811  valid's rmse:1.00238
[41]:  train's rmse:0.941672  valid's rmse:1.00242
[51]:  train's rmse:0.929105  valid's rmse:1.00233

Et en traçant l’évolution du RMSE pour train et valid :

train et valid RMSE

On voit clairement sur les données de validations que les prédictions se dégradent.

La meilleure stratégie pour les variables F_1* et F_3* est d’imputer avec leurs moyennes, après quelques tests, c’est ce qui donne le meilleur résultat. 40 variables peuvent donc être traitées de façon très simple.

Voir le code pour l’imputation avec la moyenne ou la médiane sur github.

Il ne reste qu’à trouver une solution pour les variables F_4*, qui ne sont plus qu’au nombre de 15.

Regardons à nouveau la matrice de corrélation uniquement pour ces variables :

corrplot::corrplot(cor(drop_na(select(data, contains("F_4")))), tl.cex = 0.5, tl.col = "black", method = "square", type = "upper")
corrplot::corrplot.mixed(cor(drop_na(select(data, contains("F_4")))), tl.cex = 0.5, tl.col = "black", tl.offset = 0.2, upper = "color", number.cex = 0.8, tl.pos = "lt")

matrice de corrélation pour toutes les variables f_4

On réévalue les manquants uniquement pour ces variables :

   count_na      n
      <dbl>  <int>
1         0 759268
2         1 211342
3         2  27127
4         3   2124
5         4    135
6         5      4

On n’a plus qu’un maximum de 5 manquants par observation, et une proportion de données complètes pour entrainer l’algorithme de 76%.

Bien entendu lorsqu’il y a plus d’un manquant, ils peuvent être répartis sur 14 variables, ce qui donne dans notre cas 703 combinaisons.

Par exemple, pour prédire les manquants de la variable F_4_0, on peut n’avoir aucune autre manquant, ou un manquant sur la variable F_4_2, ou sur la variable F_4_3, ou encore 2 autres manquants, sur les variables F_4_3 et F_4_5 etc.


Stratégies


Stratégie 1 : Prédiction par variable

Le principe est d’utiliser un algorithme qui sait faire des prédictions malgré des valeurs manquantes, comme LightGBM. On créé un modèle par variable à prédire, qu’on entraine sur les données qui n’ont pas de manquants

On prédit toutes les valeurs manquantes de cette variable indépendamment du fait qu’il puisse manquer des valeurs sur d’autres variables. Bien entendu, plus il y a de valeurs manquantes, plus la prédiction va être mauvaise. On peut jouer sur certains paramètres de lightGBM pour améliorer la tolérance aux manquants comme feature_fraction qui permet de réduire le nombre de variables utilisées pour chaque arbre de décision. C’est à double tranchant, car ça peut réduire la performance pour les observations complètes. Il faut donc trouver la bonne valeur du paramètre par hyperparameter tuning.

Pour améliorer les prédictions, on peut compléter par un réseau de neurones : on utilise alors un deuxième data frame dont les valeurs manquantes ont été remplacées par celle prédites par lightGBM (les réseaux de neurones ne peuvent pas gérer les NA).

  • Avantages : on ne créé que 15 modèles, c’est donc assez rapide, le RMSE est de 0.86 ce qui est correct
  • Inconvénients : les prédictions de lightGBM sont plus imprécises lorsqu’il y a des manquants, et ces valeurs sont utilisées par le réseau de neurones, on peut donc amplifier l’incertitude et la marge d’amélioration est faible.


Stratégie 2 : Altérer les données d’entrainement pour ajouter des manquants

Cette stratégie a été proposée par un des compétiteurs sur Kaggle. Il s’agit de remplacer les NA par une valeur (-1 mais on peut choisir autre chose) pour les variables indépendantes afin de pouvoir utiliser le deep learning. Mais bien sur on ne pourra pas faire de prédiction si ces valeurs n’ont pas été vues au préalable lors de la phase d’entrainement. Il faut donc altérer les données d’entrainement pour remplacer de vrais données aléatoirement par des -1.

Pour que ça fonctionne, il faut ajouter autant de -1 qu’il y a de NA, donc on créé des modèles par nombre de NA dans les variables indépendantes. Par exemple s’il y a 2 variables indépendantes avec NA, on créé pour chaque observation des données d’entrainements, 2 valeurs -1 qu’on répartit aléatoirement entre les variables en s’assurant qu’il n’y en ait que 2 NA par observation.

  • Avantages : ca améliore légèrement les prédictions
  • Inconvénients : C’est plus complexe à coder (en particulier pour attribuer aléatoirement les -1 en fonction du nombre de NA), et plus long à entrainer.


Stratégie 3 : Prédire plusieurs variables simultanément (régression avec plusieurs sorties)

C’est la stratégie qui a donné de meilleur résultat et que je vais développer ensuite dans ce post.

L’idée ici est de prédire pour chaque combinaison de manquant, un modèle qui aura autant de sorties que de manquants pour prédire tous les manquants simultanément.

C’est finalement assez facile à faire, car il s’agit de créer un réseau de neurone avec autant de neurones qu’il y a de NA pour la couche de sortie.


Mise en oeuvre


Pour pouvoir appliquer cette stratégie, il faut grouper les données par combinaison de valeur manquantes. On créé donc une variable qui liste des noms des variables contenant des NA pour chaque observation.

list_variables <- colnames(data)
list_cols <- list_variables[grep("F_4",list_variables)]

data <- data %>%
   select(row_id, all_of(list_cols))

On créé une fonction na_col() qui crée une nouvelle variable contenant le nom de la variable si la valeur de la variable d’origine est NA, ou vide sinon.

na_col <- function(var, data){
  
  var_ts <- sym(var)
  new_var_ts <- sym(glue::glue(var, "_na"))
  
  data %>%
    select({{var_ts}}) %>%
    mutate("{{new_var_ts}}" := ifelse(is.na({{var_ts}}), var, "")) %>%
    select(-{{var_ts}})
    
}

On utilise map_dfr() et reduce() du package {purrr} pour appliquer la fonction na_col() à toutes les variables et obtenir une seule variable avec la liste de toutes les variables contenant des NAs par observation.

data_mut <- map_dfc(list_cols, na_col, data) %>%
  mutate(na_cols = reduce(., paste, sep = " ")) %>%
  mutate(na_cols = str_squish(na_cols))

data <- data %>% 
  bind_cols(select(data_mut, na_cols))

data$cnt <- rowSums(is.na(data))   # nb NA 

Explication :

  • map_dfc() applique la fonction na_col() sur toutes les variables et retourne un dataframe contenant autant de colonnes qu’il y a de variables, et contenant le nom de la variable si sa valeur est NA, ou vide sinon.
  • reduce() combiné à paste, permet de ne créer qu’une seule variable contenant la concaténation de toutes ces valeurs.
  • str_squish() supprime tous les espaces en trop, pour ne conserver que les séparateurs.

On a donc à présent un dataframe contenant uniquement les variables F_4_* ainsi que la variable na_cols, qui contient la liste des variables contenant des NA, séparées par un espace, et la variable cnt qui contient le nombre de NA par observation.

Regardons un aperçu du nombre de combinaisons uniques :

unique_combi <- unique(data$na_cols)[-1]
head(unique_combi, 20)
 [1] "F_4_2"        "F_4_4"        "F_4_3 F_4_14" "F_4_12"       "F_4_14"       "F_4_3"        "F_4_8 F_4_12" "F_4_1"        "F_4_8 F_4_14"
[10] "F_4_8"        "F_4_4 F_4_14" "F_4_5"        "F_4_9"        "F_4_2 F_4_13" "F_4_10"       "F_4_3 F_4_13" "F_4_0"        "F_4_7 F_4_9" 
[19] "F_4_6"        "F_4_13"   

On utilise ensuite la partie du dataframe ne contenant aucun NA pour entrainer l’algorithme, avec un split pour avoir des données d’entrainement et de validation.

train_basis <- data %>%
  filter(cnt == 0) %>%
  select(-na_cols, -cnt)

split <- floor(0.80*NROW(train_basis))

On peut ensuite boucler sur toutes les combinaisons de variables pour créer autant de modèles qu’il y a de combinaisons. Cependant, au vu du nombre de modèle à créer, il est préférable de boucler par variable à prédire, et ensuite par combinaison comprenant cette variable, ce qui permet d’exécuter plusieurs modèles en parallèle. C’est ce que permet cette structure :

for(variable in list_cols){
  
  combi_var <- unique_combi[str_detect(unique_combi, variable)]
  
  for(combi in combi_var){
    
    set_cols <- str_split(combi, " ", simplify = TRUE)[,1]

    # ** MODEL **

  }

}

La variable set_cols contient la liste des colonnes contenant des NA pour cette itération de la boucle.


On peut donc se concentrer sur le modèle.

Les données d’entrainement contiennent toutes les colonnes ne contenant pas de manquant, et la cible est une matrice contenant les valeurs pour les différentes variables à prédire (celles contenues dans set_cols)

train_df <- train_basis[1:split,] %>%
   select(-all_of(set_cols), -row_id)
train_target <- train_basis[1:split,] %>% 
   select(all_of(set_cols)) %>% 
   as.matrix()

On fait de même pour les données de validations :

valid_df <- train_basis[split:NROW(train_basis),] %>%
  select(-all_of(set_cols), -row_id)
valid_target <- train_basis[split:NROW(train_basis),] %>%
  select(all_of(set_cols)) %>%
  as.matrix()

Et enfin on prépare les données de test en filtrant uniquement les données contenant des NA pour cette combinaison de variable, et en ne sélectionnant que les variables qui n’ont pas de manquant :

test_df <- data %>%
  filter(na_cols %in% combi) %>%
  select(-all_of(set_cols), -cnt, -na_cols)

test_row_id <- test_df$row_id
test_df <- test_df %>% select(-row_id)

On normalise les données :

preProcValues <- caret::preProcess(select(train_basis[-1], -all_of(set_cols)), method = c("center", "scale"))

trainTransformed <- predict(preProcValues, train_df)
validTransformed <- predict(preProcValues, valid_df)
testTransformed <- predict(preProcValues, test_df)

train_mx <- as.matrix(trainTransformed)

On définit le modèle avec une couche d’entrée comprenant autant de neurones qu’il y a des variables sans NA, et en couche de sortie autant de neurones qu’il y a de variables avec NA. Ce modèle va donc produire une matrice comprenant une colonne par variable à prédire.

model <- keras_model_sequential() %>% 
  layer_dense(units = 128, activation = "swish", input_shape = length(train_df)) %>%
  layer_batch_normalization() %>%
  layer_dense(units = 64, activation = "swish") %>%
  layer_dense(units = 32, activation = "swish") %>%
  layer_dense(units = 8, activation = "swish") %>%
  layer_dense(length(set_cols), activation = "linear")

On définit l’optimiseur et on peut compiler le modèle avec comme fonction de perte et métrique “mean_squared_error”.

optimizer <- optimizer_adam(learning_rate = 0.001)

model %>% 
  compile(
    loss = 'mean_squared_error',
    optimizer = optimizer,
    metrics = "mean_squared_error"
  )

On peut enfin entrainer le modèle à l’aide de la fonction fit(), en passant en paramètre les données, le nombre d’epochs et la taille du lot. On définit par ailleurs deux fonctions de callback, une pour arrêter l’entrainement s’il n’y a pas d’amélioration (early stopping) et une pour diminuer le taux d’apprentissage si on atteint un plateau (reduce lr on plateau).

model %>% fit(
  train_mx, 
  train_target, 
  epochs = EPOCHS, 
  batch_size = BATCH_SIZE, 
  validation_split = 0.1,
  callbacks = list(
    callback_early_stopping(monitor='val_mean_squared_error', patience=8, verbose = 1, mode = 'min', restore_best_weights = TRUE),
    callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.5, patience = 3, verbose = 1)
  )
)

On fait alors les prédictions sur les données de validation pour évaluer le RMSE sur ces données que l’algorithme n’a pas vu.

pred_valid <- model %>% predict(as.matrix(validTransformed))

predictions_valid <- as.data.frame(pred_valid)
colnames(predictions_valid) <- set_cols
predictions_valid <- predictions_valid %>%
  mutate(row_id = row_number()) %>%
  pivot_longer(cols = -row_id)
valid_target <- as.data.frame(valid_target) %>%
  mutate(row_id = row_number()) %>%
  pivot_longer(cols = -row_id)

rmse <- yardstick::rmse_vec(valid_target$value, predictions_valid$value)
print(glue("RMSE - combination {reduce(combi, paste, sep = ", ")}: {rmse}"))

Ce qui produit :

...
RMSE - combination F_4_8 F_4_10 F_4_14: 0.296459822230795
RMSE - combination F_4_6 F_4_7 F_4_10: 0.798801761253467
RMSE - combination F_4_0 F_4_2 F_4_13: 0.68729877310422
RMSE - combination F_4_2 F_4_7 F_4_8 F_4_10: 0.727564592227748
RMSE - combination F_4_4 F_4_6 F_4_10 F_4_11: 1.13184400510586
RMSE - combination F_4_1 F_4_9 F_4_10 F_4_11: 0.815886927914656
...

Enfin on prédit sur les données de test qui sont les données réellement manquantes, puis on met en forme les prédictions pour ajouter au fichier chaque couple variable / id avec la valeur prédite.

    test_predictions <- model %>% predict(as.matrix(testTransformed))
    test_predictions <- as.data.frame(test_predictions)
    colnames(test_predictions) <- set_cols
    test_predictions$row_id <- test_row_id
    
    test_predictions <- test_predictions %>%
      pivot_longer(cols = -row_id, names_to = "variable", values_to = "prediction")
      
    
    predictions <- tibble(`row-col` = glue("{test_predictions$row_id}-{test_predictions$variable}"), value = test_predictions$prediction)
    
    submission <- submission %>% bind_rows(predictions)


Conclusion


Voilà ! On vient donc de prédire un million de valeurs manquantes dans une table contenant un million de ligne et 80 variables.

Il y a une différence notable entre la première solution utilisant LightGBM et l’utilisation d’un réseau de neurones multi-sorties, cela étant la première solution nécessite environ 2 heures d’entrainement, alors que la dernière nécessite près de 30 heures (sans GPU). La différence est importante dans le cadre d’une compétition, en pratique peut-être moins, c’est à définir en fonction du problème, et la première solution donne un résultat satisfaisant comparé à une imputation avec la moyenne.

En benchmark, voici le RMSE pour différentes stratégies :

  • Imputation avec la moyenne : 0.97937
  • LGBM + NN : 0.85xx
  • Altération : 0.84xx
  • NN multi-sorties : 0.83xx

Savoir comment imputer les valeurs manquantes est utile pour pouvoir effectuer une analyse malgré l’absence de certaines données, ce qui est un problème très fréquent en pratique.


Christophe Nicault
Christophe Nicault
Stratégie des Systèmes d’Information
Transformation Numérique
Data Science

Je travaille sur la stratégie des systèmes d’information, les projets informatiques et la science des données.