Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions R/model_edges.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ inbound_nodes <- function(model){
inbound <- map(
model_layers,
function(x){
if (length(x$inbound_nodes))
x$inbound_nodes[[1]] %>%
map_chr(c(1, 1))
else NA
if (length(x$inbound_nodes) == 1) {
x$inbound_nodes[[1]] %>% map_chr(c(1, 1))
} else if(length(x$inbound_nodes) > 1) {
# needed for shared layers
x$inbound_nodes %>% map_chr(c(1, 1))
} else {
NA
}
}
)
names(inbound) <- map(model_layers, "name")
Expand Down
32 changes: 32 additions & 0 deletions inst/examples/example_shared_layer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
## ---- shared ----

require(keras)

# Example of a network with a shared layers
# from: https://keras.rstudio.com/articles/functional_api.html

tweet_a <- layer_input(shape = c(280, 256))
tweet_b <- layer_input(shape = c(280, 256))

# This layer can take as input a matrix and will return a vector of size 64
shared_lstm <- layer_lstm(units = 64)

# When we reuse the same layer instance multiple times, the weights of the layer are also
# being reused (it is effectively *the same* layer)
encoded_a <- tweet_a %>% shared_lstm
encoded_b <- tweet_b %>% shared_lstm

# We can then concatenate the two vectors and add a logistic regression on top
predictions <- layer_concatenate(c(encoded_a, encoded_b), axis=-1) %>%
layer_dense(units = 1, activation = 'sigmoid')

# We define a trainable model linking the tweet inputs to the predictions
model <- keras_model(inputs = c(tweet_a, tweet_b), outputs = predictions)

model %>% compile(
optimizer = 'rmsprop',
loss = 'binary_crossentropy',
metrics = c('accuracy')
)

plot_model(model)
10 changes: 10 additions & 0 deletions vignettes/introduction.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ knitr::read_chunk(
)
```

```{r cache=FALSE, include=FALSE}
knitr::read_chunk(
here::here("inst/examples/example_shared_layer.R")
)
```

### Sequential model

Simple sequential model
Expand All @@ -51,6 +57,10 @@ Simple sequential model
```{r network, fig.height=6}
```

Shared layers

```{r shared, fig.height=6}
```

## Famous architectures

Expand Down