GithubHelp home page GithubHelp logo

nicholasjclark / mrfcov Goto Github PK

View Code? Open in Web Editor NEW
23.0 3.0 5.0 3.59 MB

Markov random fields with covariates

R 100.00%
network-analysis graphical-models networks markov-random-field conditional-random-fields r-package multivariate-analysis multivariate-statistics machine-learning

mrfcov's Introduction

MRFcov R package logo

MRFcov: Markov Random Fields with additional covariates in R

DOIGitHub version

Releases   |   Reporting Issues   |   Application Blog animated gif

MRFcov (described by Clark et al, published in Ecology Statistical Reports) provides R functions for approximating interaction parameters of nodes in undirected Markov Random Fields (MRF) graphical networks. Models can incorporate covariates (a class of models known as Conditional Random Fields; CRFs; following methods developed by Cheng et al 2014 and Lindberg 2016), allowing users to estimate how interactions between nodes are predicted to change across covariate gradients. Note, this is a development version. For the stable version, please download from CRAN

Why Use Conditional Random Fields?

In principle, MRFcov models that use species’ occurrences or abundances as outcome variables are similar to Joint Species Distribution models in that variance can be partitioned among abiotic and biotic effects. However, key differences are that MRFcov models can:

  1. Produce directly interpretable coefficients that allow users to determine the relative importances (i.e. effect sizes) of biotic associations and environmental covariates in driving abundances or occurrence probabilities

  2. Identify association strengths, rather than simply determining whether they are “significantly different from zero”

  3. Estimate how associations are predicted to change across environmental gradients

Models such as these are also better at isolating true species ‘interactions’ using presence-absence occurrence data than are traditional null model co-occurrence methods (such as the all-too-common null model randomisation approaches). See this blogpost for a more detailed explanation and proof of this statement.

MRF and CRF interaction parameters are approximated using separate regressions for individual species within a joint modelling framework. Because all combinations of covariates and additional species are included as predictor variables in node-specific regressions, variable selection is required to reduce overfitting and add sparsity. This is accomplished through LASSO penalization using functions in the glmnet package.

Installation

You can install the stable version of the MRFcov package into R from CRAN. Alternatively, install the development version (updated features but no gurantees of good functionality) from GitHub using:

# install.packages("devtools")
devtools::install_github("nicholasjclark/MRFcov")

Brief Overview

We can explore the model’s primary functions using a test dataset that is available with the package. Load the Bird.parasites dataset, which contains binary occurrences of four avian blood parasites in New Caledonian Zosterops species (available in its original form at Dryad; Clark et al 2016). A single continuous covariate is also included (scale.prop.zos), which reflects the relative abundance of Zosterops species among different sample sites

library(MRFcov)
data("Bird.parasites")

Visualise the dataset to see how analysis data needs to be structured. In short, when estimating co-occurrence probabilities, node variable (i.e. species) occurrences should be included as binary variables (1s and 0s) as the left-most variables in data. Any covariates can be included as the right-most variables. Note, these covariates should ideally be on a similar scale, using the scale function for continuous covariates (or similar) so that covariates generally have mean = 0 and sd = 1

help("Bird.parasites")
View(Bird.parasites)

You can read more about specific requirements of data formats (for example, one-hot encoding of categorical covariates) in the supplied vignette

vignette("CRF_data_prep")

Running MRFs and visualising interaction coefficients

Run an MRF model using the provided continuous covariate (scale.prop.zos). Here, each species-specific regression will be individually optimised through cross-validated LASSO variable selection. Corresponding coefficients (e.g. the coefficient for effect of species A on species B and the coefficient for effect of species B on species A) will be symmetrised to form an undirected MRF graph

MRF_mod <- MRFcov(data = Bird.parasites, n_nodes = 4, family = 'binomial')
#> Leave-one-out cv used for the following low-occurrence (rare) nodes:
#>  Microfilaria ...
#> Fitting MRF models in sequence using 1 core ...

Visualise the estimated species interaction coefficients as a heatmap. These represent mean interactions and are very useful for identifying co-occurrence patterns, but they do not indicate how interactions change across gradients. Note, for binary data such as this, we can also plot the observed occurrences and co-occurrences using plot_observed_vals = TRUE

plotMRF_hm(MRF_mod, plot_observed_vals = TRUE, data = Bird.parasites)

