Exploratory Data Analysis (EDA)
A comprehensive guide to data sanity checks and EDA using Python
Before venturing into advanced data analysis or machine learning, it's essential to ensure that the data you're working with is clean and coherent. This article outlines the process of conducting Data Sanity checks and Exploratory Data Analysis (EDA), both of which are critical first steps in understanding your dataset. While the initial stages don’t involve modifying the data, these actions help uncover potential issues such as missing values, duplicates, and outliers, while providing valuable insights into the data's structure and relationships.
It's important to start by inspecting the dataset to familiarize yourself with its contents and structure. Understanding the data is key to identifying any problems that might affect your analysis. Once this preliminary examination is complete, you can decide whether data cleaning or preprocessing is needed. For now, we're focused on gaining an understanding of the data, not making changes, to better plan any necessary cleanup efforts.
Recently, I worked with a nutritional dataset from Kaggle using Google Colab, which allows for writing and executing Python code in a browser-based Jupyter notebook. The objective was to analyze nutrient patterns across a range of foods and categorize them based on their nutrient content.
Kaggle is a fantastic resource for data scientists and enthusiasts, offering a wealth of datasets across different fields. Whether you're looking to explore healthcare, finance, or more specialized areas like nutrition, Kaggle provides an excellent platform to practice data analysis and machine learning techniques. Additionally, Kaggle hosts competitions that allow you to test your skills and learn from the broader community.
The dataset I explored included nutritional values for various foods and products, detailing their protein, fat, vitamin C, and fiber content.
You can find the full code and analysis in my Google Drive folder. here.
Section 1: Understanding the Dataset
Objective: Load the dataset into Python and perform initial inspection.
- Use pandas.read_csv() or equivalent loading functions.
- Inspect the first few rows using df.head() and df.tail() for a snapshot of the dataset.
- Check dataset structure and basic metadata with df.info().
- View basic statistics with df.describe() to understand the distribution of numerical columns.
- 1.1 Data Loading and Initial Exploration
-
To load a CSV file into a Google Colab Jupyter Notebook, you can use Google Drive to access files stored in your account. But first, you need to mount your Google Drive by running a code snippet that authorizes the notebook to access to your files. This process allows seamless integration of data stored in Google Drive, making it convenient to analyze datasets stored in your cloud storage within the Colab environment.
The first step in any analysis is understanding the data itself. I begin by loading the dataset into Python using the `pandas` library using the read_csv() method from pandas.
# Mount Google Drive to access files from google.colab import drive drive.mount('/content/drive') # Load the data using Pandas import pandas as pd # Source --> https://www.kaggle.com/datasets/trolukovich/nutritional-values-for-common-foods-and-products?resource=download data_path = '/content/drive/MyDrive/CoLab/nutrition.csv' data = pd.read_csv(data_path)
- 1.2 Shape and Size of Data
-
Once the data is loaded, we need to inspect it to get a high-level overview of its structure. This involves checking the size of the data, the column names, and the general characteristics of each column. Understanding the dataset's dimensionality is the first step.
The shape tells us how many rows (observations) and columns (features):
# Display the shape of the dataset print("Shape of the dataset:", data.shape)
- 1.3 View Rows of Data
-
Sometimes the best way to understand the data is to look at it. The head() and tail() methods displays the first and last few rows of the dataset, giving us a glimpse of the data. This helps us understand the structure and format of the data and help identify any issues that may need to be addressed.
# Display the first 5 and last 5 rows of the dataset data.head(5) data.tail(5)
- 1.4 View the features (Columns)
-
After viewing the rows, it's essential to understand how the features (columns) are being defined in the dataset. Since we have seen the data, we can now confirm if the dataframe accurately represents the data we have seen. Are types correct? Are there any columns that need to be converted to a different data type? We know that numerical data should be represented as integers or floats, and categorical data should be represented as strings or categories. This uses the 'info' method in Pandas.
# Display the data types of each column data.info()
- 1.5 Identify Unique Values
-
The unique() method helps us identify the unique values in a column. This is useful for categorical columns to understand the different categories present in the dataset.
# Display the count of Unique Values in each column for column in data.columns: print(f"Unique values in '{column}': {data[column].nunique()}")
Section 2: Data Sanity Checks
Objective: Indentify missing values, duplicates, and data integrity issues.
Before analyzing the dataset, it's essential to perform sanity checks to ensure data integrity. These checks help identify missing values, duplicates, and inconsistent data types. Data Sanity Checks are essential steps to ensure the quality and integrity of a dataset before performing any analysis. These checks help identify and address issues such as missing values, outliers, and inconsistencies.
- 2.1 Missing Data Detection
-
Missing values can affect the quality of our analysis and results. It's essential to check for missing values and decide how to handle them. Common strategies include removing rows with missing values, imputing missing values, or using advanced imputation techniques.
Handling missing values could involve:
- Dropping rows or columns with too many missing values
- Imputation, where you fill missing values using the mean, median, mode, or other statistical methods
- Forward/Backward filling for time-series data
- Using advanced imputation techniques like KNN imputation or MICE
Missing values can lead to biased results and should be handled appropriately. Common methods include removing rows with missing values, imputing missing values with mean/median, or using advanced techniques like KNN imputation.
Missing data can significantly impact the quality of our analysis. It's essential to identify and handle missing values before proceeding with any analysis. The isnull() method in Pandas helps us identify missing values in the dataset.
# Check for missing values in the dataset missing_values = data.isnull().sum() print("Missing values in each column:\n", missing_values) # Optionally, visualize missing values using a heatmap import seaborn as sns import matplotlib.pyplot as plt plt.figure(figsize=(10, 6)) sns.heatmap(data.isnull(), cbar=False, cmap='viridis') plt.title('Heatmap of Missing Values') plt.show()
Visualize missing data patterns with missingno library (e.g., missingno.matrix(df) or missingno.heatmap(df)).
import missingno as msno msno.matrix(data)
- 2.2 Duplicate Data Detection
-
Duplicates can skew our analysis and lead to incorrect conclusions. It's crucial to identify and remove duplicate rows from the dataset. The duplicated() method in Pandas helps us identify duplicate rows. If they do exist, we can count them and decide how to handle them.
# Check for duplicate rows in the dataset duplicate_rows = data.duplicated().sum() print(f"Duplicate rows in the dataset: {duplicate_rows}")
- 2.3 Consistency of Data types
-
Data types play a crucial role in data analysis. It's essential to ensure that the data types are consistent with the values in the dataset. The dtypes attribute in Pandas helps us check the data types of each column. Cross-verify if numerical data is stored as numbers, and if categorical data is appropriately formatted.
# Check the data types of each column data.dtypes
- 2.4 Outlier Detection
-
Outliers can skew our analysis, so detecting them is an important part of sanity checks. Outliers can be identified visually using boxplots or statistically using the IQR method.
- Use boxplots (seaborn.boxplot or matplotlib’s boxplot) to visually inspect outliers.
- Calculate the Interquartile Range (IQR) and use it to identify outliers programmatically.
Outliers are data points that significantly differ from other observations. They can distort statistical analyses and should be investigated. Methods to handle outliers include removing them, transforming the data, or using robust statistical techniques. Outliers can be indicative of data entry errors, unusual behavior, or rare events, and detecting them early allows for informed decision-making on whether to address or retain them in further analyses. Key Statistics for Outlier Detection:
- Mean: Average value in the dataset.
- Median: The middle value, helpful for skewed distributions.
- Mode: Most frequent value, useful in categorical data.
- Min/Max: Helps identify extreme values.
- Standard Deviation: Shows the spread of the data.
- IQR (Interquartile Range): Captures the range of the middle 50%, useful for detecting outliers.
The IQR (Interquartile Range) method is commonly used to identify outliers. It defines outliers as points outside 1.5 times the IQR from the first and third quartiles. A box plot can help visualize outliers in the data.
For numeric columns, the most straightforward and effective way to detect outliers is by using box plots. Box plots are a visual tool that display the distribution of a dataset, highlighting the quartiles and identifying potential outliers through the use of “whiskers,” which typically extend to 1.5 times the interquartile range (IQR). Any points outside this range are flagged as potential outliers. This method provides a quick, visual check for outliers and gives an overall sense of data spread and variability.
Pairing the box plot with a scatter plot allows us to examine how values are distributed across the dataset, offering further context on potential outliers. Scatter plots provide a view of individual data points, which can help determine if outliers are isolated or part of a broader pattern. By using a combination of box plots and scatter plots, we can visually detect and investigate outliers in numeric data with minimal computational overhead.
For non-numeric or high-dimensional data (such as categorical or text columns), visualizations like bar charts can be inefficient and slow to render, especially when dealing with large datasets or columns with many unique values (high cardinality). Instead of using visual methods, we recommend a summary-based approach for these types of columns. This approach involves generating key statistics for each non-numeric column, such as:
- Unique Values: Count of distinct values in the column, which can highlight unusually high cardinality.
- Most Frequent Value: The mode of the column, which provides insight into the most common category or entry.
- Frequency of the Most Frequent Value: Knowing how often the most frequent value occurs can help spot anomalies, such as over-represented categories.
- Total Non-Null Entries: The count of valid entries in the column, useful for identifying sparsity or missing data.
This summary allows for a high-level overview of non-numeric columns without the delays and performance issues associated with plotting. This method provides a quick sanity check for categorical or textual data by identifying potential anomalies in the distribution of values.
To ensure the outlier detection process remains efficient and scalable as the size of our datasets grows, it is important to tailor the approach based on the data type. For numeric data, visual methods like box plots and scatter plots are appropriate, as they provide fast and intuitive insights. For non-numeric data, where the cardinality or the data size may be large, a summary-based approach is more efficient.
By automatically detecting and summarizing the column types, we can avoid unnecessary computations and still provide a comprehensive overview of potential outliers. This hybrid approach balances thoroughness with performance, ensuring that outlier detection does not become a bottleneck in our data validation pipeline.
import pandas as pd import matplotlib.pyplot as plt # Example DataFrame (Replace with your own DataFrame 'data') # data = pd.read_csv('your_data.csv') # Loop through all columns in the DataFrame for column in data.columns: if pd.api.types.is_numeric_dtype(data[column]): # Check if the column is numeric fig, axs = plt.subplots(1, 2, figsize=(16, 4)) # Set wider and shorter figure size # Boxplot (rotated horizontally) axs[0].boxplot(data[column].dropna(), vert=False) axs[0].set_title(f'Box plot of {column}') axs[0].set_xlabel(column) axs[0].grid(True) # Scatter plot (against index) axs[1].scatter(data.index, data[column], alpha=0.5) # Scatter plot with index on the x-axis axs[1].set_title(f'Scatter plot of {column}') axs[1].set_xlabel('Index') axs[1].set_ylabel(column) axs[1].grid(True) # Adjust layout and spacing plt.subplots_adjust(wspace=0.4) # Add space between the plots plt.tight_layout() plt.show() # Show the combined plots else: # For non-numeric columns, provide a simple summary instead of a bar chart unique_values = data[column].nunique() top_value = data[column].mode()[0] # Most frequent value top_value_count = data[column].value_counts().iloc[0] # Frequency of the top value total_count = len(data[column].dropna()) # Print a summary of the non-numeric column print(f"Column '{column}':") print(f" - Unique values: {unique_values}") print(f" - Most frequent value: '{top_value}' (Count: {top_value_count})") print(f" - Total non-null entries: {total_count}") # Add a line break between each feature for better separation print("\n")
Once potential outliers are identified, it is important to have a plan for handling them. Depending on the nature of the dataset and the context of the analysis, outliers can either be removed, transformed, or retained for further investigation. Outliers that result from data entry errors or inconsistencies should be corrected or removed. However, in some cases (such as fraud detection or rare event prediction), outliers may represent important signals and should be preserved. Implementing a review process for outliers ensures that we maintain data integrity while addressing any anomalies.
Incorporating outlier detection as part of our data sanity checks ensures that we identify and address anomalies before they impact downstream analysis or model performance. By using a visual approach for numeric data and a summary-based approach for non-numeric data, we maintain efficiency without sacrificing accuracy or detail. This method ensures that our data is clean, reliable, and ready for further analysis or modeling, ultimately improving the quality of our insights and decisions.
Section 3: Exploratory Data Analysis (EDA)
Objective: Load the dataset into Python and perform initial inspection.
Now that we've ensured the data is sane, we can begin exploring it to uncover patterns and relationships. We'll start with univariate analysis before moving on to bivariate and multivariate explorations.
3.1 Univariate Data Analysis
Univariate analysis focuses on analyzing individual variables in the dataset. It helps us understand the distribution of each feature and identify patterns or anomalies. Common univariate analysis techniques include histograms, bar plots, and frequency tables.
- Use histograms (df.hist()) to visualize the distribution of numerical features.
- Use bar plots (df['column'].value_counts().plot(kind='bar')) to visualize categorical features.
Univariate Analysis is a type of data analysis in data science that focuses on examining a single variable at a time. The term "univariate" combines "uni-" meaning "one" and "variate," which refers to a variable. Essentially, it’s about understanding the distribution, central tendency, and spread of one variable in isolation, without considering relationships with other variables.
Key Aspects of Univariate Data Analysis:
- Distribution
The distribution of a variable shows how its values are spread out. Common ways to visualize distributions include histograms, box plots, and density plots.
- Central Tendency
Central tendency measures where the center of the data lies. The most common measures are the mean (average), median (middle value), and mode (most frequent value).
- Spread or Variability
Exploring the variability or dispersion within the variable, commonly using measures like range, variance, and standard deviation. This helps understand how spread out the data points are from the central value.
Techniques for Univariate Analysis:
- Histogram
-
A histogram is a graphical representation of the distribution of a variable. It shows the frequency of values within different intervals or bins. Here is an example of how to create histograms for all numeric columns in Python:
import matplotlib.pyplot as plt # Create histograms for all numeric columns data.hist(figsize=(16, 12), bins=20, edgecolor='black') plt.suptitle('Histograms of All Numeric Columns', fontsize=16) plt.show()
- Descriptive Statistics
-
The describe() method in Pandas provides a summary of the dataset, including data types and descriptive statistics. Using the .T attribute transposes the summary for better readability.
Descriptive statistics summarize the main characteristics of a variable, including measures of central tendency and variability. I like to add the transpose attribute '.T' to the describe method to make the output more readable.
# Display descriptive statistics for all numeric columns data.describe().T
- Box Plot
-
A box plot displays the distribution of a variable using quartiles. It shows the median, interquartile range, and potential outliers.
import seaborn as sns import matplotlib.pyplot as plt # Create box plots for all numeric columns plt.figure(figsize=(16, 12)) sns.boxplot(data=data) plt.title('Box Plots of All Numeric Columns') plt.show()
Key compents of a box plot:
- Minimum (Lower Whisker)
The minimum is the smallest value in the dataset, excluding outliers. It represents the lower end of the data. In the boxplot, it’s typically connected to Q1 by a line called the “whisker.” The lower whisker extends to the smallest data point within 1.5 times the interquartile range (IQR) below Q1.
- First Quartile (Q1)
The first quartile (Q1) is the value below which 25% of the data falls. It marks the lower boundary of the box in the boxplot.
- Median (Second Quartile Q2)
The median is the middle value of the dataset when it’s sorted in ascending order. It’s represented by the line inside the box.
- Third Quartile (Q3)
The third quartile (Q3) is the value below which 75% of the data falls. It marks the upper boundary of the box in the boxplot.
- Maximum (Upper Whisker)
The maximum is the largest value in the dataset, excluding outliers. It represents the upper end of the data. In the boxplot, it’s typically connected to Q3 by a line called the “whisker.” The upper whisker extends to the largest data point within 1.5 times the interquartile range (IQR) above Q3.
- Interquartile Range (IQR)
The interquartile range (IQR) is the range between the first and third quartiles (Q1 and Q3). It represents the middle 50% of the data. The box in the boxplot spans the IQR, showing the spread of the central data.
- Skewness and Kurtosis
-
Skewness measures the asymmetry of the distribution, indicating whether the data is skewed to the left or right. Kurtosis measures the shape of the distribution, showing how peaked or flat it is compared to a normal distribution. These statistics provide insights into the shape of the data distribution.
# Calculate skewness and kurtosis for all numeric columns skewness = data.skew() kurtosis = data.kurtosis() print("Skewness of each column:\n", skewness) print("Kurtosis of each column:\n", kurtosis)
- KDE Plot
-
A Kernel Density Estimate (KDE) plot is a smoothed version of a histogram that shows the distribution of a variable. It provides a continuous estimate of the probability density function of the data. A KDE plot (Kernel Density Estimation) provides a smooth, continuous curve that estimates the probability density of the data. Unlike a histogram, it doesn't rely on binning but instead smooths over the data, offering a more fluid representation of the distribution.
import seaborn as sns import matplotlib.pyplot as plt # Create KDE plots for all numeric columns plt.figure(figsize=(16, 12)) # Calculate the number of rows needed based on the number of columns num_rows = (len(data.select_dtypes(include='number').columns) + 2) // 3 for i, col in enumerate(data.select_dtypes(include='number').columns): # Adjust subplot layout to accommodate more plots plt.subplot(num_rows, 3, i + 1) sns.kdeplot(data[col], shade=True) plt.title(f'KDE Plot of {col}') plt.tight_layout() plt.show()
Univariate Data Analysis helps to summarize and understand the characteristics of individual variables, providing a foundation for more complex analyses. Univariate analysis provides a foundation for understanding data by highlighting patterns within individual variables. It’s an essential first step in data exploration, often setting the stage for more complex analyses like bivariate or multivariate analysis.
3.2 Bivariate Data Analysis
Bivariate analysis focuses on analyzing the relationship between two variables in the dataset. It helps us understand how variables interact with each other and identify correlations. Common bivariate analysis techniques include scatter plots, pair plots, and correlation matrices.
- Use scatter plots (sns.scatterplot) to visualize the relationship between two numerical features.
- Use pair plots (sns.pairplot) to visualize relationships between multiple numerical features.
- Use correlation matrices (df.corr()) to quantify the relationship between numerical features.
Bivariate Analysis is a statistical method in data science that examines the relationship between two variables. The term "bivariate" stems from the prefix "bi-" meaning "two" and "variate," which refers to variables. In essence, bivariate analysis explores how one variable interacts or correlates with another. It helps to understand how one variable changes with respect to another.
This analysis helps in understanding patterns, trends, or associations between the two variables. For example, in a dataset containing information about people's ages and their corresponding income levels, bivariate analysis could be used to investigate whether there is a relationship between age and income. Techniques commonly used for bivariate analysis include scatter plots, correlation coefficients, and cross-tabulation.
- Pair Plot
-
Pair plots are a grid of scatter plots that visualize the relationships between multiple numerical features. Pair plots are a powerful tool for visualizing relationships between multiple numerical features. They provide a grid of scatter plots, where each plot shows the relationship between a pair of features. This allows us to quickly identify patterns, correlations, and potential outliers in the data.
However, there are some considerations to keep in mind when using pair plots:
-
Pros:
- Comprehensive Visualization: Pair plots provide a detailed view of relationships between all pairs of numerical features.
- Easy Identification of Patterns: They help in identifying patterns, correlations, and potential outliers.
- Useful for Initial Exploration: Pair plots are great for initial data exploration and understanding the structure of the dataset.
-
Cons:
- Computationally Intensive: Pair plots can be computationally expensive, especially for large datasets with many features.
- Cluttered Visuals: For datasets with many features, pair plots can become cluttered and hard to interpret.
- Limited to Numerical Data: Pair plots are primarily useful for numerical data and may not be suitable for categorical features.
-
Caution:
- Compute Time: Generating pair plots for large datasets can be time-consuming. Consider sampling the data or using a subset of features if compute time is a concern.
- Display Size: Pair plots can take up a lot of space, making them difficult to display on smaller screens. Ensure you have adequate display space or consider saving the plots to a file for later review.
import seaborn as sns import matplotlib.pyplot as plt # Create a pair plot of numerical features sns.pairplot(data) plt.show()
-
Pros:
- Correlation Coefficient and Matrix
-
The correlation coefficient measures the strength and direction of the linear relationship between two variables. Values range from -1 to 1, where 1 indicates a perfect positive relationship, -1 indicates a perfect negative relationship, and 0 indicates no relationship.
A Correlation Matrix is a table that shows how different things (variables) are related to each other. It helps us understand whether two things increase or decrease together, or if they don’t have much of a relationship at all. A correlation matrix helps us to see connections between things in an organized way. It's a great tool to figure out where you might want to focus more attention and helps to guide further analysis.
- +1 means a perfect positive relationship: if one thing goes up, the other also goes up.
- 0 means no relationship: the two things don’t really affect each other.
- -1 means a perfect negative relationship: if one thing goes up, the other goes down.
# Calculate the correlation matrix correlation_matrix = data.corr() # Display the correlation matrix print(correlation_matrix) # Optionally, visualize the correlation matrix using a heatmap import seaborn as sns import matplotlib.pyplot as plt plt.figure(figsize=(10, 8)) sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', linewidths=0.5) plt.title('Correlation Matrix Heatmap') plt.show()
A correlation matrix is useful because it lets you see all the relationships in one table. People often use it when they have lots of data to quickly find out which things are connected. For example, a scientist studying health might use a correlation matrix to find out if eating certain foods is related to having a lower risk of disease. By using a correlation matrix, they can easily see which foods are most important and worth looking into further.
Bivariate Data Analysis helps to uncover relationships between variables, providing insights that can inform decision-making and further analysis.
- 3.3 Multivariate Analysis
-
Multivariate analysis focuses on analyzing the relationship between multiple variables in the dataset. It helps us understand complex interactions between features and identify patterns or trends. Common multivariate analysis techniques include heatmaps, cluster analysis, and dimensionality reduction.
- Use heatmaps (sns.heatmap) to visualize the relationship between multiple numerical features.
- Use cluster analysis (KMeans, DBSCAN) to group similar observations based on features.
- Use dimensionality reduction (PCA, t-SNE) to reduce the number of features while preserving important information.
# Heatmap to visualize the relationship between multiple numerical features sns.heatmap(df.corr(), annot=True) plt.show() # Cluster analysis to group similar observations based on features from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=3) df['cluster'] = kmeans.fit_predict(df) # Dimensionality reduction to reduce the number of features from sklearn.decomposition import PCA pca = PCA(n_components=2) df_pca = pca.fit_transform(df) df_pca = pd.DataFrame(df_pca, columns=['PC1', 'PC2']) df_pca.head()
- 3.4 Feature Relationship and correlations
-
Understanding the relationship between features is crucial for building predictive models. Correlation analysis helps us identify features that are highly correlated with the target variable. This step is essential for feature selection and model building.
# Correlation matrix to quantify the relationship between features and target variable df.corr()['target_variable'].sort_values(ascending=False)
- 3.5 Time Series Analysis (if applicable)
-
If the dataset includes temporal data, it’s essential to analyze trends, seasonality, and autocorrelations.
# Plotting time series data plt.plot(df['date_column'], df['value_column']) plt.show()
If your dataset contains time series data, time series analysis can provide valuable insights into trends and patterns over time. Time series analysis techniques include time series decomposition, forecasting, and anomaly detection.
# Time series decomposition to analyze trends and seasonality from statsmodels.tsa.seasonal import seasonal_decompose decomposition = seasonal_decompose(df['time_series_column'], model='additive') decomposition.plot() plt.show() # Time series forecasting to predict future values from statsmodels.tsa.arima.model import ARIMA model = ARIMA(df['time_series_column'], order=(1, 1, 1)) model_fit = model.fit() forecast = model_fit.forecast(steps=10) print(forecast)
Section 4: Plan for Data Cleaning and Preprocessing
Objective: Create a Plan for any data cleaning or preprocessing steps.
Section 5: Conclusion
By conducting thorough data sanity checks and EDA, we lay a strong foundation for further analysis. With a clear understanding of the data, the next steps could include feature engineering, advanced visualizations, or machine learning.
Automating Univariate and Bivariate Analysis in Python
During your Data Sanity checks, it's essential to classify your variables into numerical, categorical, and dependent types before starting your Exploratory Data Analysis (EDA).
In the early stages of data analysis, you will often need to determine whether your variables are numerical, categorical, or dependent. Identifying these is crucial for:
- Performing the correct statistical methods on your data
- Automating your exploratory analysis using scripts
- Generating meaningful insights into relationships between features
Once these variables are classified, you can begin the process of performing univariate (analyzing one variable) and bivariate (analyzing relationships between two variables) analysis. Automating this process will save you time and ensure consistency in your Exploratory Data Analysis (EDA).
Automating Univariate Analysis for Numerical and Categorical Features
In univariate analysis, you focus on understanding the distribution of a single feature. The following Python functions allow you to automate this process for both numerical and categorical features.
This function calculates the key statistical attributes for a numerical feature, including mean, median, variance, and skewness. It also provides visual insights using KDE plots, BoxPlots, and Histograms.
Here's the full implementation of the function:
def univariate_analysis(df, features):
for feature in features:
skewness = df[feature].skew()
minimum = df[feature].min()
maximum = df[feature].max()
mean = df[feature].mean()
mode = df[feature].mode().values[0]
unique_count = df[feature].nunique()
variance = df[feature].var()
std_dev = df[feature].std()
percentile_25 = df[feature].quantile(0.25)
median = df[feature].median()
percentile_75 = df[feature].quantile(0.75)
data_range = maximum - minimum
print(f"Univariate Analysis for {feature}")
print(f"Skewness: {skewness:.4f}")
print(f"Min: {minimum}")
print(f"Max: {maximum}")
print(f"Mean: {mean:.4f}")
print(f"Mode: {mode}")
print(f"Unique Count: {unique_count}")
print(f"Variance: {variance:.4f}")
print(f"Std Dev: {std_dev:.4f}")
print(f"25th Percentile: {percentile_25}")
print(f"Median (50th Pct): {median}")
print(f"75th Percentile: {percentile_75}")
print(f"Range: {data_range}")
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
sns.kdeplot(df[feature], fill=True)
plt.title(f"KDE of {feature}")
plt.subplot(1, 3, 2)
sns.boxplot(df[feature])
plt.title(f"Box Plot of {feature}")
plt.subplot(1, 3, 3)
sns.histplot(df[feature], bins=10, kde=True)
plt.title(f"Histogram of {feature}")
plt.tight_layout()
plt.show()
This function provides a comprehensive analysis for each numerical feature by calculating statistical attributes and generating KDE, BoxPlot, and Histogram visualizations.
Categorical Univariate Analysis
For categorical features, we analyze the distribution of categories and their relationship with the dependent feature. Here's a function that automates this process:
def univariate_analysis_categorical(df, categorical_features):
for feature in categorical_features:
unique_categories = df[feature].nunique()
mode = df[feature].mode().values[0]
mode_freq = df[feature].value_counts().max()
category_counts = df[feature].value_counts()
category_percent = df[feature].value_counts(normalize=True) * 100
missing_values = df[feature].isnull().sum()
total_values = len(df[feature])
imbalance_ratio = category_counts.max() / total_values
print(f"Univariate Analysis for {feature}")
print(f"Unique Categories: {unique_categories}")
print(f"Mode (Most frequent): {mode}")
print(f"Frequency of Mode: {mode_freq}")
print(f"Missing Values: {missing_values}")
print(f"Imbalance Ratio (Max/Total): {imbalance_ratio:.4f}")
print(f"Category Counts:\n{category_counts}")
plt.figure(figsize=(10, 6))
sns.countplot(x=df[feature], order=df[feature].value_counts().index)
plt.title(f"Frequency of {feature} Categories")
plt.xlabel(feature)
plt.ylabel("Count")
plt.tight_layout()
plt.show()
This function provides a clear understanding of how categories are distributed across the data and helps identify potential imbalances.
Automating Bivariate Analysis
Bivariate analysis allows you to understand the relationship between two variables. Here's how you can automate this process.
The following function calculates key attributes for a numerical feature in relation to a boolean dependent feature. It prints out key insights and generates side-by-side visualizations to understand their relationship.
def bivariate_analysis(df, numerical_features, categorical_features, dependent_feature):
if numerical_features:
for feature in numerical_features:
mean_0 = df[df[dependent_feature] == 0][feature].mean()
mean_1 = df[df[dependent_feature] == 1][feature].mean()
median_0 = df[df[dependent_feature] == 0][feature].median()
median_1 = df[df[dependent_feature] == 1][feature].median()
var_0 = df[df[dependent_feature] == 0][feature].var()
var_1 = df[df[dependent_feature] == 1][feature].var()
print(f"Mean {feature} for group 0: {mean_0:.2f}")
print(f"Mean {feature} for group 1: {mean_1:.2f}")
print(f"Median {feature} for group 0: {median_0:.2f}")
print(f"Median {feature} for group 1: {median_1:.2f}")
print(f"Variance of {feature} for group 0: {var_0:.2f}")
print(f"Variance of {feature} for group 1: {var_1:.2f}")
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
sns.boxplot(x=df[dependent_feature], y=df[feature], ax=axes[0])
axes[0].set_title(f"{feature} Distribution by {dependent_feature}")
sns.barplot(x=df[dependent_feature], y=df[feature], estimator='mean', ax=axes[1])
axes[1].set_title(f"Mean {feature} by {dependent_feature}")
plt.tight_layout()
plt.show()
if categorical_features:
for feature in categorical_features:
category_distribution = df.groupby([feature, dependent_feature]).size().unstack(fill_value=0)
chi2, p, dof, expected = chi2_contingency(category_distribution)
print(f"Chi-Square Test for {feature}: Chi2 = {chi2:.4f}, p-value = {p:.4f}")
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
sns.countplot(x=df[feature], hue=df[dependent_feature], ax=axes[0])
axes[0].set_title(f"{feature} Count by {dependent_feature}")
sns.barplot(x=df[feature], y=df[dependent_feature], estimator='mean', ax=axes[1])
axes[1].set_title(f"Proportion of {dependent_feature} by {feature}")
plt.tight_layout()
plt.show()
This function performs bivariate analysis by calculating key attributes and generating box plots, bar plots, and count plots to help you better understand the relationship between variables.