Why is my function not running in parallel even with DaskTaskRunner / RayTaskRunner?

I think I discovered the issue; if the function is not decorated with @task, it will run in serial!

The Proof

WITH the decorator however, it will properly utilize the processes.

With Dask, I checked htop because updating the logging formatter is a bit tedious (notice CPU at >100%).

With Ray, the output automatically shows the unique PIDs.

(begin_task_run pid=93880) 2013-01-01T12:00:00.000000000
(begin_task_run pid=93881) 2013-01-01T00:00:00.000000000
(begin_task_run pid=93878) 2013-01-01T18:00:00.000000000
(begin_task_run pid=93879) 2013-01-01T06:00:00.000000000

Setup

import dask
import xarray as xr
import matplotlib.pyplot as plt
from cartopy import crs as ccrs
from cartopy import feature as cfeature
import cartopy.crs as ccrs
import matplotlib.textpath
import matplotlib.patches
from matplotlib.font_manager import FontProperties
import numpy as np
from prefect import flow, task, get_run_logger
from prefect.task_runners import DaskTaskRunner, RayTaskRunner


@task
def plot(ds, time):
    print(time)
    plt.figure()
    ax = plt.axes(projection=ccrs.Orthographic())
    ax.add_feature(cfeature.LAKES.with_scale("10m"))
    ax.add_feature(cfeature.OCEAN.with_scale("10m"))
    ax.add_feature(cfeature.LAND.with_scale("10m"))
    ax.add_feature(cfeature.STATES.with_scale("10m"))
    ax.add_feature(cfeature.COASTLINE.with_scale("10m"))
    ax.pcolormesh(ds["lon"], ds["lat"], ds["air"], transform=ccrs.PlateCarree())

    # generate a matplotlib path representing the word "cartopy"
    fp = FontProperties(family='Bitstream Vera Sans', weight='bold')
    logo_path = matplotlib.textpath.TextPath((-175, -35), 'cartopy',
                                             size=1, prop=fp)

    # add a background image
    im = ax.stock_img()
    # clip the image according to the logo_path. mpl v1.2.0 does not support
    # the transform API that cartopy makes use of, so we have to convert the
    # projection into a transform manually
    plate_carree_transform = ccrs.PlateCarree()._as_mpl_transform(ax)
    im.set_clip_path(logo_path, transform=plate_carree_transform)

    # add the path as a patch, drawing black outlines around the text
    patch = matplotlib.patches.PathPatch(logo_path,
                                         facecolor='none', edgecolor='black',
                                         transform=ccrs.PlateCarree())
    ax.add_patch(patch)
    ax.set_global()

    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 100))

Benchmark

Running on a Jupyter notebook,

ConcurrentTaskRunner - 2m 32.7s

@flow()
def process():
    for time in ds['time'].values:
        plot(ds.sel(time=time), time)

process()

DaskTaskRunner - 1m 29.3s

@flow(task_runner=DaskTaskRunner(cluster_kwargs={"n_workers": 4, "processes": True}))
def process():
    for time in ds['time'].values:
        plot(ds.sel(time=time), time)

process()

RayTaskRunner - 1m 35.2s

@flow(task_runner=RayTaskRunner(init_kwargs={"num_cpus": 4}))
def process():
    for time in ds['time'].values:
        plot(ds.sel(time=time), time)

process()

Native Dask (dropped the @task decorator) - 1m 31.4s (!! around the same time as DaskTaskRunner!)

def process():
    jobs = []
    for time in ds['time'].values:
        job = dask.delayed(plot)(ds.sel(time=time), time)
        jobs.append(job)
    dask.compute(jobs, scheduler="processes", num_workers=4)

process()

Serial mode - 3m 15.9s (I think this could be like an exciting blog post or announcement because just by wrapping @flow and @task, which defaults to the ConcurrentTaskRunner, one can speed up the run by 45 seconds!!)

def process():
    for time in ds['time'].values:
        plot(ds.sel(time=time), time)

process()

Conclusion:
Wrap your functions with @task to get parallelization.

1 Like