Chat with your data: Unternehmensdaten als Basis für einen eigenen KI-Assistenten nutzen.
Zum Angebot 
Data Science

Efficient UD(A)Fs with PySpark

16 ​​min

This post is older than 5 years – the content might be outdated.

Nowadays, Spark surely is one of the most prevalent technologies in the fields of data science and big data. Luckily, even though it is developed in Scala and runs in the Java Virtual Machine (JVM), it comes with Python bindings also known as PySpark, whose API was heavily influenced by Pandas. With respect to functionality, modern PySpark has about the same capabilities as Pandas when it comes to typical ETL and data wrangling, e.g. groupby, aggregations and so on. As a general rule of thumb, one should consider an alternative to Pandas whenever the data set has more than 10,000,000 rows which, depending on the number of columns and data types, translates to about 5-10 GB of memory usage. At that point PySpark might be an option for you that does the job, but of course there are others like for instance Dask which won’t be addressed in this post.

If you are new to Spark, one important thing to note is that Spark has two remarkable features besides its programmatic data wrangling capabilities. One is that Spark comes with SQL as an alternative way of defining queries and the other is Spark MLlib for machine learning. Both topics are beyond the scope of this post but should be taken into account if you are considering PySpark as an alternative to Pandas and scikit-learn for larger data sets.

But enough praise for PySpark, there are still some ugly sides as well as rough edges to it and we want to address some of them here, of course, in a constructive way. First of all, due to its relatively young age, PySpark lacks some features that Pandas provides, for example in areas such as reshaping/pivoting or time series. Also, it is not as straightforward to use advanced mathematical functions from SciPy within PySpark. That’s why sooner or later, you might walk into a scenario where you want to apply some Pandas or SciPy operations to your data frame in PySpark. Unfortunately, there is no built-in mechanism for using Pandas transformations in PySpark. In fact, this requires a lot of boilerplate code with many error-prone details to consider. Therefore we make a wish to the coding fairy, cross two fingers that someone else already solved this and start googling… and here we are 😉

The remainder of this blog post walks you through the process of writing efficient Pandas UDAFs in PySpark. In fact, we end up abstracting all the necessary boilerplate code into a single Python decorator, which allows us to conveniently specify our PySpark Pandas function. To give more insights into performance considerations, this post also contains a little journey into the internals of PySpark.

UDAFs with RDDs

To start with a recap, an aggregation function is a function that operates on a set of rows and produces a result, for example a sum() or count() function. A User-Defined Aggregation Function (UDAF) is typically used for more complex aggregations that are not natively shipped with your analysis tool in question. In our case, this means we provide some Python code that takes a set of rows and produces an aggregate result. At the time of writing – with PySpark 2.2 as latest version – there is no “official“ way of defining an arbitrary UDAF function. Also, the tracking Jira issue SPARK-10915 does not indicate that this changes in near future. Depending on your use-case, this might even be a reason to completely discard PySpark as a viable solution. However, as you might have guessed from the title of this article, there are workarounds to the rescue. This is where the RDD API comes in. As a reminder, a Resilient Distributed Dataset (RDD) is the low-level data structure of Spark and a Spark DataFrame is built on top of it. As we are mostly dealing with DataFrames in PySpark, we can get access to the underlying RDD with the help of the rdd attribute and convert it back with toDF() . This RDD API allows us to specify arbitrary Python functions that get executed on the data. To give an example, let’s say we have a DataFrame df of one billion rows with a boolean  is_sold column and we want to filter for rows with sold products. One could accomplish this with the code

Although not explicitly declared as such, this lambda function is essentially a user-defined function (UDF). For this exact use case, we could also use the more high-level DataFrame filter() method, producing the same result:

Before we now go into the details on how to implement UDAFs using the RDD API, there is something important to keep in mind which might sound counterintuitive to the title of this post: in PySpark you should avoid all kind of Python UDFs – like RDD functions or data frame UDFs – as much as possible! Whenever there is a built-in DataFrame method available, this will be much faster than its RDD counterpart. To get a better understanding of the substantial performance difference, we will now take a little detour and investigate what happens behind the scenes in those two filter examples.

PySpark internals

