-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgradio_app.py
More file actions
200 lines (160 loc) · 8.07 KB
/
Copy pathgradio_app.py
File metadata and controls
200 lines (160 loc) · 8.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from keras import mixed_precision
from src.model import *
from src.config import *
# Use mixed precision policy
mixed_precision.set_global_policy('mixed_float16')
# Model Hyperparameters
VOCAB_SIZE = 10000 # Size of the vocabulary
SEQUENCE_LENGTH = 25 # Length of the input sequences
EMBEDDING_DIM = 512 # Dimensionality of the embedding layer
UNITS = 512 # Number of units in the LSTM layers
IMG_HEIGHT = 299 # Height of the input images
IMG_WIDTH = 299 # Width of the input images
is_attention = True # Set to True if using BAHDANAU attention-based model
# Load the best weights from the first phase of training
def restore_models():
print("Restoring models...")
encoder = get_encoder()
decoder = RnnDecoder(EMBEDDING_DIM, UNITS, VOCAB_SIZE)
# Create dummy inputs with the correct shapes for a single prediction
# The batch size is 1 for inference.
dummy_features = tf.zeros((1, EMBEDDING_DIM)) # Shape of encoder output
dummy_hidden_state = decoder.reset_state(batch_size=1) # Initial LSTM state
dummy_dec_input = tf.zeros((1, 1)) # A single start token
# Call the decoder once to build its internal weights
_ = decoder(dummy_dec_input, dummy_features, dummy_hidden_state)
print("Models restored.")
return encoder, decoder
# Load the best weights from the first phase of training
def restore_models_with_attention():
print("Restoring models...")
encoder = get_encoder_with_spatial_features()
decoder = RnnDecoderWithAttention(EMBEDDING_DIM, UNITS, VOCAB_SIZE)
# Create dummy inputs with the correct shapes for a single prediction
# The batch size is 1 for inference.
dummy_features = tf.zeros((1, EMBEDDING_DIM)) # Shape of encoder output
dummy_hidden_state = decoder.reset_state(batch_size=1) # Initial LSTM state
dummy_dec_input = tf.zeros((1, 1)) # A single start token
# Call the decoder once to build its internal weights
_ = decoder(dummy_dec_input, dummy_features, dummy_hidden_state)
print("Models restored.")
return encoder, decoder
# Placeholder functions to load trained models and vectorization
def load_models_and_vectorizer():
if is_attention:
# Load the trained models
encoder, decoder = restore_models_with_attention()
# Load the best saved weights from Bahdanau attention
encoder.load_weights('./training_checkpoints_attention/best_encoder.weights.h5')
decoder.load_weights('./training_checkpoints_attention/best_decoder.weights.h5')
else:
encoder, decoder = restore_models()
# Load the best saved weights from LSTM only
encoder.load_weights('./training_checkpoints/best_encoder.weights.h5')
decoder.load_weights('./training_checkpoints/best_decoder.weights.h5')
# Restore the TextVectorization Layer
vectorization = tf.keras.layers.TextVectorization(
max_tokens=VOCAB_SIZE,
output_sequence_length=SEQUENCE_LENGTH,
dtype='int32'
)
with open('vocabulary.txt', 'r') as f:
vocabulary = [line.strip() for line in f.readlines()]
vectorization.set_vocabulary(vocabulary)
return encoder, decoder, vectorization
# Load everything once when the app starts
encoder, decoder, vectorization = load_models_and_vectorizer()
start_token_index = vectorization.get_vocabulary().index('<start>')
end_token_index = vectorization.get_vocabulary().index('<end>')
def load_and_preprocess_image(image_array):
# Preprocess the NumPy image array
img = tf.convert_to_tensor(image_array) # Gradio passes a NumPy array
img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
img = keras.applications.inception_v3.preprocess_input(img)
img_tensor = tf.expand_dims(img, 0)
return img_tensor
# Copy generate_caption_beam_search function here
# For Gradio, the input will be a NumPy array, not a path, so we need a small change to handle the uploaded image
def generate_caption_beam_search(image_array, encoder, decoder, beam_width=3):
# --- 1. Initial Setup ---
img_tensor = load_and_preprocess_image(image_array)
features = encoder(img_tensor)
hidden = decoder.reset_state(batch_size=1)
# --- 2. Initialize the Beam ---
# The beam is a list of tuples: ([sequence of token indices], score, hidden_state)
# Start with the <start> token. The initial score is 0 because we use log probabilities.
start_seq = [start_token_index]
initial_beam = [(start_seq, 0.0, hidden)]
# This list will store sequences that have finished (i.e., generated an <end> token)
completed_sequences = []
# --- 3. Main Loop ---
# Loop for each step in the caption generation process
for _ in range(SEQUENCE_LENGTH):
new_beam = []
# Expand each sequence in the current beam
for seq, score, h_state in initial_beam:
# The last word of the sequence is the input for the next prediction
last_word_idx = seq[-1]
dec_input = tf.expand_dims([last_word_idx], 0)
# Make a prediction
if is_attention:
predictions, new_h, new_c, _ = decoder(dec_input, features, h_state)
else:
predictions, new_h, new_c = decoder(dec_input, features, h_state)
# Use log softmax for numerical stability and to sum probabilities
log_probs = tf.nn.log_softmax(predictions[0])
# Get the top k words and their log probabilities
top_k_probs, top_k_indices = tf.nn.top_k(log_probs, k=beam_width)
# Create new candidate sequences from the top k predictions
for i in range(beam_width):
new_word_idx = top_k_indices[i].numpy()
prob = top_k_probs[i].numpy()
# Create the new sequence and calculate its cumulative score
new_seq = seq + [new_word_idx]
new_score = score + prob
# If the new word is the <end> token, this sequence is complete
if new_word_idx == end_token_index:
completed_sequences.append((new_seq, new_score))
else:
# Otherwise, add it to the list of candidates for the next beam
new_beam.append((new_seq, new_score, [new_h, new_c]))
# --- 4. Prune the Beam ---
# If no new sequences were generated that didn't end, stop.
if not new_beam:
break
# Sort all new candidates by their score (higher log probability is better)
new_beam.sort(key=lambda x: x[1], reverse=True)
# The new beam is the top k candidates from all possibilities
initial_beam = new_beam[:beam_width]
# --- 5. Final Selection ---
# If no sequences were completed, use the current best-scoring ones from the beam
if not completed_sequences:
completed_sequences.extend(initial_beam)
# Find the best sequence among all completed sequences
# Note: A common improvement is to normalize the score by the caption length
completed_sequences.sort(key=lambda x: x[1] / len(x[0]), reverse=True)
best_seq_indices = completed_sequences[0][0]
# Convert the sequence of indices back to words
vocab = vectorization.get_vocabulary()
result_words = [vocab[i] for i in best_seq_indices if i not in [start_token_index, end_token_index]]
return ' '.join(result_words)
# Create a wrapper function for Gradio
def predict_caption(image):
caption = generate_caption_beam_search(image, encoder, decoder, beam_width=3)
return caption
# Build the Gradio Interface
demo = gr.Interface(
fn=predict_caption,
inputs=gr.Image(label="Upload an Image"),
outputs=gr.Textbox(label="Predicted Caption"),
# With picture emoji
title="🤖️🖼️ Image Captioning Project",
description="Image captioning model built with TensorFlow. Upload an image and see the model generate a description.",
# Optional: Add example images
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch(share=True) # Set share=True to get a public link