πŸ’»Using the Python Script Block

Getting familiar with the Python Script Block

Sometimes it's less convenient (or even impossible) to implement the functionality you want using the Collimator foundational blocks. For these cases, there is the Python script block. Using this block, you can create custom feed-through (continuous time) blocks, or discrete time blocks, and you can import many 3rd party libraries or connect to cloud services.

How it works

The basics

Add a Python Script block in the usual way (drag from the library or double-click on the canvas and search for Python Script). By default, a new block has one input and one output port. More input and output ports can be added using the block menu:

Note that each port is assigned a default name. These names can be changed, but must be unique within the block.

Double clicking on the block opens a text editor with two fields where you can enter your Python code. Code in the Init will be run once before simulation starts, and Step will be run each time the block is called to update it outputs.

The script can contain arbitrary Python code including function and class definitions. Note, however, that at present, you can import only modules in the Python standard library and a curated list of libraries such as numpy, scipy, python-control, roboticstoolbox, torch, and jax.

Inputs and outputs

In order to access inputs and outputs, simply reference them by name. Input values are undefined in the Init block, and output values default to 0.0 unless initialized to something different in the Init block.

At each invocation of Step, you would typically want to read in relevant input values, perform some calculations or logic, and set the output values based on these.

Time mode

The default time mode for a Python Script block is Agnostic. In this mode, the block will operate in either Continuous time or Discrete time based on the upstream blocks in your model.

You can also set it to Discrete. Discrete time will force updates to only occur when local discrete steps are taken. If you choose this mode, you should initialize all of your outputs with a default value. This will ensure that output values are available at the first discrete update, and that the data types of the output signals are also set correctly.

Accelerate with JAX

JAX is now an integral part of Collimator, and it is available for you to use within Python Script blocks, too. Here are some tips for how to use JAX in your code.

When this option is enabled, the simulation engine will use JAX as the numerical backend for the block, with all the speed implications that go along with that. In other words, if you use JAX and write your code in JAX's recommended style, then you may see a big speed increase vs. using vanilla Numpy. If you use jax.numpy as a straight replacement for Numpy, you may not see much of a speed-up.

On the other hand, with this option disabled, JAX will not be used. Instead, the code in your Python Script block will be interpreted with standard python, including vanilla Numpy as needed. In this scenario, your simulation time will almost certainly be longer, but if you rely on Numpy semantics, or you're having trouble getting your code to run with JAX, this can be a perfectly valid approach. Note that if you write your code using JAX and then disable this option, you may or may not see simulation errors, but you will likely see slowdowns and / or unexpected behavior.

Example: Ventilator

The Medtech - Ventilator public project uses Python Script blocks to model various kinds of control flow. In the feedforward_integral_gain model, there is a script block that generates control signals for a ventilator:

It has six inputs, and one output. Most of the inputs are connected to constants derived from model parameters, and the clock input is connected to a Clock block.

Numpy version

The original, Numpy-based version of the block is implemented below. Note that the output is initialized, and "Accelerate with JAX" is turned off.

Init block

import numpy as np
pset = 0.0

Step block

# Get time corresponding to the beginnning of a cycle
t = np.mod(clock_time, T) 

# Compute inspiratory (t_ins) and expiratory (t_exp) times
t_ins = IE*T
t_exp = (1.0 - IE)*T

# Compute pset
if t<=rise_time:    # The rising slope
    pset = peep + (pmax - peep)*t/rise_time
elif t<= t_ins:     # The plateau phase
    pset = pmax
else:               # expiration: PEEP phase
    pset = peep

JAX version

The updated, JAX-optimized version is implented below:

Init block

import jax.numpy as jnp
from jax import lax

Step block

# Get time corresponding to the beginnning of a cycle
t = jnp.mod(clock_time, T)

# Compute inspiratory (t_ins) and expiratory (t_exp) times
t_ins = IE * T
t_exp = (1.0 - IE) * T

conditions = [
    lambda t: t <= rise_time,
    lambda t: t <= t_ins,
    lambda t: True, # Default/else condition
]  

outputs = [
    lambda t: peep + (pmax - peep) * t / rise_time,
    lambda t: pmax,
    lambda t: peep,
]

pset = lax.switch(jnp.argmax(jnp.array([cond(t) for cond in conditions])), outputs, t)

Last updated