Exploring regression coefficients and interpreting results

We can explore regression coefficients to get a better understanding of just how important interactions are for predicting species’ occurrence probabilities (in comparison to other covariates). This is perhaps the strongest property of conditional MRFs, as competing methods (such as Joint Species Distribution Models) do not provide interpretable mechanisms for comparing the relative importances of interactions and fixed covariates. MRF functions conveniently return a matrix of important coefficients for each node in the graph, as well as their relative importances (calculated using the formula B^2 / sum(B^2), where the vector of Bs represents regression coefficients for predictor variables). Variables with an underscore (_) indicate an interaction between a covariate and another node, suggesting that conditional dependencies of the two nodes vary across environmental gradients

MRF_mod$key_coefs$Hzosteropis
#>                      Variable Rel_importance Standardised_coef   Raw_coef
#> 1                  Hkillangoi     0.64623474        -2.3087824 -2.3087824
#> 5 scale.prop.zos_Microfilaria     0.12980415        -1.0347421 -1.0347421
#> 3                Microfilaria     0.10143149         0.9146907  0.9146907
#> 4              scale.prop.zos     0.09788426        -0.8985542 -0.8985542
#> 2                        Plas     0.01785290        -0.3837446 -0.3837446
MRF_mod$key_coefs$Hkillangoi
#>         Variable Rel_importance Standardised_coef   Raw_coef
#> 1    Hzosteropis     0.79853150        -2.3087824 -2.3087824
#> 2   Microfilaria     0.11897509        -0.8911791 -0.8911791
#> 3 scale.prop.zos     0.08154704        -0.7378041 -0.7378041
MRF_mod$key_coefs$Plas
#>                      Variable Rel_importance Standardised_coef   Raw_coef
#> 2                Microfilaria     0.63590587         1.8658732  1.8658732
#> 3              scale.prop.zos     0.24611774        -1.1607994 -1.1607994
#> 5 scale.prop.zos_Microfilaria     0.07969128         0.6605278  0.6605278
#> 1                 Hzosteropis     0.02689758        -0.3837446 -0.3837446
#> 4  scale.prop.zos_Hzosteropis     0.01023366        -0.2367016 -0.2367016
MRF_mod$key_coefs$Microfilaria
#>                     Variable Rel_importance Standardised_coef   Raw_coef
#> 3                       Plas      0.4423652         1.8658732  1.8658732
#> 4             scale.prop.zos      0.1589327        -1.1184028 -1.1184028
#> 5 scale.prop.zos_Hzosteropis      0.1360445        -1.0347421 -1.0347421
#> 1                Hzosteropis      0.1063078         0.9146907  0.9146907
#> 2                 Hkillangoi      0.1009129        -0.8911791 -0.8911791
#> 6        scale.prop.zos_Plas      0.0554369         0.6605278  0.6605278

To work through more in-depth tutorials and examples, see the vignettes in the package and check out some of the recent papers that have been published using the method

vignette("Bird_Parasite_CRF")
vignette("Gaussian_Poisson_CRFs")

Clark et al 2018 Ecology

Peel et al 2019 Emerging Microbes & Infections

Fountain-Jones et al 2019 Journal of Animal Ecology

Gallen et al 2019 Journal of Animal Ecology

Clark et al 2020 Transboundary and Emerging Diseases

Clark et al 2020 Parasites & Vectors

Clark et al 2020 Nature Climate Change

Brian & Aldridge 2021 Journal of Animal Ecology

Sallam et al 2023 Parasites & Vectors

Key references

Cheng, J., Levina, E., Wang, P. & Zhu, J. (2014). A sparse Ising model with covariates. Biometrics 70:943-953.

Clark, N.J., Wells, K., Lindberg, O. (2018). Unravelling changing interspecific interactions across environmental gradients using Markov random fields. Ecology DOI: https://doi.org/10.1002/ecy.2221

Clark, N.J., K. Wells, D. Dimitrov, and S.M. Clegg. (2016). Co-infections and environmental conditions drive the distributions of blood parasites in wild birds. Journal of Animal Ecology 85:1461-1470

Clark, N.J., S. Tozer, C. Wood, S.M. Firestone, M. Stevenson, C. Caraguel, A.L. Chaber, J. Heller, R.J. Soares Magalhães. 2020. Unravelling animal exposure profiles of human Q fever cases in Queensland, Australia using natural language processing. Transboundary and Emerging Diseases DOI: https://doi.org/10.1111/tbed.13565.

