diff --git a/R/model_edges.R b/R/model_edges.R index b44420e..b88b0de 100644 --- a/R/model_edges.R +++ b/R/model_edges.R @@ -27,10 +27,14 @@ inbound_nodes <- function(model){ inbound <- map( model_layers, function(x){ - if (length(x$inbound_nodes)) - x$inbound_nodes[[1]] %>% - map_chr(c(1, 1)) - else NA + if (length(x$inbound_nodes) == 1) { + x$inbound_nodes[[1]] %>% map_chr(c(1, 1)) + } else if(length(x$inbound_nodes) > 1) { + # needed for shared layers + x$inbound_nodes %>% map_chr(c(1, 1)) + } else { + NA + } } ) names(inbound) <- map(model_layers, "name") diff --git a/inst/examples/example_shared_layer.R b/inst/examples/example_shared_layer.R new file mode 100644 index 0000000..ff88ea8 --- /dev/null +++ b/inst/examples/example_shared_layer.R @@ -0,0 +1,32 @@ +## ---- shared ---- + +require(keras) + +# Example of a network with a shared layers +# from: https://keras.rstudio.com/articles/functional_api.html + +tweet_a <- layer_input(shape = c(280, 256)) +tweet_b <- layer_input(shape = c(280, 256)) + +# This layer can take as input a matrix and will return a vector of size 64 +shared_lstm <- layer_lstm(units = 64) + +# When we reuse the same layer instance multiple times, the weights of the layer are also +# being reused (it is effectively *the same* layer) +encoded_a <- tweet_a %>% shared_lstm +encoded_b <- tweet_b %>% shared_lstm + +# We can then concatenate the two vectors and add a logistic regression on top +predictions <- layer_concatenate(c(encoded_a, encoded_b), axis=-1) %>% + layer_dense(units = 1, activation = 'sigmoid') + +# We define a trainable model linking the tweet inputs to the predictions +model <- keras_model(inputs = c(tweet_a, tweet_b), outputs = predictions) + +model %>% compile( + optimizer = 'rmsprop', + loss = 'binary_crossentropy', + metrics = c('accuracy') +) + +plot_model(model) diff --git a/vignettes/introduction.Rmd b/vignettes/introduction.Rmd index af18b84..0427e00 100644 --- a/vignettes/introduction.Rmd +++ b/vignettes/introduction.Rmd @@ -39,6 +39,12 @@ knitr::read_chunk( ) ``` +```{r cache=FALSE, include=FALSE} +knitr::read_chunk( + here::here("inst/examples/example_shared_layer.R") +) +``` + ### Sequential model Simple sequential model @@ -51,6 +57,10 @@ Simple sequential model ```{r network, fig.height=6} ``` +Shared layers + +```{r shared, fig.height=6} +``` ## Famous architectures