I'm trying to do a model using ResNet50 for image classification into 6 classes and I want to reduce the dimension of the images before using them to train the ResNet50 model. To do this I start creating a ResNet50 model using the model in keras:
ResNet = ResNet50(
include_top= None, weights='imagenet', input_tensor=None, input_shape=([64, 109, 3]),
pooling=None, classes=6)
And then I create a sequential model that includes ResNet50 but adding some final layers for the classification and also the first layer for dimensionality reduction before using ResNet50: (About the input shape: The images I'm using have a dimension of 128x217 and the 3 is for the channel that ResNet needs)
model = models.Sequential()
model.add(GlobalAveragePooling2D(input_shape = ([128, 217, 3])))
model.add(ResNet)
model.add(GlobalAveragePooling2D())
model.add(Dense(units=512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=6, activation='softmax'))
But this doesn't work because the dimension after the first global average pooling doesn't fit with the input shape in the Resnet, the error I get is:
WARNING:tensorflow:Model was constructed with shape (None, 64, 109, 3) for input Tensor("input_6:0", shape=(None, 64, 109, 3), dtype=float32), but it was called on an input with incompatible shape (None, 3).
ValueError: Input 0 of layer conv1_pad is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: [None, 3]
I think I understand what is the problem but I don't know how to fix it since (None, 3) is not a valid input shape for ResNet50. How can I fix this? Thank you!:)
You should first understand what GlobalAveragePooling actually does. This layer cannot be apllied right after the input, because it will only give the maximum value of all the images for each channel (in your case 3 values, because you have 3 channels).
You have to use another method to reduce the size of the images (e.g. simple conversion to a smaller size.
Thanks for your answer! I hadn't realized what GlobalAveragePooling was actually doing, I solved it using AveragePooling (not global).