Clark, N.J., K. Owada, E. Ruberanziza, G. Ortu, I. Umulisa, U. Bayisenge, J.B. Mbonigaba, J.B. Mucaca, W. Lancaster, A. Fenwick, R.J. Soares Magalhães, A. Mbituyumuremyi. 2020. Parasite associations predict infection risk: incorporating co-infections in predictive models for neglected tropical diseases. Parasites & Vectors 13:1-16.

Clark, N.J., J.T. Kerry, C.I. Fraser. 2020. Rapid winter warming could disrupt coastal marine fish community structure. Nature Climate Change DOI: https://doi.org/10.1038/s41558-020-0838-5

Fountain‐Jones, N.M., N.J. Clark, A.C. Kinsley, M. Carstensen, J. Forester, T.J. Johnson, E. Miller, S. Moore, T.M. Wolf, M.E. Craft. 2019. Microbial associations and spatial proximity predict North American moose (Alces alces) gastrointestinal community composition. Journal of Animal Ecology 89:817-828.

Lindberg, O. (2016). Markov Random Fields in Cancer Mutation Dependencies. Master’s of Science Thesis. University of Turku, Turku, Finland.

Peel, A.J., K. Wells, J. Giles, V. Boyd, A. Burroughs, D. Edson, G. Crameri, M. L. Baker, H. Field, L-F. Wang, H. McCallum, R. K. Plowright, N. Clark. 2019. Synchronous shedding of multiple bat paramyxoviruses coincides with peak periods of Hendra virus spillover. Emerging Microbes & Infections 8:1314-1323

This project is licensed under the terms of the GNU General Public License (GNU GPLv3)

mrfcov's People

Contributors

nicholasjclark avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

mrfcov's Issues

Checking coefficient uncertainty

Hi,

I've fitted a CRF and I'm trying to figure out my results regarding the significance of my associations between species. With 'bootstrap_MRF' and 'plotMRF_hm' I get the confidence intervals, but I do not understand how (some of) the means can be outside the interval? I'm sorry if this is very trivial, I'm new to this method.

BR,

Anna

Numeric column names produce errors

Any column names that start with a numeric (i.e. '14232') will come out of glmnet with an 'X' pasted to the front (i.e. 'X14232'). This will throw errors when re-ordering coefficient matrices to match the input data, typically with a message like 'undefined columns selected'.

Indirect coefficients omitted when predicting networks with bootstrapped MRF models

predict_MRFnetworks.R has inconsistencies in handling indirect coefficients with bootstrapped MRF models, resulting in indirect coefficients omitted from the predicted networks when a bootstrapped model is used.

For example, within code chunk 119-133, line 126 checks if the length of MRF_mod$indirect_coefs is greater than 0

# Create an MRF_mod object to feed to the predict function
MRF_mod_booted <- list()
MRF_mod_booted$graph <- MRF_mod$direct_coef_means[ , 2:(n_nodes + 1)]
MRF_mod_booted$intercepts <- as.vector(MRF_mod$direct_coef_means[ , 1])
MRF_mod_booted$direct_coefs <- MRF_mod$direct_coef_means
MRF_mod_booted$mod_family <- MRF_mod$mod_family
MRF_mod_booted$mod_type <- 'MRFcov'
if(length(MRF_mod$indirect_coefs) > 0){
  for(i in seq_along(MRF_mod$indirect_coef_mean)){
    MRF_mod_booted$indirect_coefs[[i]] <- list(MRF_mod$indirect_coef_mean[[i]],"")[1]
    }
  names(MRF_mod_booted$indirect_coefs) <- names(MRF_mod$indirect_coef_mean)
} else {
  MRF_mod_booted$indirect_coefs <- NULL
}

However, MRF_mod$indirect_coefs does not exist for booted models, so length will always = 0. It should be MRF_mod$indirect_coef_mean. This means that MRF_mod_booted$indirect_coefs is never created, causing downstream problems (i.e. predicted networks will not change with the environment)

