Understanding Decision Trees in Machine Learning #
A Decision Tree is a simple yet powerful tool used to make decisions by visually breaking down different choices and their possible outcomes. It resembles a flowchart-like structure where each step represents a decision based on certain conditions. In machine learning, decision trees are widely used for classification (categorizing data) and regression (predicting values).
In essence, a decision tree helps transform complex decision-making processes into a clear and structured format that is easy to interpret.
How a Decision Tree Works #
A decision tree starts with a single point and gradually splits into multiple branches based on the features of the dataset. Each split is made to separate the data into more meaningful groups.
- Root Node
This is the top-most part of the tree. It represents the entire dataset and acts as the starting point for all decisions. - Branches
These are the lines that connect different nodes. Each branch represents a possible decision or outcome from a node. - Internal Nodes
These nodes represent decision points where the dataset is split based on specific conditions or features. - Leaf Nodes (Terminal Nodes)
These are the final nodes of the tree. They provide the ultimate outcome, such as a predicted class or value.
Types of Decision Trees #
Decision trees can be broadly divided into two main categories:
- Classification Trees
Used when the output variable is categorical (e.g., Yes/No, Spam/Not Spam). - Regression Trees
Used when the output variable is continuous (e.g., predicting price, temperature).
Key Concepts in Decision Trees #
- Splitting
The process of dividing a node into sub-nodes based on a condition. - Impurity Measures
Metrics like Gini Index and Entropy are used to determine how well a split separates the data. - Information Gain
It measures how much uncertainty is reduced after splitting the data. - Pruning
The technique of removing unnecessary branches to avoid overfitting and improve model performance.




What is a Decision Tree? #
A Decision Tree is a method used to simplify decision-making by organizing choices and their possible results in a structured, visual way. It looks like a tree, where each step represents a question or condition, and the answers lead to different paths.
In machine learning, decision trees are commonly used to analyze data, make predictions, and classify information. They help break down complex problems into smaller, easier-to-understand parts.

Structure of a Decision Tree #
A decision tree begins with a single starting point and grows into multiple branches based on the data. Each part of the tree has a specific role:
- Root Node
This is the topmost node of the tree. It represents the complete dataset and is the first point where the data is split. - Branches
These are the connecting lines between nodes. Each branch shows the outcome of a decision or condition. - Internal Nodes
These nodes act as decision points where the data is divided further based on specific features or rules. - Leaf Nodes (Terminal Nodes)
These are the final points in the tree. Each leaf node provides the final result, such as a predicted category or value.
The Implementation Pipeline #
Before we code, let’s look at the logical flow of our system:
import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
Load Dataset
# Load your dataset
df = pd.read_csv('/kaggle/input/datasets/sarimr/heart-attack-dataset-analysis-and-prediction/heart.csv')
Show First 10 Columns Entries
df.head(10)

Stem Plot of Cholesterol Levels
# 2. Create the Stem Plot
plt.figure(figsize=(10, 6))
# Plotting 'chol' (Cholesterol) against the index of the dataframe
# markerline: the dots at the top
# stemlines: the vertical lines
# baseline: the horizontal line at y=0
(markerline, stemlines, baseline) = plt.stem(df.index, df['chol'], linefmt='steelblue', markerfmt='o', basefmt='r-')
# 3. Customizing the plot for better readability
plt.setp(markerline, color='darkred', markersize=8) # Change marker color and size
plt.title('Stem Plot of Cholesterol Levels', fontsize=14)
plt.xlabel('Patient Index', fontsize=12)
plt.ylabel('Cholesterol (mg/dl)', fontsize=12)
plt.xticks(df.index) # Show every index on X-axis
plt.grid(axis='y', linestyle='--', alpha=0.6)
# 4. Display the plot
plt.show()
3D Surface Analysis: Heart Disease Features
# Initialize the 3D figure
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
# Define the axes from your existing DataFrame 'df'
x = df['age']
y = df['trestbps']
z = df['chol']
# Create the 3D Surface Plot
# cmap='magma' or 'viridis' provides high contrast for peaks and valleys
surf = ax.plot_trisurf(x, y, z, cmap='magma', edgecolor='none', alpha=0.9)
# Labeling the axes
ax.set_xlabel('Age', fontsize=12)
ax.set_ylabel('Resting BPS (trestbps)', fontsize=12)
ax.set_zlabel('Cholesterol (chol)', fontsize=12)
ax.set_title('3D Surface Analysis: Heart Disease Features', fontsize=15)
# Add a color bar to map the Z-axis (Cholesterol) values
fig.colorbar(surf, shrink=0.5, aspect=10)
plt.tight_layout()
plt.show()
Distribution of Patient Ages
# Set the style
sns.set_theme(style="whitegrid")
# Create the histogram
plt.figure(figsize=(10, 6))
sns.histplot(df['age'], bins=10, kde=True, color='teal')
# Add labels and title
plt.title('Distribution of Patient Ages', fontsize=15)
plt.xlabel('Age', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.show()
Average Cholesterol Levels by Sex
# Set the visualization style
sns.set_theme(style="darkgrid")
# Create the Point Plot
# We'll plot 'sex' on the x-axis and 'chol' (Cholesterol) on the y-axis
plt.figure(figsize=(8, 6))
sns.pointplot(data=df, x='sex', y='chol', color='darkorange', capsize=.1)
# Customizing the labels and title
plt.title('Average Cholesterol Levels by Sex', fontsize=15)
plt.xlabel('Sex (0 = Female, 1 = Male)', fontsize=12)
plt.ylabel('Average Cholesterol (mg/dl)', fontsize=12)
# Display the plot
plt.show()
# Features (age, sex, cp, trestbps, chol, etc.)
X = df.drop('target', axis=1)
# Target variable
y = df['target']
# Split: 80% Training, 20% Testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)Train the model
# Initialize the model # We set max_depth to prevent overfitting clf = DecisionTreeClassifier(criterion='entropy', max_depth=4, random_state=42) # Train the model clf.fit(X_train, y_train)

plt.figure(figsize=(20,10)) plot_tree(clf, feature_names=X.columns, class_names=['No Disease', 'Disease'], filled=True) plt.show()

Confusion Matrix
# Make predictions
y_pred = clf.predict(X_test)
# Confusion Matrix Heatmap
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
# Final Metrics
print(f"Accuracy Score: {accuracy_score(y_test, y_pred) * 100:.2f}%")
print("\nClassification Report:\n", classification_report(y_test, y_pred))
Accuracy Score: 88.52%
Classification Report:
precision recall f1-score support
0 0.87 0.90 0.88 29
1 0.90 0.88 0.89 32
accuracy 0.89 61
macro avg 0.88 0.89 0.89 61
weighted avg 0.89 0.89 0.89 61