Fig. 6 - SLE-Lasso models

Author

Benjamin Doran

Published

January 17, 2025

julia setup
using DrWatson
@quickactivate projectdir()

using DataFramesMeta
using SpectralInference
using NewickTree
using MLJ
import JLD2
using GLMNet: GLMNet, glmnet
using Chain
using Distributions: Normal
using MLJBase: train_test_pairs
using Distances, Clustering
using Distributions
using Printf: @sprintf
using Muon, CSV, DataFrames
using Gotree_jll
using Random: seed!
using FreqTables
using NearestNeighbors
using HypothesisTests
using MultipleTesting: adjust, Bonferroni, BenjaminiHochberg
using StatsPlots, StatsBase
theme(:default, grid=false, tickdir=:out, label="")
include(srcdir("helpers.jl"))

ddir = datadir("exp_raw", "BB669")
rdir = projectdir("_research", "metabolite_model_outofbag") |> mkpath
pdir = plotsdir("metabolite_model_outofbag") |> mkpath
supptbl_dir = projectdir("_research", "SuppTables") |> mkpath

speciescolordf = CSV.read(datadir("exp_raw", "BB669", "subsettreecolors.csv"), DataFrame)
species_color_dict = Dict(k => v for (k, v) in zip(speciescolordf.species_name, speciescolordf.color));

Setup data and model

uniprot = readh5ad(datadir("exp_raw", "UP7047", "2020_02_UP7047.h5ad"))
biobank = readh5mu(joinpath(ddir, "BB669.h5mu"))
biobank.obs.kept_species .= biobank.obs.kept_species .== 1
mtx = biobank["UPorder_oggs"].X[:, :];
upsvd = svd(uniprot.X[:, :])
bbusv = SVD(projectinLSV(mtx, upsvd), upsvd.S, upsvd.Vt);
┌ Warning: Cannot join columns with the same name because var_names are intersecting.
└ @ Muon /Users/bend/.julia/packages/Muon/UKjAF/src/mudata.jl:367
# 356 species with >= 20 strain replicates
full_train_mask = biobank.obs.kept_species;
trnYdf = biobank.obs[full_train_mask, :];
# spectral distances are invariant fold partitions 
# because everything is being projected into UniProt latent space
# so precompute them...
UPfullPCs = bbusv.U[full_train_mask, :] * Diagonal(bbusv.S[:]);
partitions = getintervals(bbusv.S, alpha=1.5, q=0.75);
Dij = spectraldistances(bbusv.U, bbusv.S, partitions) ./ size(bbusv.V, 1); # 669 full CSB
subsetDij = Dij[full_train_mask, full_train_mask]; # 356 strain replicate CSB
# get metabolite data
metabolite_names_full = replace.(biobank["metabolites_foldchange"].var_names, "_rel" => "");
bb_met_lfc = biobank["metabolites_foldchange"].X[:, :];
bb_met_lfc[isinf.(bb_met_lfc)] .= 0.0;
metabolicdistance = pairwise(Euclidean(), bb_met_lfc; dims=1);

# filter to metaboltes with at least 10% measureable data
measurable_metabolites_mask = mapslices(c -> mean(c .== 0.0) < 0.9, bb_met_lfc[full_train_mask, :], dims=1) |> vec;
keepmetabolites_mask = measurable_metabolites_mask;
metabolite_names = metabolite_names_full[keepmetabolites_mask];
metabolite_label = biobank["metabolites_foldchange"].var.label[keepmetabolites_mask]

# subset to 356 sample dataset
metab_trnY = bb_met_lfc[full_train_mask, keepmetabolites_mask];
metab_bbextraY = bb_met_lfc[.!full_train_mask, keepmetabolites_mask];

Plot tree and metabolites

When we look at the Spectral Tree in relation to metabolites, we see structure in that clades tend to have similar metabolite capacities.

We define the SLE-Lasso models, based on the tree structure by using a Lasso model to discover which clade branches (portions of the tree) are most important for describing particular metabolites.

# plot cladogram
strvar_tree_hc = UPGMA_tree(subsetDij ./ size(biobank["UPorder_oggs"], 2))
subsettreestring = SpectralInference.newickstring(strvar_tree_hc, trnYdf.Strain_ID)
subsettree = readnw(subsettreestring);
plot(strvar_tree_hc,
    # size=(600, 900),
    lw=0.5,
    yflip=true,
    xmirror=true,
    xticks=:none,
    permute=(:y, :x),
    grid=false,
    tickdirection=:none,
    rightmargin=1Plots.Measures.mm,
    label="",
    framestyle=:grid,
)

# plot annotation rectangles
rectangle(w, h, x, y) = Shape(x .+ [0, w, w, 0], y .+ [0, 0, h, h])
speciesvector = trnYdf.Species[strvar_tree_hc.order]
breaks = findall(speciesvector[begin:(end-1)] .!= speciesvector[2:end])[Not([10, 11, 12, 13, 14, 15])]
edges = [(s, e) for (s, e) in zip(vcat([0], breaks), vcat(breaks, [length(speciesvector)]))];
rects = [rectangle(2, (e - s), 0, s + 0.5) for (s, e) in zip(vcat([0], breaks), vcat(breaks, [length(speciesvector)]))];
rectspeciescolors = permutedims(speciescolordf.color[indexin(speciesvector[first.(edges).+1], speciescolordf.species_name)]);
fancy_treeplot = plot!(permutedims(rects), fill=0.35, lw=0, c=rectspeciescolors, label="")

# plot legend for tree
specieslabels = [
    "Bacteroides uniformis",
    "Phocaeicola vulgatus",
    "Bacteroides thetaiotaomicron",
    "[Ruminococcus] gnavus",
    "Bifidobacterium breve",
    "[Eubacterium] rectale",
    "Dorea formicigenerans",
    "Coprococcus comes",
    "Blautia luti & Blautia wexlerae",
    "Anaerostipes hadrus",
]
fancy_treeplot = plot!(zeros(1, 10),
    legend=:left,
    labels=permutedims(reverse(specieslabels)),
    c=reverse(rectspeciescolors),
    legendfontsize=5,
);

# plot metabolite heatmap
treeorder = indexin(getleafnames(subsettree), trnYdf.Strain_ID);
hplot = heatmap(metab_trnY[treeorder, :],
    c=:bwr, clims=getlims(metab_trnY),
    xticks=(1:size(metab_trnY, 2), metabolite_label),
    xrotation=90, xtickfontsize=7,
    yticks=:none,
);

plot(fancy_treeplot, hplot, size=(900, 900))

SLE Lambda value analysis

# Takes 3 min
K = 1 # Make predictions with SPI-LASSO on 1 nearest neighbor
λ = 0.001
lambdas = exp10.(range(-4, 0, length=101))
lambdacol = last(findmin(x -> abs(x - λ), lambdas))
REPS = 5
NFOLDS = 4
adjust_rsquared(r2, n, df) = 1 - (1 - r2) * ((n - 1) / (n - 1 - df))
adjust_rsquared (generic function with 1 method)
seed!(424242) # this is stable within julia versions, for exact results use Julia 1.10
cv = StratifiedCV(nfolds=NFOLDS, shuffle=true);
folds = vcat([train_test_pairs(cv, 1:sum(full_train_mask), trnYdf.Species) for i in 1:REPS]...)

oof_preds_df_stacked = DataFrame()
oof_dropout_preds_df_stacked = DataFrame()
inf_preds_df_stacked = DataFrame()
coefdf = DataFrame()
models_tbl = []
for (i, (fold_trn, fold_tst)) in collect(enumerate(folds))

    # Use tree to get lineage traces for each training and test sample
    # using projections of taxa into UniProt so these loadings are constant regardless of folds
    foldPCs = UPfullPCs[fold_trn, :]
    # trn_nns = map(r->partialsortperm(r, 1:K), eachrow(subsetDij[fold_trn, fold_trn]))
    tst_nns = map(r -> partialsortperm(r, 1:K), eachrow(subsetDij[fold_tst, fold_trn]))
    # oob_nns = map(r -> partialsortperm(r, 1:K), eachrow(subset_oob_Dij[:, fold_trn]))
    bbextra_nns = map(r -> partialsortperm(r, 1:K), eachrow(subset_bbextra_Dij[:, fold_trn]))

    # UPGMA tree building...
    foldhc = UPGMA_tree(subsetDij[fold_trn, fold_trn])
    foldtree = readnw(SpectralInference.newickstring(foldhc, trnYdf.Strain_ID[fold_trn]))
    ordered_treeids = getleafids(foldtree)[indexin(trnYdf.Strain_ID[fold_trn], getleafnames(foldtree))]
    # ladderize!(foldtree, rev=false)

    # Make SLE ancester encoding for training set
    trnX_all = @chain begin
        spectral_lineage_encoding(foldtree, ordered_treeids)
        getfield.(:sle)
        stack
        float.(_)
    end
    isinternal_fold = map(!isleaf, prewalk(foldtree))
    trnX = trnX_all[:, isinternal_fold]

    num_descendent_species = map(prewalk(foldtree)) do node
        if !isroot(node)
            sps = trnYdf.Species[indexin(getleafnames(parent(node)), trnYdf.Strain_ID)]
            return length(unique(sps))
        else
            NaN
        end
    end

    # prepare mask for dropping out subspecies branches
    # any branch that only has a single species as desendents of its parent
    dropout_mask = num_descendent_species[isinternal_fold] .<= 1.0

    # Get features for each out-of-fold isolate
    oofX = map(tst_nns) do nn
        trnX[nn, :] |>
        mtx -> mean(mtx, dims=1)
        # df -> combine(df, [c => mean for c in 1:size(df, 2)])
    end |> x -> vcat(x...)


    # fit lasso model
    individual_metabolite_results = []
    # (target_idx, (target, mlabel)) = (findfirst(==("Phenylacetate"), metabolite_label), ("Phenylacetate", "Phenylacetate"))
    for (target_idx, (target, mlabel)) in enumerate(zip(metabolite_names, metabolite_label))

        mpath = glmnet(trnX, metab_trnY[fold_trn, target_idx], Normal();
            lambda=lambdas,
        )

        # get coefficients dropping out subspecies branches
        betas_droppedout = Matrix(deepcopy(mpath.betas))
        betas_droppedout[dropout_mask, :] .= 0.0

        push!(models_tbl, (;
            metabolite_name = target,
            metabolite_label = mlabel,
            fold = ((i - 1) % NFOLDS) + 1,
            resample = ((i - 1) ÷ NFOLDS) + 1,
            model = mpath
        ))

        # save results of trained model
        # in fold predictions
        inf_preds_df_stacked = vcat(inf_preds_df_stacked, DataFrame(
            :row_id => fold_trn,
            :msk_id => trnYdf.Strain_ID[fold_trn],
            :metabolite_name => target,
            :metabolite_label => mlabel,
            :fold => ((i - 1) % NFOLDS) + 1,
            :resample => ((i - 1) ÷ NFOLDS) + 1,
            :truth => metab_trnY[fold_trn, target_idx],
            # :preds => GLMNet.predict(mpath, trnX)[:, lambdacol],
            [Symbol("preds_$i") => v for (i, v) in zip(lambdas, eachcol(GLMNet.predict(mpath, trnX)))]...))

        # out-of-fold predictions
        oof_preds_df_stacked = vcat(oof_preds_df_stacked, DataFrame(
            :row_id => fold_tst,
            :msk_id => trnYdf.Strain_ID[fold_tst],
            :metabolite_name => target,
            :metabolite_label => mlabel,
            :fold => ((i - 1) % NFOLDS) + 1,
            :resample => ((i - 1) ÷ NFOLDS) + 1,
            :truth => metab_trnY[fold_tst, target_idx],
            # :preds => GLMNet.predict(mpath, tstX)[:, lambdacol],
            [Symbol("preds_$i") => v for (i, v) in zip(lambdas, eachcol(GLMNet.predict(mpath, oofX)))]...))
        
        # out-of-fold predictions dropping out subspecies branches
        oof_dropout_preds_df_stacked = vcat(oof_dropout_preds_df_stacked, DataFrame(
            :row_id => fold_tst,
            :msk_id => trnYdf.Strain_ID[fold_tst],
            :metabolite_name => target,
            :metabolite_label => mlabel,
            :fold => ((i - 1) % NFOLDS) + 1,
            :resample => ((i - 1) ÷ NFOLDS) + 1,
            :truth => metab_trnY[fold_tst, target_idx],
            # :preds => GLMNet.predict(mpath, tstX)[:, lambdacol],
            [Symbol("preds_$i") => v for (i, v) in zip(lambdas, eachcol(oofX * betas_droppedout))]...))

        # coefs of model
        coefdf = vcat(coefdf, DataFrame(
            :metabolite_name => target,
            :metabolite_label => mlabel,
            :fold => ((i - 1) % NFOLDS) + 1,
            :resample => ((i - 1) ÷ NFOLDS) + 1,
            :num_species_descendents => num_descendent_species[isinternal_fold],
            [Symbol("coefs_$k") => v for (k, v) in zip(lambdas, eachcol(mpath.betas))]...
        ))
    end
    println("on $(i)th resample")
