GithubHelp home page GithubHelp logo

Comments (7)

Horea94 avatar Horea94 commented on August 23, 2024

In the [Fruits-360 CNN.py](src/image_classification/Fruits-360 CNN.py) file, we are using
model.load_weights(model_out_dir + "/model.h5")
to load the saved weights from the .h5 file. Then, using
y_pred = model.predict(testGen, steps=(testGen.n // batch_size) + 1, verbose=verbose)
we store the prediction in the y_pred variable. The variable is a list of arrays. Each array contains a probability for each of the classes, the index of the highest probability being the predicted class. To get the class name, you can look at the same index in the labels array.
The predict method is not limited to generators. According to the documentation found here https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict, the input can be:

  • A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
  • A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
  • A tf.data dataset.
  • A generator or keras.utils.Sequence instance. A more detailed description of unpacking behaviour for iterator types (Dataset, generator, Sequence) is given in the Unpacking behaviour for iterator-like inputs section of Model.fit.

from fruit-images-dataset.

nckenn avatar nckenn commented on August 23, 2024

In the [Fruits-360 CNN.py](src/image_classification/Fruits-360 CNN.py) file, we are using
model.load_weights(model_out_dir + "/model.h5")
to load the saved weights from the .h5 file. Then, using
y_pred = model.predict(testGen, steps=(testGen.n // batch_size) + 1, verbose=verbose)
we store the prediction in the y_pred variable. The variable is a list of arrays. Each array contains a probability for each of the classes, the index of the highest probability being the predicted class. To get the class name, you can look at the same index in the labels array.
The predict method is not limited to generators. According to the documentation found here https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict, the input can be:

  • A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
  • A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
  • A tf.data dataset.
  • A generator or keras.utils.Sequence instance. A more detailed description of unpacking behaviour for iterator types (Dataset, generator, Sequence) is given in the Unpacking behaviour for iterator-like inputs section of Model.fit.

Thank you for quick reply, very much appreciated, can you take a look at my code below, im still new ML, thats why i dont understand some of the terms you used. Am i in the right path?

`
image = cv.imread('assets/img/banana.jpg')
image = cv.resize(image, (100, 100))

model = tensorflow.keras.models.load_model('models/model.h5', custom_objects={'tf' : tf})
data = np.ndarray(shape=(1, 100, 100, 3), dtype=np.float32)

image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
data[0] = normalized_image_array
predictions = model.predict(data,1)

classes = predictions.argmax(axis=1)
print(predictions)

`

from fruit-images-dataset.

Horea94 avatar Horea94 commented on August 23, 2024

The overall logic seems fine. However, in our model, we are representing images as integers rather than floats and we do not apply normalisation. So maybe you could change data to:
data = np.ndarray(shape=(1, 100, 100, 3), dtype=np.float32)
and use the image_array instead of the normalized_image_array.

from fruit-images-dataset.

nckenn avatar nckenn commented on August 23, 2024

The overall logic seems fine. However, in our model, we are representing images as integers rather than floats and we do not apply normalisation. So maybe you could change data to:
data = np.ndarray(shape=(1, 100, 100, 3), dtype=np.float32)
and use the image_array instead of the normalized_image_array.

Im still confused :(, how can i classify the label? there is no label.txt generated after running Fruits-360 CNN.py, do i need to create manually?

from fruit-images-dataset.

Horea94 avatar Horea94 commented on August 23, 2024

Fruits-360 CNN.py uses an in-memory array of labels. By default, this array is built as the list of all the folder names in the Training folder. If the use_label_file is set to true, it will look for a labels.txt file that the user creates. That file should contain the label names (folder names from the Training folder) that you want to train the model for, one per line.
So when predicting with the trained model, you can build the labels array with similar logic, and use it to convert the predicted class(which will be a number) to a human-readable name.

from fruit-images-dataset.

nckenn avatar nckenn commented on August 23, 2024

Fruits-360 CNN.py uses an in-memory array of labels. By default, this array is built as the list of all the folder names in the Training folder. If the use_label_file is set to true, it will look for a labels.txt file that the user creates. That file should contain the label names (folder names from the Training folder) that you want to train the model for, one per line.
So when predicting with the trained model, you can build the labels array with similar logic, and use it to convert the predicted class(which will be a number) to a human-readable name.

now everything works fine, except for this one issue, when i use the image from the Training folder as input image, the prediction is accurate, but when i used the image downloaded from google, the prediction is not correct? any suggestion on my code?

When using this one, prediction is correct:
orange-d

But when i used this, prediction is not correct
orange

`
image = cv.imread('assets/img/orange.jpg')
image = cv.resize(image, (100, 100))

model = tensorflow.keras.models.load_model('models/model.h5', custom_objects={'tf' : tf})
data = np.ndarray(shape=(1, 100, 100, 3), dtype=np.int32)

image_array = np.asarray(image)
data[0] = image_array
predictions = model.predict(data, 1)

classes = predictions.argmax(axis=-1)
print(classes)

`

from fruit-images-dataset.

Horea94 avatar Horea94 commented on August 23, 2024

The issue is most likely not in the code.
The images in the dataset are taken under the same lighting conditions, using one fruit per class.
For the model to properly generalise on real-world images, you would need a lot more variance in the data.
Some degree of variance can be introduced with data augmentation(image flips, hue/saturation/light changes). Perhaps use image processing to alter the background of the training images, as the dataset only uses a white background.
Alternatively, you could add more images to your training dataset.

from fruit-images-dataset.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.