The same issue occurs again on line 189

 # If covariates exist, incorporate their modifications to species' interactions 
 if(length(MRF_mod$indirect_coefs) > 0){

Checking MRF_mod$indirect_coefs for booted models will always result in FALSE.

unused argument lambda1

Hi There,

I am trying to use evaluate the interspecific interactions between mosquito species using the unpenalized function. When I add the lambda1 variable to the code I get an error. "unused argument (lambda1 = 0.1)"

example from the test dataset with the package
MRF_mod<- MRFcov(data = Bird.parasites, n_nodes = 4, family = 'binomial', lambda1 = 0.1)

have you seen this before and can you point me in the right direction for resolving this?

thanks!

Error with indirect coefficient symmetrization 'new columns would leave holes after existing columns'

I am trying to use the package to produce directly interpretable coefficients that allow me to determine the relative importance of species' interaction and environmental covariates in driving occurrence probabilities. I have been successful in exploring the regression coefficients and interpreting the results when I used my two species co-occurrence and three covariates. However, when I am trying to use 5 covariates, it gives me an error:

(Error in [<-.data.frame(*tmp*, , (n_nodes + (n_covariates * i) + (3 - :
new columns would leave holes after existing columns.)

Is it a maximum of covariates number? I attach my data file in case you see any basic issue that I am not aware.

Thank you very much in advance. Congratulations for this fantastic package, I find really useful and interesting. I am looking forward to hearing from you.

Add citation info

Add citation information to inst directory once the Ecology doi is produced

Deviance not calculated in cv_MRF_diag

After encountering an odd problem with cv_MRF_diag with Poisson data, I found a simple solution (with the help of @hezibu). The problem I had was that NaNs were produced when dividing zero by zero while creating preds_log. Other numbers divided by zero were taken care of with is.infinite but NaNs weren't.
I used the same procedure already implemented in the function, to set infinite values to zero (when numbers were divided by 0), but with is.nan instead if is.infinite.

Here is my suggested fix [my additions in rows 299 and 324:

function (data, symmetrise, n_nodes, n_cores, sample_seed, n_folds, 
                            n_fold_runs, n_covariates, compare_null, family, plot = TRUE, 
                            cached_model, cached_predictions, mod_labels = NULL) 
{
  if (!(family %in% c("gaussian", "poisson", "binomial"))) 
    stop("Please select one of the three family options:\n         \"gaussian\", \"poisson\", \"binomial\"")
  if (missing(symmetrise)) {
    symmetrise <- "mean"
  }
  if (missing(compare_null)) {
    compare_null <- FALSE
  }
  if (missing(n_folds)) {
    n_folds <- 10
  }
  else {
    if (sign(n_folds) == 1) {
      n_folds <- ceiling(n_folds)
    }
    else {
      stop("Please provide a positive integer for n_folds")
    }
  }
  if (missing(n_fold_runs)) {
    n_fold_runs <- n_folds
  }
  else {
    if (sign(n_fold_runs) == 1) {
      n_fold_runs <- ceiling(n_fold_runs)
    }
    else {
      stop("Please provide a positive integer for n_fold_runs")
    }
  }
  if (missing(n_cores)) {
    n_cores <- 1
  }
  else {
    if (sign(n_cores) != 1) {
      stop("Please provide a positive integer for n_cores")
    }
    else {
      if (sfsmisc::is.whole(n_cores) == FALSE) {
        stop("Please provide a positive integer for n_cores")
      }
    }
  }
  if (missing(n_nodes)) {
    warning("n_nodes not specified. using ncol(data) as default, assuming no covariates", 
            call. = FALSE)
    n_nodes <- ncol(data)
    n_covariates <- 0
  }
  else {
    if (sign(n_nodes) != 1) {
      stop("Please provide a positive integer for n_nodes")
    }
    else {
      if (sfsmisc::is.whole(n_nodes) == FALSE) {
        stop("Please provide a positive integer for n_nodes")
      }
    }
  }
  if (missing(n_covariates)) {
    n_covariates <- ncol(data) - n_nodes
  }
  else {
    if (sign(n_covariates) != 1) {
      stop("Please provide a positive integer for n_covariates")
    }
    else {
      if (sfsmisc::is.whole(n_covariates) == FALSE) {
        stop("Please provide a positive integer for n_covariates")
      }
    }
  }
  if (missing(sample_seed)) {
    sample_seed <- ceiling(runif(1, 0, 1e+05))
  }
  if (missing(cached_model)) {
    cat("Generating node-optimised Conditional Random Fields model", 
        "\n", sep = "")
    if (family == "binomial") {
      mrf <- MRFcov(data = data, symmetrise = symmetrise, 
                    n_nodes = n_nodes, n_cores = n_cores, family = "binomial")
      if (compare_null) {
        cat("\nGenerating Markov Random Fields model (no covariates)", 
            "\n", sep = "")
        mrf_null <- MRFcov(data = data[, 1:n_nodes], 
                           symmetrise = symmetrise, n_nodes = n_nodes, 
                           n_cores = n_cores, family = "binomial")
      }
    }
    if (family == "poisson") {
      mrf <- MRFcov(data = data, symmetrise = symmetrise, 
                    n_nodes = n_nodes, n_cores = n_cores, family = "poisson")
      if (compare_null) {
        cat("\nGenerating Markov Random Fields model (no covariates)", 
            "\n", sep = "")
        mrf_null <- MRFcov(data = data[, 1:n_nodes], 
                           symmetrise = symmetrise, n_nodes = n_nodes, 
                           n_cores = n_cores, family = "poisson")
      }
    }
    if (family == "gaussian") {
      mrf <- MRFcov(data = data, symmetrise = symmetrise, 
                    n_nodes = n_nodes, n_cores = n_cores, family = "gaussian")
      if (compare_null) {
        cat("\nGenerating Markov Random Fields model (no covariates)", 
            "\n", sep = "")
        mrf_null <- MRFcov(data = data[, 1:n_nodes], 
                           symmetrise = symmetrise, n_nodes = n_nodes, 
                           n_cores = n_cores, family = "gaussian")
      }
    }
  }
  else {
    mrf <- cached_model$mrf
    if (compare_null) {
      mrf_null <- cached_model$mrf_null
    }
  }
  if (family == "binomial") {
    folds <- caret::createFolds(rownames(data), n_folds)
    all_predictions <- if (missing(cached_predictions)) {
      predict_MRF(data, mrf)
    }
    else {
      cached_predictions$predictions
    }
    cv_predictions <- lapply(seq_len(n_folds), function(k) {
      test_data <- data[folds[[k]], ]
      predictions <- all_predictions[[2]][folds[[k]], 
      ]
      true_pos <- (predictions == test_data[, c(1:n_nodes)])[test_data[, 
                                                                       c(1:n_nodes)] == 1]
      false_pos <- (predictions == 1)[predictions != test_data[, 
                                                               c(1:n_nodes)]]
      true_neg <- (predictions == test_data[, c(1:n_nodes)])[test_data[, 
                                                                       c(1:n_nodes)] == 0]
      false_neg <- (predictions == 0)[predictions != test_data[, 
                                                               c(1:n_nodes)]]
      pos_pred <- sum(true_pos, na.rm = TRUE)/(sum(true_pos, 
                                                   na.rm = TRUE) + sum(false_pos, na.rm = TRUE))
      neg_pred <- sum(true_neg, na.rm = TRUE)/(sum(true_neg, 
                                                   na.rm = TRUE) + sum(false_neg, na.rm = TRUE))
      sensitivity <- sum(true_pos, na.rm = TRUE)/(sum(true_pos, 
                                                      na.rm = TRUE) + sum(false_neg, na.rm = TRUE))
      specificity <- sum(true_neg, na.rm = TRUE)/(sum(true_neg, 
                                                      na.rm = TRUE) + sum(false_pos, na.rm = TRUE))
      tot_pred <- (sum(true_pos, na.rm = TRUE) + sum(true_neg, 
                                                     na.rm = TRUE))/(length(as.matrix(test_data[, 
                                                                                                c(1:n_nodes)])))
      list(mean_pos_pred = mean(pos_pred, na.rm = TRUE), 
           mean_neg_pred = mean(neg_pred, na.rm = TRUE), 
           mean_tot_pred = mean(tot_pred, na.rm = TRUE), 
           mean_sensitivity = mean(sensitivity, na.rm = TRUE), 
           mean_specificity = mean(specificity, na.rm = TRUE))
    })
    plot_dat <- purrr::map_df(cv_predictions, magrittr::extract, 
                              c("mean_pos_pred", "mean_tot_pred", "mean_sensitivity", 
                                "mean_specificity"))
    if (compare_null) {
      all_predictions <- if (missing(cached_predictions)) {
        predict_MRF(data[, 1:n_nodes], mrf_null)
      }
      else {
        cached_predictions$null_predictions
      }
      cv_predictions_null <- lapply(seq_len(n_folds), 
                                    function(k) {
                                      test_data <- data[folds[[k]], 1:n_nodes]
                                      predictions <- all_predictions[[2]][folds[[k]], 
                                      ]
                                      true_pos <- (predictions == test_data)[test_data == 
                                                                               1]
                                      false_pos <- (predictions == 1)[predictions != 
                                                                        test_data]
                                      true_neg <- (predictions == test_data)[test_data == 
                                                                               0]
                                      false_neg <- (predictions == 0)[predictions != 
                                                                        test_data]
                                      pos_pred <- sum(true_pos, na.rm = TRUE)/(sum(true_pos, 
                                                                                   na.rm = TRUE) + sum(false_pos, na.rm = TRUE))
                                      neg_pred <- sum(true_neg, na.rm = TRUE)/(sum(true_neg, 
                                                                                   na.rm = TRUE) + sum(false_neg, na.rm = TRUE))
                                      sensitivity <- sum(true_pos, na.rm = TRUE)/(sum(true_pos, 
                                                                                      na.rm = TRUE) + sum(false_neg, na.rm = TRUE))
                                      specificity <- sum(true_neg, na.rm = TRUE)/(sum(true_neg, 
                                                                                      na.rm = TRUE) + sum(false_pos, na.rm = TRUE))
                                      tot_pred <- (sum(true_pos, na.rm = TRUE) + 
                                                     sum(true_neg, na.rm = TRUE))/(length(as.matrix(test_data)))
                                      list(mean_pos_pred = mean(pos_pred, na.rm = TRUE), 
                                           mean_neg_pred = mean(neg_pred, na.rm = TRUE), 
                                           mean_tot_pred = mean(tot_pred, na.rm = TRUE), 
                                           mean_sensitivity = mean(sensitivity, na.rm = TRUE), 
                                           mean_specificity = mean(specificity, na.rm = TRUE))
                                    })
      plot_dat_null <- purrr::map_df(cv_predictions_null, 
                                     magrittr::extract, c("mean_pos_pred", "mean_tot_pred", 
                                                          "mean_sensitivity", "mean_specificity"))
      if (is.null(mod_labels)) {
        plot_dat$model <- "CRF"
        plot_dat_null$model <- "MRF (no covariates)"
      }
      else {
        plot_dat$model <- mod_labels[1]
        plot_dat_null$model <- mod_labels[2]
      }
      plot_dat <- rbind(plot_dat, plot_dat_null)
      if (plot) {
        plot_binom_cv_diag_optim(plot_dat, compare_null = TRUE)
      }
      else {
        return(plot_dat)
      }
    }
    else {
      if (plot) {
        plot_binom_cv_diag_optim(plot_dat, compare_null = FALSE)
      }
      else {
        return(plot_dat)
      }
    }
  }
  if (family == "gaussian") {
    folds <- caret::createFolds(rownames(data), n_folds)
    cv_predictions <- lapply(seq_len(n_folds), function(k) {
      test_data <- data[folds[[k]], ]
      predictions <- predict_MRF(test_data, mrf)
      Rsquared <- vector()
      MSE <- vector()
      for (i in seq_len(ncol(predictions))) {
        Rsquared[i] <- cor.test(test_data[, i], predictions[, 
                                                            i])[[4]]
        MSE[i] <- mean((test_data[, i] - predictions[, 
                                                     i])^2)
      }
      list(Rsquared = mean(Rsquared, na.rm = T), MSE = mean(MSE, 
                                                            na.rm = T))
    })
    plot_dat <- purrr::map_df(cv_predictions, magrittr::extract, 
                              c("Rsquared", "MSE"))
    if (compare_null) {
      cv_predictions_null <- lapply(seq_len(n_folds), 
                                    function(k) {
                                      test_data <- data[folds[[k]], 1:n_nodes]
                                      predictions <- predict_MRF(test_data, mrf_null)
                                      Rsquared <- vector()
                                      MSE <- vector()
                                      for (i in seq_len(ncol(predictions))) {
                                        Rsquared[i] <- cor.test(test_data[, i], 
                                                                predictions[, i])[[4]]
                                        MSE[i] <- mean((test_data[, i] - predictions[, 
                                                                                     i])^2)
                                      }
                                      list(Rsquared = mean(Rsquared, na.rm = T), 
                                           MSE = mean(MSE, na.rm = T))
                                    })
      plot_dat_null <- purrr::map_df(cv_predictions_null, 
                                     magrittr::extract, c("Rsquared", "MSE"))
      if (is.null(mod_labels)) {
        plot_dat$model <- "CRF"
        plot_dat_null$model <- "MRF (no covariates)"
      }
      else {
        plot_dat$model <- mod_labels[1]
        plot_dat_null$model <- mod_labels[2]
      }
      plot_dat <- rbind(plot_dat, plot_dat_null)
      if (plot) {
        plot_gauss_cv_diag_optim(plot_dat, compare_null = TRUE)
      }
      else {
        return(plot_dat)
      }
    }
    else {
      if (plot) {
        plot_gauss_cv_diag_optim(plot_dat, compare_null = FALSE)
      }
      else {
        return(plot_dat)
      }
    }
  }
  if (family == "poisson") {
    folds <- caret::createFolds(rownames(data), n_folds)
    cv_predictions <- lapply(seq_len(n_folds), function(k) {
      test_data <- data[folds[[k]], ]
      predictions <- predict_MRF(test_data, mrf)
      Deviance <- vector()
      MSE <- vector()
      for (i in seq_len(ncol(predictions))) {
        preds_log <- log(test_data[, i]/predictions[, 
                                                    i])
        preds_log[is.infinite(preds_log)] <- 0
        preds_log[is.nan(preds_log)] <- 0
        test_data_wzeros <- test_data
        test_data_wzeros[predictions[, i] == 0, i] <- 0
        Deviance[i] <- mean(2 * sum(test_data_wzeros[, 
                                                     i] * preds_log - (test_data_wzeros[, i] - 
                                                                         predictions[, i])))
        MSE[i] <- mean((test_data[, i] - predictions[, 
                                                     i])^2)
      }
      list(Deviance = mean(Deviance, na.rm = T), MSE = mean(MSE, 
                                                            na.rm = T))
    })
    plot_dat <- purrr::map_df(cv_predictions, magrittr::extract, 
                              c("Deviance", "MSE"))
    if (compare_null) {
      cv_predictions_null <- lapply(seq_len(n_folds), 
                                    function(k) {
                                      test_data <- data[folds[[k]], 1:n_nodes]
                                      predictions <- predict_MRF(test_data, mrf_null)
                                      Deviance <- vector()
                                      MSE <- vector()
                                      for (i in seq_len(ncol(predictions))) {
                                        preds_log <- log(test_data[, i]/predictions[, 
                                                                                    i])
                                        preds_log[is.infinite(preds_log)] <- 0
                                        preds_log[is.nan(preds_log)] <- 0
                                        test_data_wzeros <- test_data
                                        test_data_wzeros[predictions[, i] == 0, 
                                                         i] <- 0
                                        Deviance[i] <- mean(2 * sum(test_data_wzeros[, 
                                                                                     i] * preds_log - (test_data_wzeros[, i] - 
                                                                                                         predictions[, i])))
                                        MSE[i] <- mean((test_data[, i] - predictions[, 
                                                                                     i])^2)
                                      }
                                      list(Deviance = mean(Deviance, na.rm = T), 
                                           MSE = mean(MSE, na.rm = T))
                                    })
      plot_dat_null <- purrr::map_df(cv_predictions_null, 
                                     magrittr::extract, c("Deviance", "MSE"))
      if (is.null(mod_labels)) {
        plot_dat$model <- "CRF"
        plot_dat_null$model <- "MRF (no covariates)"
      }
      else {
        plot_dat$model <- mod_labels[1]
        plot_dat_null$model <- mod_labels[2]
      }
      plot_dat <- rbind(plot_dat, plot_dat_null)
      if (plot) {
        plot_poiss_cv_diag_optim(plot_dat, compare_null = TRUE)
      }
      else {
        return(plot_dat)
      }
    }
    else {
      if (plot) {
        plot_poiss_cv_diag_optim(plot_dat, compare_null = FALSE)
      }
      else {
        return(plot_dat)
      }
    }
  }
}

Does this look adequate?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.