end

CSV.write(joinpath(rdir, "oof_predictions_stacked_SLE_lambda=many.csv"), oof_preds_df_stacked)
CSV.write(joinpath(rdir, "oof_dropout_predictions_stacked_SLE_lambda=many.csv"), oof_dropout_preds_df_stacked)
CSV.write(joinpath(rdir, "infold_predictions_stacked_SLE_lambda=many.csv"), inf_preds_df_stacked)
CSV.write(joinpath(rdir, "coefs_SLE_lambda=many.csv"), coefdf)
JLD2.save(joinpath(rdir, "models_SLE_lambda=many.jld2"), Dict("models" => models_tbl))
JLD2.save(joinpath(rdir, "folds_SLE_lambda=many.jld2"), @strdict(folds))
on 1th resample
on 2th resample
on 3th resample
on 4th resample
on 5th resample
on 6th resample
on 7th resample
on 8th resample
on 9th resample
on 10th resample
on 11th resample
on 12th resample
on 13th resample
on 14th resample
on 15th resample
on 16th resample
on 17th resample
on 18th resample
on 19th resample
on 20th resample
oof_preds_df_stacked = CSV.read(joinpath(rdir, "oof_predictions_stacked_SLE_lambda=many.csv"), DataFrame)
oof_preds_df_stacked =
    @chain oof_preds_df_stacked begin
        stack(8:108)
        transform!(:variable => ByRow(s -> parse(Float64, last(split(s, "_")))) => :lambda)
        select(Not([:variable, :value]), :value => :preds)
    end
coefdf_stacked =
    @chain CSV.read(joinpath(rdir, "coefs_SLE_lambda=many.csv"), DataFrame) begin
        stack(6:106)
        transform!(:variable => ByRow(s -> parse(Float64, last(split(s, "_")))) => :lambda)
        select(Not([:variable, :value]), :value => :coef)
    end
mdlstatsdf_stacked =
    @chain coefdf_stacked begin
        groupby([:metabolite_label, :fold, :resample, :lambda])
        combine(
            :coef => (x -> sum(x .!= 0)) => :degrees_freedom,
            :fold => (x -> (3 / 4 * sum(full_train_mask) - 1)) => :orig_degrees_freedom, # num params possible
            :coef => (x -> mean(x .!= 0)) => :degrees_freedom_prop,
            [:num_species_descendents, :coef] => ((n, c) -> mean((n.==11)[(c.!=0)])) => :phylum_level,
            [:num_species_descendents, :coef] => ((n, c) -> mean((1 .< n .< 11)[(c.!=0)])) => :species_level,
            [:num_species_descendents, :coef] => ((n, c) -> mean((n.==1)[(c.!=0)])) => :strain_level,
        )
    end
mdlstatsdf =
    @chain oof_preds_df_stacked begin
        groupby([:metabolite_label, :fold, :resample, :lambda])
        combine(
            [:truth, :preds] => ((y, yhat) -> rsquared(yhat, y)) => :rsq,
            [:truth, :preds] => ((y, yhat) -> cor(yhat, y)) => :cor,
        )
        leftjoin(mdlstatsdf_stacked, on=[:metabolite_label, :fold, :resample, :lambda])
        transform!(
            [:rsq, :orig_degrees_freedom, :degrees_freedom] => ByRow((r, n, d) -> adjust_rsquared(r, n, d)) => :rsq_adj,
            :phylum_level => (x -> replace(x, NaN => 0.0)) => identity,
            :species_level => (x -> replace(x, NaN => 0.0)) => identity,
            :strain_level => (x -> replace(x, NaN => 0.0)) => identity,
        )
        select!([:metabolite_label, :fold, :resample, :lambda], :rsq, :rsq_adj, :cor, 6:12)
        disallowmissing
    end
CSV.write(joinpath(rdir, "oof_modelstats_SLE_lambda=many.csv"), mdlstatsdf)

mdlstatsdf_mean =
    @chain mdlstatsdf begin
        subset(:lambda => ByRow(>=(1e-3)))
        groupby([:metabolite_label, :lambda])
        combine(
            5:13 .=> mean,
            5:13 .=> std,
        )
    end

bestlambdamodels = @chain mdlstatsdf_mean begin
    groupby(:metabolite_label)
    subset(:rsq_adj_mean => (x -> x .== maximum(x)))