PySpark is actually a wrapper around the Spark core written in Scala. When you start your SparkSession in Python, in the background PySpark uses Py4J to launch a JVM and create a Java SparkContext. All PySpark operations, for example our df.filter() method call, behind the scenes get translated into corresponding calls on the respective Spark DataFrame object within the JVM SparkContext. This is in general extremely fast and the overhead can be neglected as long as you don’t call the function millions of times. So in our df.filter() example, the DataFrame operation and the filter condition will be send to the Java SparkContext, where it gets compiled into an overall optimized query plan. Once the query is executed, the filter condition is evaluated on the distributed DataFrame within Java, without any callback to Python! In case our workflow loads the DataFrame from Hive and saves the resulting DataFrame as Hive table, throughout the entire query execution all data operations are performed in a distributed fashion within Java Spark workers, which allows Spark to be very fast for queries on large data sets. Okay, so why is the RDD filter() method then so much slower? The reason is that the lambda function cannot be directly applied to the DataFrame residing in JVM memory. What actually happens internally is that Spark spins up Python workers next to the Spark executors on the cluster nodes. At execution time, the Spark workers send our lambda function to those Python workers. Next, the Spark workers start serializing their RDD partitions and pipe them to the Python workers via sockets, where our lambda function gets evaluated on each row. For the resulting rows, the whole serialization/deserialization procedure happens again in the opposite direction so that the actual filter() can be applied to the result set.

The entire data flow when using arbitrary Python functions in PySpark is also shown in the following image, which has been taken from the old PySpark Internals wiki:

UDAF Data Flow in PySpark

Even if all of this sounded awkwardly technical to you, you get the point that executing Python functions in a distributed Java system is very expensive in terms of execution time due to excessive copying of data back and forth.

To give a short summary to this low-level excursion: as long as we avoid all kind of Python UDFs, a PySpark program will be approximately as fast as Spark program based on Scala. If we cannot avoid UDFs, we should at least try to make them as efficient as possible, which is what show in the remaining post. Before we move on though, one side note should be kept in mind. The general problem of accessing data frames from different programming languages in the realm of data analytics is currently addressed by the creator of Pandas Wes McKinney. He is also the initiator of the Apache Arrow project which tries to standardize the way columnar data is stored in memory so that everyone using Arrow won’t need to do the cumbersome object translation by serialization and deserialization anymore. Hopefully with version 2.3, as shown in the issues SPARK-13534 and SPARK-21190, Spark will make use of Arrow, which should drastically speed up our Python UDFs. Still, even in that case we should always prefer built-in Spark functions whenever possible.

PySpark UDAFs with Pandas

As mentioned before our detour into the internals of PySpark, for defining an arbitrary UDAF function we need an operation that allows us to operate on multiple rows and produce one or multiple resulting rows. This functionality is provided by the RDD method mapPartitions , where we can apply an arbitrary Python function my_func to a DataFrame df partition with:

If you want to further read up on RDDs and partitions, you can checkout the chapter Partitions and Partitioning of the excellent Mastering Apache Spark 2 book by Jacek Laskowski. In most cases we would want to control the number of partitions, like 100, or even group by a column, let’s say country , in which case we would write:


Having solved one problem, as it is quite often in life, we have introduced another problem. As we are working now with the low-level RDD interface, our function my_func will be passed an iterator of PySpark Row objects and needs to return them as well. A Row object itself is only a container for the column values in one row, as you might have guessed. When we return such a Row , the data types of these values therein must be interpretable by Spark in order to translate them back to Scala. This is a lot of low-level stuff to deal with since in most cases we would love to implement our UDF/UDAF with the help of Pandas, keeping in mind that one partition should hold less than 10 million rows.

So first we need to define a nice function that will convert a Row iterator into a Pandas DataFrame:

This function actually does only one thing which is calling df = pd.DataFrame.from_records(rows, columns=first_row.__fields__) in order to generate a DataFrame. The rest of the code makes sure that the iterator is not empty and for debugging reasons we also peek into the first row and print the value as well as the datatype of each column. This has proven in practice to be extremely helpful in case something goes wrong and one needs to debug what’s going on in the UDF/UDAF. The functions peek and rtype are defined as follows:

The next part is to actually convert the result of our UDF/UDAF back to an iterator of Row objects. Since our result will most likely be a Pandas DataFrame or Series, we define the following:

This looks a bit more complicated but essentially we convert a Pandas Series to a DataFrame if necessary and handle the edge cases of an empty DataFrame or None as return value. We then convert the DataFrame to records, convert some NumPy data types to the Python equivalent and create an iterator over Row objects from the converted records.

With these functions at hand we can define a Python decorator that will allow us to automatically call the functions rows_to_pandas and pandas_to_rows at the right time:

The code is pretty much self-explanatory if you have ever written a Python decorator; otherwise, you should read about it since it takes some time to wrap your head around it. Basically, we set up a default logger, create a Pandas DataFrame from the Row iterator, pass it to our UDF/UDAF and convert its return value back to a Row iterator. The only additional thing that might still raise questions is the usage of args[-1] . This is due to the fact that func might also be a method of an object. In this case, the first argument would be self but the last argument is in either cases the actual argument that mapPartitions will pass to us. The code of setup_logger depends on your Spark installation. In case you are using Spark on Apache YARN, it might look like this:

Now having all parts in place let’s assume the code above resides in the python module A future post will cover the topic of deploying dependencies in a systematic way for production requirements. For now we just presume that as well as all its dependencies like Pandas, NumPy, etc. are accessible by the Spark driver as well as the executors. This allows us to then easily define an example UDAF my_func that collects some basic statistics for each country as:

It is of course not really useful in practice to return some statistics with the help of a UDAF that could also be retrieved with basic PySpark functionality but this is just an example. We now generate a dummy data DataFrame and apply the function to each partition as above with:

The code above can be easily tested with the help of a Jupyter notebook with PySpark where the SparkSession spark is predefined.


Overall, this proposed method allows the definition of an UDF as well as an UDAF since it is up to the function my_func if it returns (1) a DataFrame having as many rows as the input DataFrame (think Pandas transform), (2) a DataFrame of only a single row or (3) optionally a Series (think Pandas aggregate) or a DataFrame with an arbitrary number of rows (think Pandas apply) with even varying columns. Therefore, this approach should be applicable to a variety of use cases where the built-in PySpark functionality is not sufficient.

To wrap it up, this blog post gives you a template on how to write PySpark UD(A)Fs while abstracting all the boilerplate in a dedicated module. We also went down the rabbit hole to explore the technical difficulties the Spark developers face in providing Python bindings to a distributed JVM-based system. In this respect we are really looking forward to closer integration of Apache Arrow and Spark in the upcoming Spark 2.3 and future versions.

This article originally appeared on

9 Kommentare

    1. Hi @disqus_nywrOI9eN3:disqus, thanks for your feedback, glad you like it. Good point, one should definitely add more type conversions to cover all cases. Let me know if you find more missing ones, then I will add them all at once in the next update of this post.

  1. Amazing! Many thanks. I added this to deal with empty partitions, before passing the data frame to the UDAF:
    if args[-1].shape[0] == 0:
    df = pd.DataFrame()
    df = func(*args)

    1. Hi, great that you like it and thanks for your feedback. I think one can do this, on the other hand, maybe your UDF wants to react on empty dataframes by some special action. With your approach the actual UDF never sees the empty dataframe. But in the end it really depends on your use-case.

  2. Hi, Please explain why there is drop for ‚country‘ here? I tried having the partition column also (in this case ‚country‘) when applying the function and got unexpected results. Though final counts were matching, the data is not good – got duplicates for some rows and some records missing. After adding drop it worked. Can you explain what really happens with the drop?

    df = df.groupby(‚country‘).apply(lambda x: x.drop(‚country‘, axis=1).describe())

    1. Hi Srileka, with drop I just remove the country column from the dataframe to assure that describe gives meaningful values. Does this help? Could you eloborate some more what duplicates you are seeing? Please also note that the article you are reading is quite old and PySpark has evolved a lot since 2017. I would not recommend this approach nowadays anymore. Please check out my successor article „More efficient UDFs with PySpark“ from 2019 and then read up on all the news about Spark 3.0 to be up to date about UD(A)Fs.

      1. Thanks for the response Florian. On the duplicates – Please refer the example below
        Input df:
        row_num, other columns
        1, a1, a2, a3, a4
        2, b1, b2, b3, b4
        3, c1, c2, c3, c4
        4, d1, d2, d3, d4
        5, e1, e2, e3, e4
        Output df:
        row_num, other columns
        1, a1, a2, a3, a4
        1, a1, a2, a3, a4
        3, c1, c2, c3, c4
        4, d1, d2, d3, d4
        4, d1, d2, d3, d4
        The rows 2 & 5 are missing and 1 & 4 are exactly duplicated. Final count of output df is same as input df (missing count compensates dups count). But when I drop the partition id (p_id) as soon as I call the function before any transformations, it works. Please note, if I drop it by end of the function it doesn’t work – still getting dups and missing. I have set the index on row_num and resetting at the end of UDAF. So trying to understand what happens with drop p_id – why it works only when we drop the p_id in the beginning of the function. No matter what I do with the index, only drop p_id does the magic.

        Thanks for the new link. Yes Spark has evolved a lot, but I have an existing code in similar framework which goes through an enhancement now. So trying to achieve with minimal change.

        1. Thanks for the example, don’t exactly understand why drop fixes that. Sorry, but I cannot look much deeper into this since the code is so old and we don’t use it anymore.

Hat dir der Beitrag gefallen?

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert