Faraz Gerrard Jamal
3 min readAug 17, 2022

--

Calculate values for a column in a row based on previous row’s column’s value for a PySpark Dataframe

Say, I have a column ‘val’ whose value gets calculated at each row and then the next row takes in that value and applies some logic on it, and then value for that row also gets updated. It can be shown as follows:-

val(x) = f(val(x-1), col_a(x), col_b(x)) where x is the row number (indexed at 0)

val(0) = f(col_a(0), col_b(0)) {some fixed value calculated based on two columns}

val(0) represents the first value in a partition.

[ f here represents some arbitrary function]

Let’s try using lag function as follows(for a sample dataframe):-

windowSpec  = Window.partitionBy("department")
+-------------+----------+------+------+------+
|employee_name|department| a | b | val |
+-------------+----------+------+------+------+
|James |Sales |3000 |2500 |5500 | #val(0) = (a(0) + b(0)) = 5500 [first value within a partition]
|Michael |Sales |4600 |1650 |750 | #val(1) = (a(1) + b(1) - val(0)) = 750
|Robert |Sales |4100 |1100 |4450 | #val(2) = (a(2) + b(2) - val(1)) = 4450
|Maria |Finance |3000 |7000 |xxxx | #....... and so on, this is how I want the calculations to take place.
|James |Finance |3000 |5000 |xxxx |
|Scott |Marketing |3300 |4300 |xxxx |
|Jen |Marketing |3900 |3700 |xxxx |
df = df.withColumn("val",col("a") + col("b") - lag("val",1).over(windowSpec)) #Does not have the desired result.

Tracking the previously calculated value from the same column is hard to do in spark — I’m not saying it’s impossible, and there certainly are ways (hacks) to achieve it. One way to do is using array of structs and aggregate function.

Consider two assumptions in your data :-

  • There is an ID column that has the sort order of the data — spark does not retain dataframe sorting due to its distributed nature.
  • There is a grouping key for the processing to be optimized.

The assumptions can be implemented using ‘monotonically_increasing_id’ and ‘lit’ functions respectively.

# input data with aforementioned assumptions
data_sdf.show()

# +---+---+-------+---------+----+----+
# | gk|idx| name| dept| a| b|
# +---+---+-------+---------+----+----+
# | gk| 1| James| Sales|3000|2500|
# | gk| 2|Michael| Sales|4600|1650|
# | gk| 3| Robert| Sales|4100|1100|
# | gk| 4| Maria| Finance|3000|7000|
# | gk| 5| James| Finance|3000|5000|
# | gk| 6| Scott|Marketing|3300|4300|
# | gk| 7| Jen|Marketing|3900|3700|
# +---+---+-------+---------+----+----+

# create structs with all columns and collect it to an array
# use the array of structs to do the val calcs
# NOTE - keep the ID field at the beginning for the `array_sort` to # #work as required.
arr_of_structs_sdf = data_sdf. \
withColumn('allcol_struct', func.struct(*data_sdf.columns)). \
groupBy('gk'). \
agg(func.array_sort(func.collect_list('allcol_struct')).alias('allcol_struct_arr'))

This piece of code will turn it into a dataframe with a single column having a single row composed of a list of struct objects, where each struct object corresponds to a record in the original dataframe. func.array_sort enables us to keep the order same by using ‘idx’ as the key for sorting.

# a function to create struct schema string
struct_fields = lambda x: ', '.join([str(x)+'.'+k+' as '+k for k in data_sdf.columns])

# use `aggregate` to do the val calc
arr_of_structs_sdf. \
withColumn('new_allcol_struct_arr',
func.expr('''
aggregate(slice(allcol_struct_arr, 2, size(allcol_struct_arr)),
array(struct({0}, (allcol_struct_arr[0].a+allcol_struct_arr[0].b) as val)),
(x, y) -> array_union(x,
array(struct({1}, ((y.a+y.b)-element_at(x, -1).val) as val))
)
)
'''.format(struct_fields('allcol_struct_arr[0]'), struct_fields('y'))
)
). \
selectExpr('inline(new_allcol_struct_arr)'). \
show(truncate=False)

The ‘aggregrate’ function takes in an initial value and performs operations on it as we iterate through the list. In this case we take the intial value as the sum of column‘a’ and column‘b’ for the first row as column‘val’. Using slice, we make sure we start the calculation from the second row of the dataframe (after initializing the value using the first row of the dataframe)

Then we keep adding new elements to the list using array_union and refer to the most recent ‘val’ column’s value via ->

element_at(x, -1).val #'-1' points to the last element in the array

Finally we use selectExpr('inline(new_allcol_struct_arr)')

to explode the list of structs and convert it into a dataframe as below :-

# +---+---+-------+---------+----+----+----+
# |gk |idx|name |dept |a |b |val |
# +---+---+-------+---------+----+----+----+
# |gk |1 |James |Sales |3000|2500|5500|
# |gk |2 |Michael|Sales |4600|1650|750 |
# |gk |3 |Robert |Sales |4100|1100|4450|
# |gk |4 |Maria |Finance |3000|7000|5550|
# |gk |5 |James |Finance |3000|5000|2450|
# |gk |6 |Scott |Marketing|3300|4300|5150|
# |gk |7 |Jen |Marketing|3900|3700|2450|
# +---+---+-------+---------+----+----+----+

--

--