end
32×20 DataFrame
7 rows omitted
Row metabolite_label lambda rsq_mean rsq_adj_mean cor_mean degrees_freedom_mean orig_degrees_freedom_mean degrees_freedom_prop_mean phylum_level_mean species_level_mean strain_level_mean rsq_std rsq_adj_std cor_std degrees_freedom_std orig_degrees_freedom_std degrees_freedom_prop_std phylum_level_std species_level_std strain_level_std
String31 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 Acetate 0.0758578 0.713737 0.678213 0.853746 29.3 266.0 0.11015 0.0 0.344767 0.655233 0.0471556 0.0524868 0.0290627 2.47301 0.0 0.00929704 0.0 0.0434495 0.0434495
2 Butyrate 0.131826 0.852513 0.843584 0.927411 15.2 266.0 0.0571429 0.132768 0.341282 0.52595 0.0365737 0.0384522 0.0217612 1.47256 0.0 0.00553592 0.0129736 0.042269 0.0513115
3 Propionate 0.144544 0.537103 0.501361 0.741929 19.0 266.0 0.0714286 0.0164044 0.224591 0.759005 0.0529746 0.0564935 0.0433324 2.42791 0.0 0.00912747 0.0419871 0.0450987 0.0734243
4 Succinate 0.1 0.61316 0.567427 0.791151 27.95 266.0 0.105075 0.00740741 0.300704 0.691888 0.0363944 0.0415038 0.0250723 3.73427 0.0 0.0140386 0.0227995 0.0448213 0.0465965
5 2-Methylbutyrate 0.229087 0.594906 0.569073 0.781151 16.0 266.0 0.0601504 0.127386 0.127386 0.745227 0.0598302 0.0620661 0.0443341 2.2711 0.0 0.00853797 0.0178117 0.0178117 0.0356234
6 3-Aminoisobutyrate 0.144544 0.177975 0.133944 0.455166 13.5 266.0 0.0507519 0.0 0.158138 0.841862 0.120498 0.12507 0.118141 1.90567 0.0 0.00716417 0.0 0.0359807 0.0359807
7 5-Aminovalerate 0.275423 0.543044 0.506948 0.747554 19.4 266.0 0.0729323 0.0488415 0.269576 0.681582 0.0374495 0.039984 0.0263265 2.64376 0.0 0.00993895 0.0506661 0.0730989 0.0667503
8 Alanine 0.057544 0.0639759 0.0281666 0.304284 9.75 266.0 0.0366541 0.0 0.38438 0.61562 0.0437672 0.0461699 0.103909 2.07428 0.0 0.00779804 0.0 0.117553 0.117553
9 Aspartate 0.144544 0.44623 0.394543 0.678519 22.55 266.0 0.0847744 0.0719725 0.198808 0.72922 0.0827505 0.0913016 0.0659453 2.23548 0.0 0.00840406 0.0331121 0.0457345 0.054044
10 Benzoate 0.275423 0.0891614 0.0758291 0.397764 3.8 266.0 0.0142857 0.0 0.500119 0.499881 0.101351 0.103552 0.148358 1.39925 0.0 0.00526033 0.0 0.232841 0.232841
11 Cysteine 0.0691831 0.281043 0.17623 0.544532 33.65 266.0 0.126504 0.0 0.212563 0.787437 0.101866 0.118007 0.0956086 2.43386 0.0 0.00914986 0.0 0.0467705 0.0467705
12 Glutamate 0.0691831 0.182157 0.114249 0.456756 20.3 266.0 0.0763158 0.0 0.21916 0.78084 0.0474709 0.0518121 0.0648258 2.55672 0.0 0.00961175 0.0 0.0511875 0.0511875
13 Glycine 0.20893 0.711895 0.710018 0.860794 1.8 266.0 0.00676692 0.0 0.608333 0.391667 0.168478 0.169533 0.0820137 0.523148 0.0 0.00196672 0.0 0.260875 0.260875
21 Palmitate 0.0630957 0.054432 0.0274684 0.297454 7.35 266.0 0.0276316 0.0 0.370325 0.629675 0.0228067 0.0222852 0.0457244 1.38697 0.0 0.00521417 0.0 0.137163 0.137163
22 Phenylacetate 0.190546 0.753281 0.730152 0.873597 22.85 266.0 0.0859023 0.0 0.21961 0.78039 0.0417033 0.0445242 0.0239348 2.39022 0.0 0.00898579 0.0 0.0285486 0.0285486
23 Phenylalanine 0.0398107 0.161741 0.078399 0.428751 23.95 266.0 0.0900376 0.0 0.180211 0.819789 0.0921165 0.101669 0.119583 2.60516 0.0 0.00979382 0.0 0.0535672 0.0535672
24 Proline 0.0630957 0.283207 0.227891 0.548492 19.0 266.0 0.0714286 0.107395 0.286167 0.606439 0.0652789 0.0692945 0.063463 2.75299 0.0 0.0103496 0.0156659 0.0656436 0.0660317
25 Serine 0.229087 0.252343 0.230219 0.562255 7.35 266.0 0.0276316 0.0 0.160022 0.839978 0.197141 0.207998 0.100998 1.81442 0.0 0.00682111 0.0 0.0611798 0.0611798
26 Threonine 0.0630957 0.123956 0.0632694 0.442701 17.05 266.0 0.0640977 0.0 0.218394 0.781606 0.401533 0.433177 0.171421 2.13923 0.0 0.00804223 0.0 0.0780699 0.0780699
27 Tryptamine 0.20893 0.743699 0.726334 0.867432 17.05 266.0 0.0640977 0.0 0.169351 0.830649 0.0507361 0.0523907 0.0288667 2.85574 0.0 0.0107359 0.0 0.0573328 0.0573328
28 Tryptophan 0.057544 0.307183 0.24848 0.578114 20.7 266.0 0.0778195 0.0 0.239016 0.760984 0.0732773 0.0796236 0.0797394 1.68897 0.0 0.00634952 0.0 0.0659877 0.0659877
29 Tyramine 0.0331131 0.263256 0.20591 0.553014 19.2 266.0 0.0721805 0.0 0.34614 0.65386 0.0523266 0.0524228 0.0564838 2.82097 0.0 0.0106052 0.0 0.0526799 0.0526799
30 Tyrosine 0.190546 0.0381358 0.0291187 0.30607 2.45 266.0 0.00921053 0.0 0.330833 0.669167 0.0622661 0.0634383 0.155491 0.944513 0.0 0.0035508 0.0 0.253298 0.253298
31 Valerate 0.190546 0.379166 0.336592 0.634284 17.0 266.0 0.0639098 0.100846 0.274531 0.624623 0.051993 0.0556162 0.0474183 1.86378 0.0 0.0070067 0.0353121 0.0359245 0.0516123
32 Valine 0.0363078 0.163081 0.0880475 0.437981 21.8 266.0 0.0819549 0.0 0.34365 0.65635 0.0587783 0.0633141 0.0934369 3.01924 0.0 0.0113505 0.0 0.0546937 0.0546937
infold_preds_df_stacked = CSV.read(joinpath(rdir, "infold_predictions_stacked_SLE_lambda=many.csv"), DataFrame)
infold_preds_df_stacked =
    @chain infold_preds_df_stacked begin
        stack(8:108)
        transform!(:variable => ByRow(s -> parse(Float64, last(split(s, "_")))) => :lambda)
        select(Not([:variable, :value]), :value => :preds)
    end;

infold_mdlstatsdf =
    @chain infold_preds_df_stacked begin
        groupby([:metabolite_label, :fold, :resample, :lambda])
        combine(
            [:truth, :preds] => ((y, yhat) -> rsquared(yhat, y)) => :rsq,
            [:truth, :preds] => ((y, yhat) -> cor(yhat, y)) => :cor,
        )
        leftjoin(mdlstatsdf_stacked, on=[:metabolite_label, :fold, :resample, :lambda])
        transform!(
            [:rsq, :orig_degrees_freedom, :degrees_freedom] => ByRow((r, n, d) -> adjust_rsquared(r, n, d)) => :rsq_adj,
            :phylum_level => (x -> replace(x, NaN => 0.0)) => identity,
            :species_level => (x -> replace(x, NaN => 0.0)) => identity,
            :strain_level => (x -> replace(x, NaN => 0.0)) => identity,
        )
        select!([:metabolite_label, :fold, :resample, :lambda], :rsq, :rsq_adj, :cor, 6:12)
        disallowmissing
    end;

Plot Acetate

speciescolordf = CSV.read(datadir("exp_raw", "BB669", "subsettreecolors.csv"), DataFrame)
species_color_dict = Dict(k => v for (k, v) in zip(speciescolordf.species_name, speciescolordf.color));
compound = "Acetate"
pltdf =
    @chain bestlambdamodels begin
        select([:metabolite_label, :lambda])
        leftjoin(oof_preds_df_stacked, on=[:metabolite_label, :lambda])
        leftjoin(biobank.obs[:, [:Strain_ID, :Species, :Donor]], on=:msk_id => :Strain_ID)
        subset(
            :metabolite_label => ByRow(==(compound)),
            :resample => ByRow(==(1)),
        )
        disallowmissing
    end;
ps = []
for lambda_index in [1, 17, 26, 51]
    compound = "Acetate"
    L = reverse(lambdas)[lambda_index]
    pltdf =
        @chain oof_preds_df_stacked begin
            leftjoin(biobank.obs[:, [:Strain_ID, :Species, :Donor]], on=:msk_id => :Strain_ID)
            subset(
                :metabolite_label => ByRow(==(compound)),
                :resample => ByRow(==(2)),
                :lambda => ByRow(==(L))
            )
            disallowmissing
        end

    speciescolors_ordered = get.(Ref(species_color_dict), pltdf.Species, :grey)
    pltlims = extrema(vcat(pltdf.preds, pltdf.truth))
    rsq = rsquared(pltdf.preds, pltdf.truth)

    p = plot(
        title="lambda = $(round(log10(L), digits=2)), r² = $(round(rsq, digits=2))",
        ylabel="$compound (log2FC)",
        xlabel="out-of-fold prediction (log2FC)",
        legend=:outerright,
        size=(800, 600),
        lims=pltlims,
        widen=true,
        margin=5Plots.mm,
    )
    plot!(identity, -20, 20, linestyle=:dash, color=:grey, label="1:1 line")
    @df pltdf scatter!(:preds, :truth,
        group=:Species,
        color=speciescolors_ordered,
        markerstrokewidth=0.1,
        markersize=8,
        # markeralpha=0.5,
    )
    push!(ps, p)
    savefig(joinpath(pdir, "scatter_$(compound)_lambda=$(L).pdf"))
end
plot(ps..., layout=grid(2,2), size=(800,800), legend=false)
compound = "Acetate"
L = reverse(lambdas)[51]
pltdf =
    @chain oof_preds_df_stacked begin
        leftjoin(biobank.obs[:, [:Strain_ID, :Species, :Donor]], on=:msk_id => :Strain_ID)
        subset(
            :metabolite_label => ByRow(==(compound)),
            :resample => ByRow(==(2)),
            :lambda => ByRow(==(L))
        )
        disallowmissing
    end;

speciescolors_ordered = get.(Ref(species_color_dict), pltdf.Species, :grey);
pltlims = extrema(vcat(pltdf.preds, pltdf.truth))
rsq = rsquared(pltdf.preds, pltdf.truth)
plot(
    title="lambda = $L, r² = $(round(rsq, digits=2))",
    ylabel="$compound (log2FC)",
    xlabel="out-of-fold prediction (log2FC)",
    legend=:outerright,
    size=(800, 600),
    lims=pltlims,
    widen=true,
    margin=5Plots.mm,
)
plot!(identity, -20,20, linestyle=:dash, color=:grey, label="1:1 line")
@df pltdf scatter!(:preds, :truth,
    group=:Species,
    color=speciescolors_ordered,
    markerstrokewidth=0.1,
    markersize=8,
    # markeralpha=0.5,
)
# savefig(joinpath(pdir, "scatter_$(compound)_lambda=$(L).pdf"))
preds_std_pltdf = map([1, 17, 26, 51,]) do lambda_index
    compound = "Acetate"
    L = reverse(lambdas)[lambda_index]
    pltdf =
        @chain oof_preds_df_stacked begin
            leftjoin(biobank.obs[:, [:Strain_ID, :Species, :Donor]], on=:msk_id => :Strain_ID)
            subset(
                :metabolite_label => ByRow(==(compound)),
                :resample => ByRow(==(2)),
                :lambda => ByRow(==(L))
            )
            disallowmissing
            groupby(:Species)
            combine(
                :preds => std,
            )
        end
    
    (; lambda=L,lambda_log10=log10(L), preds_std_mean=pltdf.preds_std |> mean, preds_std_std=pltdf.preds_std |> std)
end |> DataFrame
4×4 DataFrame
Row lambda lambda_log10 preds_std_mean preds_std_std
Float64 Float64 Float64 Float64
1 1.0 0.0 0.0198439 0.000378193
2 0.229087 -0.64 0.074131 0.0863078
3 0.1 -1.0 0.122492 0.123941
4 0.01 -2.0 0.450392 0.202781

Zooming in we can look at A. hadrus, and its measured Acetate relative concentration and our prediction for those Acetate relative concentration

Plot A.hadrus zoomin - Acetate

subsettree =
    @chain begin
        UPGMA_tree(Dij[full_train_mask, full_train_mask] ./ size(biobank["UPorder_oggs"], 2))
        SpectralInference.newickstring(biobank.obs.Strain_ID[full_train_mask])
        readnw
    end

# tree data
treeorder = indexin(getleafnames(subsettree), trnYdf.Strain_ID);
hadrusnames = trnYdf.Strain_ID[findall(contains.(trnYdf.Species, "hadrus"))]
hadrus_treeidxs = indexin(hadrusnames, getleafnames(subsettree));
hadrustree = readnw(NewickTree.nwstr(NewickTree.extract(subsettree, hadrusnames)))
hadrusnames = getleafnames(hadrustree)
hadrus_treeidxs = indexin(hadrusnames, getleafnames(subsettree));

# tree plot
tp = plot(hadrustree, fs=6,
    ylabel="A. hadrus",
    rightmargin=15Plots.mm,
    leftmargin=5Plots.mm,
    framestyle=:grid,
    ticks=false,
)
Ls = reverse(lambdas)[[1, 17, 26, 51]]
Ls_labs = round.(log10.(Ls), digits=2)

pltdf =
    @chain oof_preds_df_stacked begin
        leftjoin(biobank.obs[:, [:Strain_ID, :Species, :Donor]], on=:msk_id => :Strain_ID)
        subset(
            :metabolite_label => ByRow(==(compound)),
            :Species => ByRow(==("Anaerostipes hadrus")),
            :resample => ByRow(==(2)),
            :lambda => ByRow((Ls)),
        )
        disallowmissing
        unstack([:msk_id, :truth], :lambda, :preds)
        select(:msk_id, :truth, "1.0", "0.2290867652767773", "0.1", "0.01")
        _[indexin(hadrusnames, _.msk_id), :]
        # Matrix
    end
pltmtx = Matrix(pltdf[:, 2:end])
hp = heatmap(pltmtx, 
    c=:bwr, clims=getlims(pltmtx),
    xticks=(1:length(Ls)+1, ["measured", Ls_labs...]),
    xrotation=45,
    yticks=false,
    framestyle=:grid,
)

# combined plot
layout = @layout [a{0.4w} b]
plot(tp, hp, layout=layout, link=:y, size=(600, 600))
ps = map(2:5) do i
    heatmap(pltmtx[:, [1,i]],
        c=:bwr, clims=getlims(pltmtx),
        xticks=(1:2, ["measured", Ls_labs[i-1]]),
        colorbar=false,
    )
