14.3. Sentiment Analysis with Attention

We build upon the RNN-based sentiment analysis, from Chapter 12, by incorporating an attention mechanism.

Our strategy is simple:

  1. Instead of directly feeding the final hidden state of the RNN unit to the dense layer, we use it as a “query” to the attention layer.

  2. The attention layer uses the query to compute a context vector, which reflects the sentence’s contextual information. This context vector is then passed to the dense layer to generate the sentiment prediction.

Fig.14-4: RNN-Based Sentiment Analysis with Attention

Here is a breakdown of how it works:

This model employs the final hidden state as a query $q$, for the attention layer.

The attention layer additionally receives all hidden states, denoted as $h_{i}$, as both key vectors $k_{i}$ and value vectors $v_{i}$. It then calculates attention weights $a_{i}$ based on the similarity between each key vector $k_{i}$ and the query vector $q$.

These attention weights $a_{i}$ are multiplied by the corresponding value vectors $v_{i}$, and the weighted vectors $a_{i}v_{i}$ are then summed ($\sum{a_{i} v_{i}}$) to create the final context vector.

Fig.14-5: Attention Mechanism in RNN-Based Sentiment Analysis

This approach utilizing attention weights is expected to capture the context of longer sentences more effectively compared to simply relying on the final hidden state.

14.3.1. Create Model

Complete Python code is available at: SentimentAnalysis-GRU-tf-attention.py

Our model comprises a many-to-many GRU network, an attention layer, and a dense output layer.

  • Many-to-Many GRU Layer: We set $\text{return_sequences}=\text{True}$ and $\text{return_state}=\text{True}$ for the GRU network to obtain both the final hidden state ($\text{state}$) and the entire sequence of hidden states ($\text{output}$). The final state feeds the attention layer as a query, while the entire sequence serves as its key and value vectors.

  • Attention Layer: We use the Bahdanau attention.

  • Dense Output Layer: We use a sigmoid activation function in the dense layer.

# ========================================
# Create Model
# ========================================

input_nodes = 1
hidden_nodes = 128
output_nodes = 1

embedding_dim = 64


class SentimentAnalysis(tf.keras.Model):
    def __init__(self, hidden_units, output_units, vocab_size, embedding_dim, rate=0.0):
        super().__init__()

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

        self.gru = tf.keras.layers.GRU(
            hidden_units,
            activation="tanh",
            recurrent_activation="sigmoid",
            kernel_initializer="glorot_normal",
            recurrent_initializer="orthogonal",
            return_sequences=True,
            return_state=True,
        )
        self.attention = Attention.BahdanauAttention(hidden_nodes)
        self.dense = tf.keras.layers.Dense(output_units, activation="sigmoid")

    def call(self, x):
        x = self.embedding(x)
        output, state = self.gru(x)
        context_vector, attention_weights = self.attention(state, output)
        x = self.dense(context_vector)
        return x, attention_weights


model = SentimentAnalysis(hidden_nodes, output_nodes, vocab_size, embedding_dim)

model.build(input_shape=(None, max_len))
model.summary()

14.3.2. Sentiment Analysis with Attention

The results of our RNN-based sentiment analysis with attention on the validation dataset are presented below:

$ python SentimentAnalysis-GRU-tf-attention.py

Model: "sentiment_analysis"
 Layer (type)                Output Shape              Param #
=================================================================
 embedding (Embedding)       multiple                  143936

 gru (GRU)                   multiple                  74496

 bahdanau_attention (Bahdana  multiple                 33153
 uAttention)

 dense_1 (Dense)             multiple                  129

=================================================================
Total params: 251,714
Trainable params: 251,714
Non-trainable params: 0

... snip ...

Text:an hour seriously? .
Correct value   =>  Negative
Estimated value =>  Negative

Text:it really is impressive that the place hasnt closed down .
Correct value   =>  Negative
Estimated value =>  Negative

Text:i promise they wont disappoint .
Correct value   =>  Positive
Estimated value =>  Negative
*** Wrong ***

Text:cute quaint simple honest .
Correct value   =>  Positive
Estimated value =>  Positive

Text:the waitresses are very friendly .
Correct value   =>  Positive
Estimated value =>  Positive

Text:i live in the neighborhood so i am disappointed i wont be back here because it is a convenient location .
Correct value   =>  Negative
Estimated value =>  Positive
*** Wrong ***

Text:the nachos are a must have! .
Correct value   =>  Positive
Estimated value =>  Negative
*** Wrong ***

Text:the pan cakes everyone are raving about taste like a sugary disaster tailored to the palate of a six year old .
Correct value   =>  Negative
Estimated value =>  Negative

Text:this is some seriously good pizza and im an expert/connisseur on the topic .
Correct value   =>  Positive
Estimated value =>  Positive

Text:tasted like dirt .
Correct value   =>  Negative
Estimated value =>  Negative

Text:if you want a sandwich just go to any firehouse!!!!! .
Correct value   =>  Positive
Estimated value =>  Positive

Text:and the beans and rice were mediocre at best .
Correct value   =>  Negative
Estimated value =>  Negative

Text:they will customize your order any way youd like my usual is eggplant with green bean stir fry love it! .
Correct value   =>  Positive
Estimated value =>  Positive

Text:i will be back many times soon .
Correct value   =>  Positive
Estimated value =>  Negative
*** Wrong ***

Text:what did bother me was the slow service .
Correct value   =>  Negative
Estimated value =>  Negative

Text:the food was very good .
Correct value   =>  Positive
Estimated value =>  Positive

Text:this place is not worth your time let alone vegas .
Correct value   =>  Negative
Estimated value =>  Negative

Text:delicious nyc bagels good selections of cream cheese real lox with capers even .
Correct value   =>  Positive
Estimated value =>  Positive

Text:very bad experience! .
Correct value   =>  Negative
Estimated value =>  Negative

Text:do not waste your money here! .
Correct value   =>  Negative
Estimated value =>  Negative


16 out of 20 sentences are correct.
Accuracy: 0.760000

In this example, 16 out of 20 sentences were correctly classified (sentence-level accuracy of 0.76).

After 100 trials, the average accuracy was 0.759 (standard deviation = 0.034).

This result demonstrates that the model with the attention mechanism outperforms the previously introduced model.

However, the result does not fully demonstrate the functionality of the attention mechanism, because this example is too simple.

The next section will dive into machine translation, which is a good application where the attention mechanism reveals its full abilities.