20 essential PySpark operations

Table Of Contents:
Setting Up PySpark
Loading data
Basic operations
Column operations
Row operations
Aggregate functions
Window functions
Joins
Performance Optimisation
Best Practices and Tips
Conclusion
References
As a Machine Learning Engineer working with big data, I understand the importance of efficient data processing and high-quality data preparation for ML models. Spark is a tool that enables us to work with big data. If you are looking for an intro to Spark, here’s a guide.
We will be using PySpark library throughout this article. PySpark is the Python API for Apache Spark. In this comprehensive article, we'll explore essential PySpark operations using a sample dataset. We conclude the article with best practices to keep in mind while working with data. Feel free to grab a cup of coffee/ tea :)
a. Setting Up PySpark
Before diving into data processing, we need to configure our Spark environment properly. The SparkSession is your entry point to all Spark functionality. Feel free to spin up a notebook, or download Spark on your local machine.
- To install PySpark on your notebook, run:
!pip install pyspark
To download spark on your local machine, you can follow this guide.
Once we have the spark environment, we can initialise the spark instance.
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder \\
.appName("SparkLearningDemo") \\
.getOrCreate()
b. Loading data
To work with spark, we need to convert data into a spark dataframe. The data can exist in several formats, let’s go over a few examples on how to handle each format:
- CSV Files
CSV files are one of the most common data formats. When reading CSV files in PySpark, you can specify various options to handle different scenarios like headers, schema inference, and malformed records:
# Reading CSV with options
df = spark.read \\
.option("header", "true") \\
.option("inferSchema", "true") \\
.csv("fashion_products.csv")
- Pandas DF
Sometimes you'll need to convert existing pandas DataFrames to Spark DataFrames, especially when working with legacy code or smaller datasets that were initially processed using pandas.
Best practice: Always specify schema when creating a dataframe.
import pandas as pd
from pyspark.sql.types import StructType, StructField, StringType, DoubleType
# Define schema for better control and performance
schema = StructType([
StructField("product_id", StringType(), False),
StructField("category", StringType(), True),
StructField("price", DoubleType(), True),
StructField("color", StringType(), True)
])
# Create pandas DataFrame
new_products_pd = pd.DataFrame({
'product_id': ['P1001', 'P1002', 'P1003'],
'category': ['Midi-dress', 'Skorts', 'Shirt'],
'price': [1650.0, 1800.0, 700.0],
'color': ['Red', 'White', 'Grey']
})
# Convert to Spark DataFrame with schema
df = spark.createDataFrame(new_products_pd, schema=schema)
- Parquet Files
Parquet is a columnar storage format that offers superior performance and compression compared to CSV. It's particularly useful for big data processing and analytics:
# Reading Parquet
df = spark.read.parquet("fashion_data.parquet")
If you want to learn more about the Parquet format, I highly recommend this in-depth article.
c. Basic operations
Let’s go over some basic operations to you understand the data's structure, content, and basic statistics.
show() and head() prints first few rows of df.
printSchema() prints the schema of the df.
describe() displays some basic statistics like count, mean, etc. of df
count() returns the total number of rows. It is an aggregate function, which we’ll be covering below.
# Display sample data
# We can use the flag truncate=False to get a full view of the content
df.show(5, truncate=False)
# Schema information
df.printSchema()
# Basic statistics
df.describe().show()
#Count
df.count()
The results should look like this:
d. Column operations
- withColumn() - create a new column/ update an existing column/ change the datatype of an existing column.
from pyspark.sql.functions import col, when, round, expr, array, struct
# Complex column transformations
df_transformed = df.withColumn("discount_price",
round(col("price") * 0.9, 2))
- select() - Select certain columns from the df
df_subset = df.select("product_id", "category", "price")
- drop() - Drop certain columns from the df
df_dropped = df.drop("discount_price")
e. Row operations
- filter() or where() - Filters rows based on a condition. You can also filter on multiple conditions by using logical operators like and.
# Complex filtering
premium_products = df \\
.filter((col("price") > 1000) &
(col("category").isin("Dress", "Shirt")))
distinct() - Select distinct rows. We do not have any duplicate rows in our sample dataset. For a task, create a new df with multiple rows with same value.
sort() or orderBy() - To sort rows by one or more columns. By default these methods sort in ascending order:
df.sort(col("price").desc())
Since we are getting to advanced operations, let’s update our sample dataset:
schema = StructType([
StructField("product_id", StringType(), False),
StructField("category", StringType(), True),
StructField("price", DoubleType(), True),
StructField("brand", StringType(), True),
StructField("color", StringType(), True),
StructField("season", StringType(), True)
])
# Create data rows
data = [
("P1001", "Dress", 1650.0, "Gucci", "Red", "Summer"),
("P1002", "Dress", 1800.0, "Zara", "White", "Winter"),
("P1003", "Shirt", 700.0, "H&M", "Grey", "Summer"),
("P1004", "Shirt", 850.0, "Zara", "Blue", "Winter"),
("P1005", "Skirt", 950.0, "Gucci", "Black", "Fall"),
("P1006", "Dress", 2100.0, "Gucci", "Red", "Summer"),
("P1007", "Shirt", 600.0, "H&M", "White", "Spring"),
("P1008", "Skirt", 1100.0, "Zara", "Black", "Fall"),
("P1009", "Dress", 1750.0, "Zara", "Blue", "Winter"),
("P1010", "Shirt", 800.0, "H&M", "Grey", "Spring")
]
# Create DataFrame
df = spark.createDataFrame(data, schema=schema)
Your sample dataset should look like this:
f. Aggregate Functions
Aggregate functions perform calculations on a group of rows to return a single value. Think of them as operations that take multiple values and "aggregate" them into a single result. They're particularly useful when you want to summarise data and understand data patterns.
- groupBy() - Groups the df based on certain conditions, given the aggregate function.
df.groupBy("category").count().show()
- collect_list() - Collect values into a list.
import pyspark.sql.functions as F
df.groupBy("category").agg(F.collect_list("product_id")).show(truncate=False)
g. Window Functions
Window functions allow you to perform calculations across a set of rows related to the current row. For example, if we want to find the lowest price across a category and brand, we can proceed like below:
- row_number() - When used with window function, calculates the row number.
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
window_spec = Window.partitionBy("category", "brand").orderBy("price")
# Adding ranking
df_analyzed = df.withColumn("row_number", row_number().over(window_spec)).select("category", "brand", "row_number", "price")
df_analyzed.show()
h. Joins
PySpark supports several types of joins to combine DataFrames:
inner - Returns only matching rows between two DataFrames
left or left_outer - Returns all rows from left DataFrame and matching from right
right or right_outer - Returns all rows from right DataFrame and matching from left
full or full_outer - Returns all rows from both DataFrames
# Create orders data
orders_data = [
("O1001", "P1001", 2, "2024-01-15"),
("O1002", "P1003", 1, "2024-01-16"),
("O1003", "P1005", 3, "2024-01-16"),
("O1004", "P1002", 1, "2024-01-17"),
("O1005", "P1010", 2, "2024-01-17"),
("O1006", "P1012", 1, "2024-01-18")
]
orders_schema = StructType([
StructField("order_id", StringType(), False),
StructField("product_id", StringType(), True),
StructField("quantity", IntegerType(), True),
StructField("order_date", StringType(), True)
])
# Create products data
products_data = [
("P1001", "Dress", 1650.0, "Gucci"),
("P1002", "Dress", 1800.0, "Zara"),
("P1003", "Shirt", 700.0, "H&M"),
("P1004", "Shirt", 850.0, "Zara"),
("P1005", "Skirt", 950.0, "Gucci"),
("P1008", "Skirt", 1100.0, "Zara"),
("P1009", "Dress", 1750.0, "Zara"),
("P1010", "Shirt", 800.0, "H&M")
]
products_schema = StructType([
StructField("product_id", StringType(), False),
StructField("category", StringType(), True),
StructField("price", DoubleType(), True),
StructField("brand", StringType(), True)
])
# Create DataFrames
orders_df = spark.createDataFrame(orders_data, orders_schema)
products_df = spark.createDataFrame(products_data, products_schema)
# Example joins
print("\\nInner Join (only matching orders and products):")
orders_df.join(products_df, "product_id", how="inner").show()
print("Left Join")
orders_df.join(products_df, "product_id", "left").show()
print("Right join")
orders_df.join(products_df, "product_id", "right").show()
print("Full outer join")
orders_df.join(products_df, "product_id", "full").show()
- Broadcast Joins
Broadcast joins are essential for optimising joins between large and small dataframes by broadcasting the smaller DataFrame to all nodes. I cannot explain how much time broadcasting has saved in my pipelines.
from pyspark.sql.functions import broadcast
# Small DataFrame for product categories
category_df = spark.createDataFrame([
("Dress", "Women's Clothing"),
("Shirt", "Men's Clothing")
], ["category", "department"])
# Broadcast join for performance
df_with_dept = df.join(broadcast(category_df), "category")
Best practice: Consider using broadcast joins when one DataFrame is significantly smaller (typically < 10GB) than the other.
i. Performance Optimisation
Performance optimisation is crucial when working with large datasets. These techniques help you manage memory usage and improve processing speed:
Partitioning - Helps distribute data across clusters. We can achieve this via repartition. If we have to reduce the partitions, we can use coalesce.
Caching/ Persist - Keeps frequently accessed data in memory for faster processing. Use cache or persist. To validate if the df has been cached, we can analyse the spark UI.
getNumPartitions - To see the number of partitions
# Repartition for better parallelism
df_repartitioned = df.repartition(10, "category")
# To validate the number of partitions
df_repartitioned.rdd.getNumPartitions()
# Coalesce for reducing partitions
df_coalesced = df.coalesce(5)
# Caching for repeated operations
df.cache()
df.persist()
# After use, remove from memory
df.unpersist()
j. Best Practices and Tips
Following these best practices will help you write more efficient and maintainable PySpark code:
Always specify schemas explicitly when creating DataFrames to ensure data type consistency
Use appropriate partition sizes based on data volume to optimise processing
Leverage caching for frequently accessed data to improve performance
Use broadcast joins for small lookup tables to reduce shuffle operations
Use appropriate storage formats (Parquet preferred) for better performance
h. Conclusion
This guide covers the essential PySpark operations needed for processing data efficiently. Remember to always consider your specific use case when applying these operations and monitor performance metrics to ensure optimal processing.
i . References
Subscribe to my newsletter
Read articles from Tanupriya Singh directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