end
plot(ps..., layout=grid(2,2))
savefig(joinpath(pdir, "ahadrus_predictions_acetate_seperated_heatmaps.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/ahadrus_predictions_acetate_seperated_heatmaps.pdf"
# Ls = reverse(lambdas)[[1, 17, 26, 51, 76, 101]]
Ls = reverse(lambdas)[[1, 17, 26, 51]]
Ls_labs = round.(log10.(Ls), digits=2)

pltdf =
    @chain oof_preds_df_stacked begin
        leftjoin(biobank.obs[:, [:Strain_ID, :Species, :Donor]], on=:msk_id => :Strain_ID)
        subset(
            :metabolite_label => ByRow(==(compound)),
            :Species => ByRow(==("Anaerostipes hadrus")),
            :resample => ByRow(==(2)),
            :lambda => ByRow((Ls)),
        )
        disallowmissing
        unstack([:msk_id, :truth], :lambda, :preds)
        select(:msk_id, :truth, "1.0", "0.2290867652767773", "0.1", "0.01")
        _[indexin(hadrusnames, _.msk_id), :]
        # Matrix
    end
pltmtx = Matrix(pltdf[:, 2:end])
hp = heatmap(pltmtx, 
    c=:bwr, clims=getlims(pltmtx),
    xticks=(1:length(Ls)+1, ["measured", Ls_labs...]),
    xrotation=45,
    yticks=false,
    framestyle=:grid,
)

# combined plot
layout = @layout [a{0.4w} b]
plot(tp, hp, layout=layout, link=:y, size=(600, 600))
savefig(joinpath(pdir, "ahadrus_predictions_acetate.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/ahadrus_predictions_acetate.pdf"

Plot performance by regularization

first(mdlstatsdf, 5)
5×13 DataFrame
Row metabolite_label fold resample lambda rsq rsq_adj cor degrees_freedom orig_degrees_freedom degrees_freedom_prop phylum_level species_level strain_level
String31 Int64 Int64 Float64 Float64 Float64 Float64 Int64 Float64 Float64 Float64 Float64 Float64
1 Acetate 1 1 0.0001 0.701449 -5.593 0.846187 253 266.0 0.951128 0.00790514 0.13834 0.853755
2 Acetate 1 1 0.000109648 0.701549 -5.59078 0.846249 253 266.0 0.951128 0.00790514 0.13834 0.853755
3 Acetate 1 1 0.000120226 0.701657 -6.18734 0.846315 254 266.0 0.954887 0.00787402 0.137795 0.854331
4 Acetate 1 1 0.000131826 0.701774 -6.18453 0.846385 254 266.0 0.954887 0.00787402 0.137795 0.854331
5 Acetate 1 1 0.000144544 0.701899 -6.18152 0.846459 254 266.0 0.954887 0.00787402 0.137795 0.854331
last(mdlstatsdf, 5)
5×13 DataFrame
Row metabolite_label fold resample lambda rsq rsq_adj cor degrees_freedom orig_degrees_freedom degrees_freedom_prop phylum_level species_level strain_level
String31 Int64 Int64 Float64 Float64 Float64 Float64 Int64 Float64 Float64 Float64 Float64 Float64
1 Valine 4 5 0.691831 -0.021575 -0.021575 -2.85521e-17 0 266.0 0.0 0.0 0.0 0.0
2 Valine 4 5 0.758578 -0.021575 -0.021575 -2.85521e-17 0 266.0 0.0 0.0 0.0 0.0
3 Valine 4 5 0.831764 -0.021575 -0.021575 -2.85521e-17 0 266.0 0.0 0.0 0.0 0.0
4 Valine 4 5 0.912011 -0.021575 -0.021575 -2.85521e-17 0 266.0 0.0 0.0 0.0 0.0
5 Valine 4 5 1.0 -0.021575 -0.021575 -2.85521e-17 0 266.0 0.0 0.0 0.0 0.0
pltdf = mdlstatsdf_mean;
meandf = unstack(pltdf, :lambda, :metabolite_label, :rsq_adj_mean)
strainlvl_df = unstack(pltdf, :lambda, :metabolite_label, :strain_level_mean)
stddf = unstack(pltdf, :lambda, :metabolite_label, :rsq_adj_std);
plot(
    title="SLE model",
    ylabel="adjusted R²",
    xlabel="λ",
    xscale=:log10,
    xlims=(1e-3, 1),
    ylims=(-0.25, 1),
    margin=5Plots.mm,
    size=(800, 400),
    # widen=true,
)

# adj rsq traces
foreach(names(meandf)[2:end]) do cname
    plot!(meandf.lambda, meandf[!, cname],
        c=:lightblue,
        yticks=[-0.25, 0, 0.5, 1],
        xticks=[1e-3, 1e-2, 1e-1, 1e-0],
    )
end

@df bestlambdamodels scatter!(:lambda, :rsq_adj_mean, ms=5,
    c=:orange,
    # mz=:strain_level_mean, cmap=:hawaii,
)
mets_to_annotate = [
    "Acetate", "Butyrate", "Propionate", "Succinate", "Glycine", "Lysine",
    "Phenylalanine", "Tyramine", "Tryptophan",
]
# @df filter(:metabolite_label => ∈(mets_to_annotate), bestlambdamodels) begin
@df bestlambdamodels begin
    annotate!(:lambda, :rsq_adj_mean, text.(:metabolite_label .* "  ", 5, :right))
end
annotate!(1, 1, text("32 metabolites", 6, :right))
plot!()
savefig(joinpath(pdir, "SLE_lasso_adjrsq_lambda_strainlevel_simplified.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/SLE_lasso_adjrsq_lambda_strainlevel_simplified.pdf"
plot(
    size=(600, 150),
    xlims=(-3, 0),
    margin=5Plots.mm,
    yticks=0:3:6
)
histogram!(log10.(bestlambdamodels.lambda), bins=-3:0.1:0, c=:orange)
savefig(joinpath(pdir, "SLE_lasso_peak_adjrsq_lambda_histogram.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/SLE_lasso_peak_adjrsq_lambda_histogram.pdf"

strain level branches by lambda

pltdf = @chain mdlstatsdf begin
    subset(:lambda => ByRow(>(1e-3)))
    groupby(:lambda)
    combine(
        :strain_level => mean,
        :strain_level => std,
        :species_level => mean,
        :species_level => std,
        :rsq_adj => mean,
    )
end
75×6 DataFrame
50 rows omitted
Row lambda strain_level_mean strain_level_std species_level_mean species_level_std rsq_adj_mean
Float64 Float64 Float64 Float64 Float64 Float64
1 0.00109648 0.871257 0.0200739 0.124428 0.0191009 -1.3691
2 0.00120226 0.871614 0.0201968 0.12415 0.0192005 -1.31653
3 0.00131826 0.871607 0.0203881 0.124124 0.0192575 -1.26223
4 0.00144544 0.871831 0.0199868 0.123921 0.0190958 -1.2066
5 0.00158489 0.872342 0.020021 0.123609 0.01937 -1.15839
6 0.0017378 0.872332 0.0197473 0.123729 0.0194016 -1.1029
7 0.00190546 0.87226 0.0197135 0.123856 0.0195209 -1.04971
8 0.0020893 0.872532 0.0195717 0.123701 0.0195293 -0.999268
9 0.00229087 0.872663 0.0195391 0.123581 0.0196881 -0.9461
10 0.00251189 0.873215 0.0200185 0.123029 0.0203484 -0.891564
11 0.00275423 0.873548 0.0213664 0.122647 0.0216841 -0.839968
12 0.00301995 0.873609 0.0217357 0.122589 0.0219744 -0.78452
13 0.00331131 0.8732 0.0221041 0.122956 0.0224467 -0.72992
64 0.363078 0.236376 0.317135 0.238696 0.313711 0.212032
65 0.398107 0.199054 0.294403 0.248549 0.326514 0.203829
66 0.436516 0.168395 0.280825 0.25913 0.339121 0.195533
67 0.47863 0.146054 0.267577 0.265849 0.349479 0.186798
68 0.524807 0.12389 0.254061 0.274229 0.359781 0.176741
69 0.57544 0.0988073 0.226709 0.276658 0.370469 0.165939
70 0.630957 0.0684917 0.179011 0.280082 0.378822 0.154982
71 0.691831 0.0585419 0.160946 0.2793 0.384192 0.14471
72 0.758578 0.0486012 0.146953 0.272653 0.398665 0.133771
73 0.831764 0.0343973 0.122032 0.237085 0.390501 0.122831
74 0.912011 0.0272135 0.106901 0.19901 0.372846 0.112608
75 1.0 0.0233073 0.0997593 0.1925 0.374488 0.101416
plot(
    # title="SLE model",
    ylabel="fraction of\nstrain-level\nbranches",
    xlabel="λ",
    xscale=:log10,
    xlims=(1e-3, 1),
    ylims=(0, 1),
    yticks=[0,0.5,1],
    margin=5Plots.mm,
    leftmargin=10Plots.mm,
    size=(700, 200),
    # widen=true,
)
@df pltdf plot!(:lambda, :strain_level_mean,
    ribbon=:strain_level_std,
    # alpha=0.5,
    fillalpha=0.3,
    c=:grey,
    # label="strain-level",

)
# @df pltdf plot!(:lambda, :species_level_mean,
#     ribbon=:species_level_std,
#     # label="species-level",
# )
savefig(joinpath(pdir, "SLE_lasso_fractionstrainlevel_lambda_plot.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/SLE_lasso_fractionstrainlevel_lambda_plot.pdf"

Strain level branches for peak models

stbl = CSV.read(joinpath(rdir, "supptable_for_adjrsq_by_lambda_plt.csv"), DataFrame);
pltdf = @chain stbl begin
    groupby(:metabolite_label)
    subset(:rsq_adj_mean => x -> x .== (maximum(x)))
end
32×8 DataFrame
7 rows omitted
Row metabolite_label lambda rsq_mean rsq_std rsq_adj_mean rsq_adj_std degrees_freedom_mean strain_level_mean
String31 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 Acetate 0.0758578 0.713737 0.0471556 0.678213 0.0524868 29.3 0.655233
2 Butyrate 0.131826 0.852513 0.0365737 0.843584 0.0384522 15.2 0.52595
3 Propionate 0.144544 0.537103 0.0529746 0.501361 0.0564935 19.0 0.759005
4 Succinate 0.1 0.61316 0.0363944 0.567427 0.0415038 27.95 0.691888
5 2-Methylbutyrate 0.229087 0.594906 0.0598302 0.569073 0.0620661 16.0 0.745227
6 3-Aminoisobutyrate 0.144544 0.177975 0.120498 0.133944 0.12507 13.5 0.841862
7 5-Aminovalerate 0.275423 0.543044 0.0374495 0.506948 0.039984 19.4 0.681582
8 Alanine 0.057544 0.0639759 0.0437672 0.0281666 0.0461699 9.75 0.61562
9 Aspartate 0.144544 0.44623 0.0827505 0.394543 0.0913016 22.55 0.72922
10 Benzoate 0.275423 0.0891614 0.101351 0.0758291 0.103552 3.8 0.499881
11 Cysteine 0.0691831 0.281043 0.101866 0.17623 0.118007 33.65 0.787437
12 Glutamate 0.0691831 0.182157 0.0474709 0.114249 0.0518121 20.3 0.78084
13 Glycine 0.20893 0.711895 0.168478 0.710018 0.169533 1.8 0.391667
21 Palmitate 0.0630957 0.054432 0.0228067 0.0274684 0.0222852 7.35 0.629675
22 Phenylacetate 0.190546 0.753281 0.0417033 0.730152 0.0445242 22.85 0.78039
23 Phenylalanine 0.0398107 0.161741 0.0921165 0.078399 0.101669 23.95 0.819789
24 Proline 0.0630957 0.283207 0.0652789 0.227891 0.0692945 19.0 0.606439
25 Serine 0.229087 0.252343 0.197141 0.230219 0.207998 7.35 0.839978
26 Threonine 0.0630957 0.123956 0.401533 0.0632694 0.433177 17.05 0.781606
27 Tryptamine 0.20893 0.743699 0.0507361 0.726334 0.0523907 17.05 0.830649
28 Tryptophan 0.057544 0.307183 0.0732773 0.24848 0.0796236 20.7 0.760984
29 Tyramine 0.0331131 0.263256 0.0523266 0.20591 0.0524228 19.2 0.65386
30 Tyrosine 0.190546 0.0381358 0.0622661 0.0291187 0.0634383 2.45 0.669167
31 Valerate 0.190546 0.379166 0.051993 0.336592 0.0556162 17.0 0.624623
32 Valine 0.0363078 0.163081 0.0587783 0.0880475 0.0633141 21.8 0.65635
@df pltdf histogram(:strain_level_mean,
    bins=10,
    fillcolor=:lightgrey,
    size=(700, 200),
    xlims=(0,1),
    widen=true,
    xlabel="fraction of strain level branches",
    ylabel="count",
    margin=5Plots.mm,
)
savefig(joinpath(pdir, "SLE_lasso_adjrsq_strainlevel_histogram.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/SLE_lasso_adjrsq_strainlevel_histogram.pdf"

Fig. S14 - Leave-one-out analysis

@show metabolite_label
@show metabolite_names;
metabolite_label = ["Acetate", "Butyrate", "Propionate", "Succinate", "2-Methylbutyrate", "3-Aminoisobutyrate", "5-Aminovalerate", "Alanine", "Aspartate", "Benzoate", "Cysteine", "Glutamate", "Glycine", "Hexanoate", "Isobutyrate", "Isoleucine", "Isovaleric-Acid", "Leucine", "Lysine", "Methionine", "Palmitate", "Phenylacetate", "Phenylalanine", "Proline", "Serine", "Threonine", "Tryptamine", "Tryptophan", "Tyramine", "Tyrosine", "Valerate", "Valine"]
metabolite_names = ["Acetate", "Butyrate", "Propionate", "Succinate", "_2_Methylbutyrate", "_3_Aminoisobutyrate", "_5_Aminovalerate", "Alanine", "Aspartate", "Benzoate", "Cysteine", "Glutamate", "Glycine", "Hexanoate", "Isobutyrate", "Isoleucine", "Isovaleric_Acid", "Leucine", "Lysine", "Methionine", "Palmitate", "Phenylacetate", "Phenylalanine", "Proline", "Serine", "Threonine", "Tryptamine", "Tryptophan", "Tyramine", "Tyrosine", "Valerate", "Valine"]

Predictive capacity under a ‘leave-one-out’ training scheme where entire species are removed from the training set. SLE-LASSO models are trained leaving out 1 of 11 species and tested on the single species that was left out. Each of the 10 species used for training contains 20 or more strains. Mean predictive capacity ± 1 standard deviation (error bars) are shown for the training sets (top) and the validation sets (bottom). ‘inf’ represents adjusted r2 values that could not be calculated.

using Distributions: Normal
# Takes 1 min
K = 1 # Make predictions with SPI-LASSO on 1 nearest neighbor
λ = 1e-3
REPS = 5
NFOLDS = 4
lambdas = exp10.(range(0, -4, length=100));

folds_l1o_species = map(sort(unique(trnYdf.Species))) do speciesname
    trn_idx = findall(!=(speciesname), trnYdf.Species)
    tst_idx = findall(==(speciesname), trnYdf.Species)
    (trn_idx, tst_idx)
end

results = []

for (metname, metlabel) in zip(metabolite_names, metabolite_label)
    for (i, (fold_trn, fold_tst)) in collect(enumerate(folds_l1o_species))

        # Use tree to get lineage traces for each training and test sample
        # using projections of taxa into UniProt so these loadings are constant regardless of folds
        foldPCs = UPfullPCs[fold_trn, :]
        trn_nns = map(r -> partialsortperm(r, 1:K), eachrow(subsetDij[fold_trn, fold_trn]))
        tst_nns = map(r -> partialsortperm(r, 1:K), eachrow(subsetDij[fold_tst, fold_trn]))
        # oob_nns = map(r->partialsortperm(r, 1:K), eachrow(subset_oob_Dij[:, fold_trn]))

        # UPGMA tree building...
        foldhc = UPGMA_tree(subsetDij[fold_trn, fold_trn])
        foldtree = readnw(SpectralInference.newickstring(foldhc, trnYdf.Strain_ID[fold_trn]))
        # ladderize!(foldtree, rev=false)

        # Make SLE ancester encoding
        trnXdf_all = map(prewalk(foldtree)) do node
            tmp = zeros(length(fold_trn))
            tmp[indexin(getleafnames(node), trnYdf.Strain_ID[fold_trn])] .= 1
            "node__$(id(node))" => tmp
        end |> DataFrame
        isinternal_fold = map(!isleaf, prewalk(foldtree))
        # reorder nodes by tree depth
        treedists = mapinternalnodes(foldtree) do node
            network_distance(foldtree, node)
        end
        trnXdf = trnXdf_all[:, isinternal_fold]
        trnXdf = trnXdf[:, sortperm(treedists)]
        rename!(trnXdf, ["node__$i" for i in 1:size(trnXdf, 2)])

        # Get features for each out-of-fold isolate
        tstXdf = map(tst_nns) do nn
            trnXdf[nn, :] |>
            df -> combine(df, [c => mean for c in 1:size(df, 2)])
        end |> x -> vcat(x...)

        trnX = Matrix(trnXdf)
        tstX = Matrix(tstXdf)

        trnY = metab_trnY[fold_trn, metabolite_names.==metname] |> vec
        tstY = metab_trnY[fold_tst, metabolite_names.==metname] |> vec

        # @show metname, size(trnX), size(trnY)
        modelpath = glmnet(trnX, trnY, Normal(); lambda=lambdas)
        n = size(trnX, 2)
        dfs = map(x -> sum(x .!= 0), eachcol(modelpath.betas))

        trnR2 = rsquared.(eachcol(GLMNet.predict(modelpath, trnX)), Ref(trnY))
        tstR2 = rsquared.(eachcol(GLMNet.predict(modelpath, tstX)), Ref(tstY))

        push!(results, (;
            # model=modelpath,
            metabolite_name=metname,
            metabolite_label=metlabel,
            leftout_species=(onlyunique)(trnYdf.Species[fold_tst]),
            train_r2=trnR2,
            test_r2=tstR2,
            train_adjr2=adjust_rsquared.(trnR2, n, dfs),
            test_adjr2=adjust_rsquared.(tstR2, n, dfs),
            lambda=modelpath.lambda,
            model_dfs=dfs
        ))
    end
end
resultsdf = mapreduce(x->DataFrame(;x...), vcat, results);
first(resultsdf, 5)
5×9 DataFrame
Row metabolite_name metabolite_label leftout_species train_r2 test_r2 train_adjr2 test_adjr2 lambda model_dfs
String String String Float64 Float64 Float64 Float64 Float64 Int64
1 Acetate Acetate Anaerostipes hadrus -2.22045e-16 -8.78239 -2.22045e-16 -8.78239 1.0 0
2 Acetate Acetate Anaerostipes hadrus -2.22045e-16 -8.78239 -2.22045e-16 -8.78239 0.911163 0
3 Acetate Acetate Anaerostipes hadrus -2.22045e-16 -8.78239 -2.22045e-16 -8.78239 0.830218 0
4 Acetate Acetate Anaerostipes hadrus -2.22045e-16 -8.78239 -2.22045e-16 -8.78239 0.756463 0
5 Acetate Acetate Anaerostipes hadrus -2.22045e-16 -8.78239 -2.22045e-16 -8.78239 0.689261 0
bestlambda_leaveoneout_df = @chain resultsdf begin
    @groupby(:metabolite_label, :leftout_species)
    @subset(
        :test_adjr2 .== maximum(:test_adjr2)
    )
    @groupby(:metabolite_label, :leftout_species)
    @subset(
        :lambda .== maximum(:lambda)
    )
    # @groupby(:metabolite_label, :leftout_species)
    sort(:test_adjr2, rev=true)
end;
pltdf = @chain bestlambda_leaveoneout_df begin
    @groupby(:metabolite_label)
    @combine(
        :test_adjr2_mean = mean(:test_adjr2),
        :test_adjr2_std = std(:test_adjr2),
        :train_adjr2_mean = mean(:train_adjr2),
        :train_adjr2_std = std(:train_adjr2),
    )
    sort(:test_adjr2_mean, rev=true)
end;
p1 = @df pltdf bar(:metabolite_label, :train_adjr2_mean, yerror=:train_adjr2_std,
    xticks=((1:length(:metabolite_label)) .- 0.5, :metabolite_label),
    xrotation=30,
    xtickfontsize=6,
    bottommargin=5Plots.mm,
    ylabel="adj. R² (training set)",
    fillcolor=:grey,
    ylims=(0,1),
    widen=true,
)

p2 = @df pltdf bar(:metabolite_label, :test_adjr2_mean, yerror=:test_adjr2_std,
    xticks=((1:length(:metabolite_label)) .- 0.5, :metabolite_label),
    xrotation=30,
    xtickfontsize=6,
    bottommargin=5Plots.mm,
    ylabel="adj. R² (testing set)",
    fillcolor=:grey,
    # ylims=(-1,1),
    # widen=true,
);
plot(p1, p2, layout=grid(2,1), size=(650,500), link=:x)
savefig(joinpath(pdir, "leave-one-species-out-results.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/leave-one-species-out-results.pdf"

Fig. S15 - Overfitting & Coeffs

infold_pltdf =
    @chain infold_mdlstatsdf begin
        subset(:lambda => ByRow(>=(1e-3)))
        groupby([:metabolite_label, :lambda])
        combine(
            5:13 .=> mean,
            5:13 .=> std,
        )
    end
first(infold_pltdf, 5)
5×20 DataFrame
Row metabolite_label lambda rsq_mean rsq_adj_mean cor_mean degrees_freedom_mean orig_degrees_freedom_mean degrees_freedom_prop_mean phylum_level_mean species_level_mean strain_level_mean rsq_std rsq_adj_std cor_std degrees_freedom_std orig_degrees_freedom_std degrees_freedom_prop_std phylum_level_std species_level_std strain_level_std
String31 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 Acetate 0.001 0.937059 0.80382 0.96803 179.75 266.0 0.675752 0.00386866 0.133389 0.862742 0.0113282 0.0367428 0.00584001 4.08946 0.0 0.0153739 0.0050941 0.0144972 0.011977
2 Acetate 0.00109648 0.936936 0.806912 0.967971 178.2 266.0 0.669925 0.00361343 0.133422 0.862965 0.0113433 0.0359351 0.00584756 4.7195 0.0 0.0177425 0.00484482 0.015301 0.0127648
3 Acetate 0.00120226 0.93678 0.812522 0.967896 175.3 266.0 0.659023 0.00364206 0.13221 0.864148 0.0113366 0.0354377 0.0058449 4.90005 0.0 0.0184213 0.00488046 0.0140508 0.0122095
4 Acetate 0.00131826 0.936587 0.815338 0.967804 173.7 266.0 0.653008 0.00367659 0.132861 0.863463 0.011322 0.0346107 0.00583847 4.41409 0.0 0.0165943 0.00492837 0.0149418 0.0129915
5 Acetate 0.00144544 0.936372 0.818787 0.967702 171.55 266.0 0.644925 0.00372389 0.132196 0.864081 0.0113298 0.0345434 0.00584282 5.11422 0.0 0.0192264 0.00498694 0.0161098 0.0139428
pltdf =
    @chain mdlstatsdf begin
        subset(:lambda => ByRow(>=(1e-3)))
        groupby([:metabolite_label, :lambda])
        combine(
            5:13 .=> mean,
            5:13 .=> std,
        )
    end
first(pltdf, 5)
5×20 DataFrame
Row metabolite_label lambda rsq_mean rsq_adj_mean cor_mean degrees_freedom_mean orig_degrees_freedom_mean degrees_freedom_prop_mean phylum_level_mean species_level_mean strain_level_mean rsq_std rsq_adj_std cor_std degrees_freedom_std orig_degrees_freedom_std degrees_freedom_prop_std phylum_level_std species_level_std strain_level_std
String31 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 Acetate 0.001 0.629924 -0.154249 0.813449 179.75 266.0 0.675752 0.00386866 0.133389 0.862742 0.0815298 0.264763 0.0402405 4.08946 0.0 0.0153739 0.0050941 0.0144972 0.011977
2 Acetate 0.00109648 0.63077 -0.131152 0.813759 178.2 266.0 0.669925 0.00361343 0.133422 0.862965 0.0813201 0.257235 0.0402418 4.7195 0.0 0.0177425 0.00484482 0.015301 0.0127648
3 Acetate 0.00120226 0.631667 -0.0890729 0.814106 175.3 266.0 0.659023 0.00364206 0.13221 0.864148 0.0810454 0.231735 0.0401964 4.90005 0.0 0.0184213 0.00488046 0.0140508 0.0122095
4 Acetate 0.00131826 0.632732 -0.0664488 0.81454 173.7 266.0 0.653008 0.00367659 0.132861 0.863463 0.0808514 0.227508 0.0401483 4.41409 0.0 0.0165943 0.00492837 0.0149418 0.0129915
5 Acetate 0.00144544 0.633894 -0.0388593 0.814984 171.55 266.0 0.644925 0.00372389 0.132196 0.864081 0.0805077 0.220566 0.0401091 5.11422 0.0 0.0192264 0.00498694 0.0161098 0.0139428

Train & test performance panel

oof_meandf = unstack(pltdf, :lambda, :metabolite_label, :rsq_adj_mean)
infold_meandf = unstack(infold_pltdf, :lambda, :metabolite_label, :rsq_adj_mean)
bestlambdamodels = @chain pltdf begin
    groupby(:metabolite_label)
    subset(:rsq_adj_mean => (x -> x .== maximum(x)))
    sort(:rsq_adj_mean, rev=true)
end


ps = []
foreach(bestlambdamodels.metabolite_label) do cname
    p = plot(
        title=cname,
        xlims=(1e-3, 1),
        ylims=(-0.25, 1),
        yticks=[-0.25, 0, 0.5, 1],
        xticks=[1e-3, 1e-2, 1e-1, 1e-0],
        xscale=:log10,
        tickfontsize=5,
        titlefontsize=8,
        margin=2Plots.mm,
    )
    plot!(infold_meandf.lambda, infold_meandf[!, cname],
        c=:grey, ls=:dash,
    )
    plot!(oof_meandf.lambda, oof_meandf[!, cname],
        c=:grey,
    )
    met_mask = bestlambdamodels.metabolite_label .== cname
    scatter!(bestlambdamodels.lambda[met_mask], bestlambdamodels.rsq_adj_mean[met_mask], 
        mc=:orange,
        msw=0.25,
    )
    push!(ps, p)
end
plot(ps..., layout=grid(4, 8), size=(1500, 750))

Adjusted r2 (y-axis) versus penalty value of SLE-LASSO model (x-axis) for predicting metabolite concentrations across all strains. Adjusted r2 value plotted for both training and test set with peak predictive capacity of the test set delineated as a yellow dot in each plot (see key).

savefig(joinpath(pdir, "train-vs-test_mean-adjusted-r2_linecharts.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/train-vs-test_mean-adjusted-r2_linecharts.pdf"
infold_adjr2 = select(infold_pltdf,
    :metabolite_label => :Metabolite,
    :lambda => :Penalty_value,
    :rsq_adj_mean => :training_adj_rsq_mean,
    :rsq_adj_std => :training_adj_rsq_std,
)

@chain pltdf begin
    select(
        :metabolite_label => :Metabolite,
        :lambda => :Penalty_value,
        :rsq_adj_mean => :outofsample_adj_rsq_mean,
        :rsq_adj_std => :outofsample_adj_rsq_std,
    )
    leftjoin(infold_adjr2, on=[:Metabolite, :Penalty_value])
    groupby(:Metabolite)
    transform!(
        :outofsample_adj_rsq_mean => ((x)-> ifelse.(x.==maximum(x), "peak value", "")) => :Peak_outofsample_adj_rsq
    )
    CSV.write(
        projectdir("_research", "SuppTables", "Supplementary_Table_6.tsv"),
        _,
        delim="\t",
    )
end
"/Users/bend/projects/Doran_etal_2023/_research/SuppTables/Supplementary_Table_6.tsv"

Coefficients of SLE-Lasso models

coefdf_stacked =
    @chain CSV.read(joinpath(rdir, "coefs_SLE_lambda=many.csv"), DataFrame) begin
        stack(6:106)
        transform!(:variable => ByRow(s -> parse(Float64, last(split(s, "_")))) => :lambda)
        select(Not([:variable, :value]), :value => :coef)
    end
17194240×7 DataFrame
17194215 rows omitted
Row metabolite_name metabolite_label fold resample num_species_descendents lambda coef
String31 String31 Int64 Int64 Float64 Float64 Float64
1 Acetate Acetate 1 1 NaN 0.0001 0.0
2 Acetate Acetate 1 1 11.0 0.0001 -0.451036
3 Acetate Acetate 1 1 3.0 0.0001 -0.32701
4 Acetate Acetate 1 1 2.0 0.0001 -0.107274
5 Acetate Acetate 1 1 1.0 0.0001 -0.169603
6 Acetate Acetate 1 1 1.0 0.0001 0.863021
7 Acetate Acetate 1 1 1.0 0.0001 0.122345
8 Acetate Acetate 1 1 1.0 0.0001 0.027476
9 Acetate Acetate 1 1 1.0 0.0001 0.574951
10 Acetate Acetate 1 1 1.0 0.0001 -0.134015
11 Acetate Acetate 1 1 1.0 0.0001 0.400922
12 Acetate Acetate 1 1 1.0 0.0001 -0.823842
13 Acetate Acetate 1 1 1.0 0.0001 0.947637
17194229 Valine Valine 4 5 1.0 1.0 0.0
17194230 Valine Valine 4 5 1.0 1.0 0.0
17194231 Valine Valine 4 5 1.0 1.0 0.0
17194232 Valine Valine 4 5 1.0 1.0 0.0
17194233 Valine Valine 4 5 1.0 1.0 0.0
17194234 Valine Valine 4 5 1.0 1.0 0.0
17194235 Valine Valine 4 5 1.0 1.0 0.0
17194236 Valine Valine 4 5 1.0 1.0 0.0
17194237 Valine Valine 4 5 1.0 1.0 0.0
17194238 Valine Valine 4 5 1.0 1.0 0.0
17194239 Valine Valine 4 5 1.0 1.0 0.0
17194240 Valine Valine 4 5 1.0 1.0 0.0
coef_pltdf = @chain coefdf_stacked begin
    @rsubset(:coef != 0.0, :lambda >= 0.001)
    @rtransform(:is_subspecies = :num_species_descendents <= 1.0)
    @groupby(:metabolite_label, :fold, :resample, :lambda, :is_subspecies)
    @combine(
        :mean_coef_magnitude = mean(abs, :coef),
    )
    unstack(:is_subspecies, :mean_coef_magnitude)
    rename(
        "false" => :speciesbranches_mean_coef_magnitude,
        "true" => :subspeciesbranches_mean_coef_magnitude,
    )
    @groupby(:metabolite_label, :lambda)
    @combine(
        :mean_coef_magnitude_subspecies = (mean  skipmissing)(:subspeciesbranches_mean_coef_magnitude),
        :mean_coef_magnitude_species = (mean  skipmissing)(:speciesbranches_mean_coef_magnitude),
    )
    sort([:metabolite_label, :lambda])
    # coalesce.(NaN)
end;
ordered_metabolite_labels = sort(bestlambdamodels, :rsq_adj_mean, rev=true).metabolite_label;
ps = []
for mlabel in ordered_metabolite_labels
    p = plot(
        title=mlabel,
        xscale=:log10,
        xticks=[0.001, 0.01, 0.1, 1],
        xlims=(0.001, 1),
        ylims=(0, Inf),
        widen=true,
    )
    @df @rsubset(coef_pltdf, :metabolite_label == mlabel) plot!(
        :lambda, :mean_coef_magnitude_species;
        linestyle=:solid,
        color=:grey
    )
    @df @rsubset(coef_pltdf, :metabolite_label == mlabel) plot!(
        :lambda, :mean_coef_magnitude_subspecies;
        linestyle=:dash,
        color=:black,
    )
    vline!([bestlambdamodels.lambda[bestlambdamodels.metabolite_label.==mlabel]],
        color=:lightgrey, linewidth=0.5,
    )
    push!(ps, p)
end
plot(ps..., layout=grid(4, 8), size=(1200, 600), titlefontsize=6, tickfontsize=6)

Mean magnitude of coefficients in SLE-model (y-axis) versus penalty value (x-axis) for each metabolite (panel). Dashed curves correspond to coefficients of SLE branches defining differences amongst strains belonging to the same species (‘inter-species variation’); solid curves corresponds to coefficients of SLE branches defining differences amongst strains belonging to different species (‘intra-species variation’). Solid vertical gray line in each plot delineates the penalty value for which the peak predictive SLE-LASSO model is observed as shown in Figure 7D.

savefig(joinpath(pdir, "averaged_coeff_magnitude_dash=subspecies.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/averaged_coeff_magnitude_dash=subspecies.pdf"

Fig. S16 - Single fold example

seed!(424242)
cv = StratifiedCV(nfolds=NFOLDS, shuffle=true);
folds = vcat([train_test_pairs(cv, 1:sum(full_train_mask), trnYdf.Species) for i in 1:REPS]...)

oof_preds_df_stacked = DataFrame()
inf_preds_df_stacked = DataFrame()
coefdf = DataFrame()
models_tbl = []
(i, (fold_trn, fold_tst)) = (1, folds[1])
# (i, (fold_trn, fold_tst)) = (2, folds[2])

# Use tree to get lineage traces for each training and test sample
# using projections of taxa into UniProt so these loadings are constant regardless of folds
foldPCs = UPfullPCs[fold_trn, :]
tst_nns = map(r -> partialsortperm(r, 1:K), eachrow(subsetDij[fold_tst, fold_trn]))

# UPGMA tree building...
foldhc = UPGMA_tree(subsetDij[fold_trn, fold_trn])
foldtree = readnw(SpectralInference.newickstring(foldhc, trnYdf.Strain_ID[fold_trn]))
ladderize!(foldtree, rev=false)
ordered_treeids = getleafids(foldtree)[indexin(trnYdf.Strain_ID[fold_trn], getleafnames(foldtree))]

# Make SLE ancester encoding for training set
trnX_all = @chain begin
    spectral_lineage_encoding(foldtree, ordered_treeids)
    getfield.(:sle)
    stack
    float.(_)
end
isinternal_fold = map(!isleaf, prewalk(foldtree))
trnX = trnX_all[:, isinternal_fold]

num_descendent_species = map(prewalk(foldtree)) do node
    if !isroot(node)
        sps = trnYdf.Species[indexin(getleafnames(parent(node)), trnYdf.Strain_ID)]
        return length(unique(sps))
    else
        NaN
    end
end

# Get features for each out-of-fold isolate
oofX = map(tst_nns) do nn
    trnX[nn, :] |>
    mtx -> mean(mtx, dims=1)
end |> x -> vcat(x...)

# fit lasso model
(target_idx, (target, mlabel)) = (findfirst(==("Acetate"), metabolite_label), ("Acetate", "Acetate"))

mdl = GLMNet.glmnetcv(trnX, metab_trnY[fold_trn, target_idx], Normal();
    lambda=lambdas,
)

push!(models_tbl, (;
    metabolite_name=target,
    metabolite_label=mlabel,
    fold=((i - 1) % NFOLDS) + 1,
    resample=((i - 1) ÷ NFOLDS) + 1,
    fold_trn,
    fold_tst,
    foldtree=foldtree,
    model=mdl,
))

println("on $(i)th resample")
on 1th resample
# best lambda model for acetate
best_lambda_from_fig7 = only(@rsubset(bestlambdamodels, :metabolite_label == "Acetate").lambda)
bestlambda_idx = last(findmin(x->abs(x-best_lambda_from_fig7), mdl.lambda))
29
log10(best_lambda_from_fig7)
-1.12
branch_coeffs = zeros(2length(fold_trn)-1)
branch_coeffs[isinternal_fold] .= mdl.path.betas[:, bestlambda_idx]
clims = getlims(branch_coeffs);
branch_coeffs_dict = Dict(zip(NewickTree.id.(prewalk(foldtree)), branch_coeffs));
# plot(size=(400, 600))
tp = plot(foldtree, 
    fs=2,
    line_z=permutedims(branch_coeffs[2:end]),
    clims=clims,
    colormap=:bam,
    linewidth=4 .+ permutedims(abs.(branch_coeffs[2:end])),
)
vline!([0.5, 2.04], color=:grey, alpha=0.5, linestyle=:dash)
plot(foldtree, 
    fs=2,
    line_z=permutedims(branch_coeffs[2:end]),
    clims=clims,
    colormap=:bam,
    linewidth=1 .+ permutedims(abs.(branch_coeffs[2:end])),
    colorbar=true,
)
savefig(joinpath(pdir, "treeplot_with_colorbar.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/treeplot_with_colorbar.pdf"
trainingset_treeorder = indexin(getleafnames(foldtree), trnYdf.Strain_ID[fold_trn]);
pltmtx = hcat(
    metab_trnY[fold_trn, target_idx][trainingset_treeorder], 
    GLMNet.predict(mdl, trnX)[trainingset_treeorder]
)
hplt = heatmap(pltmtx,
    colormap=:bwr,
    clims=getlims(pltmtx),
    xticks = (1:2, ["measurements", "training predictions"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)
pltmtx = hcat(
    metab_trnY[fold_trn, target_idx][trainingset_treeorder], 
    GLMNet.predict(mdl, trnX)[trainingset_treeorder]
)
heatmap(pltmtx,
    colormap=:bwr,
    clims=getlims(pltmtx),
    xticks = (1:2, ["measurements", "training predictions"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    # colorbar=:none,
)
savefig(joinpath(pdir, "prediction_heatmap_with_colorbar.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/prediction_heatmap_with_colorbar.pdf"
pltmtx = hcat(
    trnYdf.Species[fold_trn][trainingset_treeorder]
)
colormap = getindex.(Ref(species_color_dict), unique(pltmtx))
spplt = heatmap(pltmtx,
    colormap=colormap,
    # clims=getlims(pltmtx),
    xticks = (1:1, ["species"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)
layout = @layout [a{0.85w} b c{0.05w}]
plot(tp, hplt, spplt, layout=layout, link=:y, size=(500, 600), yflip=true)
savefig(joinpath(pdir, "foldtree_coeffs_predictions_species.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/foldtree_coeffs_predictions_species.pdf"
foreach(zip(prewalk(foldtree), branch_coeffs)) do (node, sle)
    NewickTree.setsupport!(node.data, sle)
end;

A. hadrus subtree

subtreeleafids = trnYdf.Strain_ID[fold_trn][occursin.("hadrus", trnYdf.Species[fold_trn])];
subtree = readnw(nwstr(NewickTree.extract(foldtree, subtreeleafids)))
trainingset_treeorder = indexin(getleafnames(subtree), trnYdf.Strain_ID[fold_trn]);
subtree_coeffs = map(prewalk(subtree)) do node
    NewickTree.support(node)
end |> x->replace(x, NaN=>0.0);
# plot(size=(400, 600))
tp = plot(subtree, 
    fs=5,
    line_z=permutedims(subtree_coeffs[2:end]),
    # clims=getlims(subtree_coeffs[2:end]),
    clims=clims./2,
    colormap=:bam,
    linewidth=5 .+ permutedims(abs.(subtree_coeffs[2:end])),
)
# vline!([0.5, 2.04], color=:grey, alpha=0.5, linestyle=:dash)

pltmtx = hcat(
    metab_trnY[fold_trn, target_idx][trainingset_treeorder], 
    GLMNet.predict(mdl, trnX)[trainingset_treeorder]
)
hplt = heatmap(pltmtx,
    colormap=:bwr,
    clims=getlims(pltmtx),
    xticks = (1:2, ["measurements", "training predictions"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)

pltmtx = hcat(
    trnYdf.Species[fold_trn][trainingset_treeorder]
)
colormap = getindex.(Ref(species_color_dict), unique(pltmtx))
spplt = heatmap(pltmtx,
    colormap=colormap,
    # clims=getlims(pltmtx),
    xticks = (1:1, ["species"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)

layout = @layout [a{0.85w} b c{0.05w}]
plot(tp, hplt, spplt, layout=layout, link=:y, size=(500, 400), yflip=true)
savefig(joinpath(pdir, "hadrus_foldtree_coeffs_predictions_species.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/hadrus_foldtree_coeffs_predictions_species.pdf"

P. vulgatus subtree

subtreeleafids = trnYdf.Strain_ID[fold_trn][occursin.("vulgatus", trnYdf.Species[fold_trn])];
subtree = readnw(nwstr(NewickTree.extract(foldtree, subtreeleafids)))
trainingset_treeorder = indexin(getleafnames(subtree), trnYdf.Strain_ID[fold_trn]);
subtree_coeffs = map(prewalk(subtree)) do node
    NewickTree.support(node)
end |> x->replace(x, NaN=>0.0);
@chain biobank.obs[occursin.("rectale", biobank.obs.Species), [:Strain_ID, :Accession]] begin
    sort(:Strain_ID)
end
20×2 DataFrame
Row Strain_ID Accession
String String
1 MSK.13.48 JAAISJ000000000
2 MSK.13.50 JAAISI000000000
3 MSK.13.59 JAAISH000000000
4 MSK.16.22 JAAIMQ000000000
5 MSK.16.45 JAAIMP000000000
6 MSK.17.13 JAAIMK000000000
7 MSK.17.19 JAAIMJ000000000
8 MSK.17.3 JAAIMG000000000
9 MSK.17.42 JAAIME000000000
10 MSK.17.57 JAAIMC000000000
11 MSK.17.70 JAAILY000000000
12 MSK.17.78 JAAILX000000000
13 MSK.17.79 JAAILW000000000
14 MSK.22.19 JAAISF000000000
15 MSK.22.23 JAAISE000000000
16 MSK.22.28 JAAISD000000000
17 MSK.22.51 JAAISB000000000
18 MSK.22.92 JAJFBX000000000
19 MSK.9.13 JAAISA000000000
20 MSK.9.15 JAAIRZ000000000
# plot(size=(400, 600))
tp = plot(subtree, 
    fs=5,
    line_z=permutedims(subtree_coeffs[2:end]),
    # clims=getlims(subtree_coeffs[2:end]),
    clims=clims./2,
    colormap=:bam,
    linewidth=5 .+ permutedims(abs.(subtree_coeffs[2:end])),
)
# vline!([0.5, 2.04], color=:grey, alpha=0.5, linestyle=:dash)

pltmtx = hcat(
    metab_trnY[fold_trn, target_idx][trainingset_treeorder], 
    GLMNet.predict(mdl, trnX)[trainingset_treeorder]
)
hplt = heatmap(pltmtx,
    colormap=:bwr,
    clims=getlims(pltmtx),
    xticks = (1:2, ["measurements", "training predictions"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)

pltmtx = hcat(
    trnYdf.Species[fold_trn][trainingset_treeorder]
)
colormap = getindex.(Ref(species_color_dict), unique(pltmtx))
spplt = heatmap(pltmtx,
    colormap=colormap,
    # clims=getlims(pltmtx),
    xticks = (1:1, ["species"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)

layout = @layout [a{0.85w} b c{0.05w}]
plot(tp, hplt, spplt, layout=layout, link=:y, size=(500, 500), yflip=true)
savefig(joinpath(pdir, "vulgatus_foldtree_coeffs_predictions_species.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/vulgatus_foldtree_coeffs_predictions_species.pdf"

B. breve & E. rectale subtree

subtreeleafids = trnYdf.Strain_ID[fold_trn][
    occursin.("rectale", trnYdf.Species[fold_trn]) .|| occursin.("breve", trnYdf.Species[fold_trn])
];
subtree = readnw(nwstr(NewickTree.extract(foldtree, subtreeleafids)))
trainingset_treeorder = indexin(getleafnames(subtree), trnYdf.Strain_ID[fold_trn]);
subtree_coeffs = map(prewalk(subtree)) do node
    NewickTree.support(node)
end |> x->replace(x, NaN=>0.0);
# plot(size=(400, 600))
tp = plot(subtree, 
    fs=5,
    line_z=permutedims(subtree_coeffs[2:end]),
    # clims=getlims(subtree_coeffs[2:end]),
    clims=clims./2,
    colormap=:bam,
    linewidth=5 .+ permutedims(abs.(subtree_coeffs[2:end])),
)
# vline!([0.5, 2.04], color=:grey, alpha=0.5, linestyle=:dash)

pltmtx = hcat(
    metab_trnY[fold_trn, target_idx][trainingset_treeorder], 
    GLMNet.predict(mdl, trnX)[trainingset_treeorder]
)
hplt = heatmap(pltmtx,
    colormap=:bwr,
    clims=getlims(pltmtx),
    xticks = (1:2, ["measurements", "training predictions"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)

pltmtx = hcat(
    trnYdf.Species[fold_trn][trainingset_treeorder]
)
colormap = getindex.(Ref(species_color_dict), unique(pltmtx))
spplt = heatmap(pltmtx,
    colormap=colormap,
    # clims=getlims(pltmtx),
    xticks = (1:1, ["species"]),
    yticks=false,
    xrotation=45,
    bottommargin=5Plots.mm,
    colorbar=:none,
)

layout = @layout [a{0.85w} b c{0.05w}]
plot(tp, hplt, spplt, layout=layout, link=:y, size=(500, 500), yflip=true)
savefig(joinpath(pdir, "rectale-breve_foldtree_coeffs_predictions_species.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/rectale-breve_foldtree_coeffs_predictions_species.pdf"

Fig. S17 - panel of predictions

best_models_df = @chain bestlambdamodels begin
    select(:metabolite_label, :lambda, :rsq_mean, :rsq_adj_mean)
    sort(:rsq_adj_mean, rev=true)
end

best_model_preds_df = @chain oof_preds_df_stacked begin
    leftjoin(best_models_df, _; on=[:metabolite_label, :lambda])
    groupby([:msk_id, :metabolite_label, :lambda])
    combine(
        :truth => (only  unique) => identity,
        :rsq_mean => (only  unique) => identity,
        :rsq_adj_mean => (only  unique) => identity,
        # :preds => first,
        :preds => mean,
        :preds => std,
    )
    DataFrames.transform(
        :msk_id => (x -> trnYdf.Species[indexin(x, trnYdf.Strain_ID)]) => :Species
    )
    DataFrames.transform(
        :Species => ByRow(x -> get(species_color_dict, x, :grey)) => :Species_color
    )
end;
first(best_model_preds_df, 5)
5×10 DataFrame
Row msk_id metabolite_label lambda truth rsq_mean rsq_adj_mean preds_mean preds_std Species Species_color
String15? String31 Float64 Float64 Float64 Float64 Float64 Float64 String String7
1 MSK.19.38 Tyramine 0.0331131 0.742278 0.263256 0.20591 0.147467 0.0372331 [Ruminococcus] gnavus #3288BD
2 MSK.22.14 Tyramine 0.0331131 -0.0850042 0.263256 0.20591 0.0181758 0.0163976 Phocaeicola vulgatus #9E0142
3 MSK.22.19 Tyramine 0.0331131 -0.116961 0.263256 0.20591 0.135059 0.0351086 [Eubacterium] rectale #D53E4F
4 MSK.18.5 Tyramine 0.0331131 -0.0123604 0.263256 0.20591 -0.0979244 0.00693549 Blautia luti #5E4FA2
5 MSK.19.91 Tyramine 0.0331131 0.377349 0.263256 0.20591 -0.091508 0.00715769 Bacteroides uniformis #FEE08B
ps = []
for row in eachrow(best_models_df)
    predsdf = subset(best_model_preds_df,
        :metabolite_label => ByRow(==(row.metabolite_label))
    )
    p = @df predsdf scatter(
        :preds_mean, :truth,
        xerror=:preds_std,
        msw=0.25,
        # markerstrokecolor=:black,
        seriescolor=:Species_color,
        title=row.metabolite_label * "\nλ=$(round((onlyunique)(:lambda), digits=2)), R²=$(round((onlyunique)(:rsq_adj_mean), digits=2))",
        rasterize=true,
    )
    push!(ps, p)
end
plot(ps..., 
    layout=grid(4, 8), 
    size=(1500, 750), 
    titlefontsize=8,
    tickontsize=5,
    ratio=1,
    widen=1.1,
    xrotation=45,
)

Measured (y-axis) versus predicted (x-axis) relative concentration of metabolite where prediction is computed from the peak predictive SLE-LASSO model from Figure 7D. Dots are individual strains colored by species (see color key).

split plot because illustrator sucks

plot(ps[1:16]..., 
    layout=grid(2, 8), 
    size=(1500, 400), 
    titlefontsize=8,
    tickontsize=5,
    ratio=1,
    widen=1.1,
    # xrotation=45,
    topmargin=2Plots.mm,
    bottommargin=5Plots.mm,
)
savefig(joinpath(pdir, "outoffold_mean-adjusted-r2_scattercharts_tophalf.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/outoffold_mean-adjusted-r2_scattercharts_tophalf.pdf"
plot(ps[17:end]..., 
    layout=grid(2, 8), 
    size=(1500, 400), 
    titlefontsize=8,
    tickontsize=5,
    ratio=1,
    widen=1.1,
    # xrotation=45,
    topmargin=2Plots.mm,
    bottommargin=5Plots.mm,
)
savefig(joinpath(pdir, "outoffold_mean-adjusted-r2_scattercharts_bottomhalf.pdf"))
"/Users/bend/projects/Doran_etal_2023/plots/metabolite_model_outofbag/outoffold_mean-adjusted-r2_scattercharts_bottomhalf.pdf"