GithubHelp home page GithubHelp logo

rojberr / titanic-machine-learning Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 1.17 MB

Predict TItanic survivals

Jupyter Notebook 87.31% Python 10.01% Dockerfile 0.21% Shell 0.72% HTML 1.75%
id3 id3-algorithm machine-learning

titanic-machine-learning's Introduction

titanic-machine-learning

This report is summarizing a machine learning project that predicts the survival (yes/no) of passengers on the Titanic. It uses Iterative Dichotomiser 3 decision tree algorithm.

The dataset is from Kaggle and contains information about the passengers on the Titanic. https://www.kaggle.com/competitions/titanic/overview

The goal is to predict whether a passenger survived or not based on the features in the dataset.

iceberg

Usage

To predict which passengers survived the Titanic disaster, you can use the Docker image with ID3 algorithm. Run the container:

docker run -p 5000:5000 ghcr.io/rojberr/titanic-machine-learning:master

... then open http://127.0.0.1:5000/ with your browser, input yours passenger data and predict if he or she survived or not.

To build locally and run the container use:

docker build -t predict .
docker run -p 5000:5000 predict

... then open http://127.0.0.1:5000/ with your browser.

Algorithm

The iterative_dichotomiser_3.ipynb Jupyter notebook contains the implementation of the ID3 algorithm.

It uses a recursive approach to build the tree.

The algorithm is based on the information gain that is calculated using the entropy.

Information Gain = Entropy(parent) - [weighted average] * Entropy(children)

The whole proces consists of various steps:

  1. Load the data from csv file put-titanic-homework.csv
  2. Data preprocessing - saving preprocessed data in a datastore
  3. Create a DecisionTreeClassifier() object
    • initialize the object using values, column (feature names) and stating the labels (survived = 1, not survived = 0)
  4. Train the model (create the tree) by initializing the object and calling the id3() method (commonly called fit() in sklearn)
  5. The model will calculate the:
    • entropy for the whole dataset
    • take all passengers and calculate the entropy for each feature
    • calculate the information gain for each feature
    • choose the feature that maximizes the information gain
    • add the best information gain feature as a node in the tree
    • for each possible choice in this feature:
      • calculate remaining passengers for this (split the dataset based on the choice)
      • calculate the entropy for the new dataset
      • calculate the information gain for the new dataset
      • choose the feature that maximizes the information gain
      • add the best information gain feature as a node in the tree
      • repeat the process until:
        • all passengers survived from subset survived or not -> then add a leaf with 1 or 0 label respectively
        • there are no more features to split on -> then add a leaf with the most common label in the subset

Results

The algorithm was compared to the baseline results taken from Kaggle website.

The baseline results are as follows: Naive Bayes (72.6%) Logistic Regression (82.1%) Decision Tree (77.6%) K Nearest Neighbor (80.5%) Random Forest (80.6%) Support Vector Classifier (83.2%) Xtreme Gradient Boosting (81.8%) Soft Voting Classifier - All Models (82.8%)

** This ID3 implementation scored 77.55% accuracy trained on put-titanic-homework.csv and tested on kaggle-titanic-test.csv dataset. **

It needs to be noted that the results are based on the training dataset from Kaggle website and our ID3 tree was trained using a slightly simpler dataset using fewer features.

You can run and see yourself installing Pip dependencies and running Bash script that I prepared:

./code/pipe.sh

The result should be:

[INFO]: This script trained ID3 tree and saved it in ./model.pkl file as pickle.
|  └ Sex
| |  └ female
| | |  └ SibSp
...
| | | | |  └ Pclass
| | | | | |  └ 1
...

[INFO]: This script plotted ID3 tree model and saved it in ./plot.txt file.
    PassengerId  Pclass                                               Name     Sex     Age  SibSp  Parch  Survived  Predicted
0             1       3                            Braund, Mr. Owen Harris    male  middle      1      0         0          0
1             2       1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  middle      1      0         1          1
..          ...     ...                                                ...     ...     ...    ...    ...       ...        ...
98           99       2               Doling, Mrs. John T (Ada Julia Bone)  female  middle      0      1         1          1
99          100       2                                  Kantor, Mr. Sinai    male  middle      1      0         0          0
[100 rows x 9 columns]

The percentage of equal values in the last two columns is: 93.00%
[INFO]: This script tested predictions output based on ID3 tree and saved it in ./test_output_1.csv file.
     PassengerId  Survived  Pclass                                               Name     Sex   Age  SibSp  Parch            Ticket     Fare Cabin Embarked  Predicted
0              1         0       3                            Braund, Mr. Owen Harris    male  22.0      1      0         A/5 21171   7.2500   NaN        S        0.0
1              2         1       1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1      0          PC 17599  71.2833   C85        C        1.0
..           ...       ...     ...                                                ...     ...   ...    ...    ...               ...      ...   ...      ...        ...
889          890         1       1                              Behr, Mr. Karl Howell    male  26.0      0      0            111369  30.0000  C148        C        0.0
890          891         0       3                                Dooley, Mr. Patrick    male  32.0      0      0            370376   7.7500   NaN        Q        0.0
[891 rows x 13 columns]
The percentage of equal values in the last two columns is: 77.55%


[INFO]: This script tested predictions output based on ID3 tree and saved it in ./test_output_2.csv file.
    PassengerId  Pclass                                               Name     Sex     Age  SibSp  Parch  Survived  Predicted
0             1       3                            Braund, Mr. Owen Harris    male  middle      1      0         0          0
1             2       1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  middle      1      0         1          1
..          ...     ...                                                ...     ...     ...    ...    ...       ...        ...
98           99       2               Doling, Mrs. John T (Ada Julia Bone)  female  middle      0      1         1          1
99          100       2                                  Kantor, Mr. Sinai    male  middle      1      0         0          0
[100 rows x 9 columns]
[INFO]: This script predicted output based on ID3 tree and saved it in ./output.csv file.

TODO:

  • build model artifact as pickle
  • build interference.py - python function that preprocess input features and passes them to model
  • rest api wrapper fastapi / flask
  • modify the dockerfile
  • create train pipeline - pulls from datastore and creates model and saves it
  • Add testing pipeline - takes the model and tests it
  • Add interference pipeline - somebody inputs data and expects a prediction
  • Plot in jpg
  • Try your model on dataset from kaggle
  • Add page/flow for error analysis (I want to build better models by trying and seeing what is wrong, if the model got better or worse)
  • Test you model on training dataset, to assume how good it would be on dev set

Implement such an architecture

img.png

titanic-machine-learning's People

Contributors

rojberr avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.