jax.experimental.checkify.check_error#
- jax.experimental.checkify.check_error(error)[source]#
Raise an Exception if
errorrepresents a failure. Functionalized bycheckify().The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None: ... err.throw() # can raise ValueError
But unlike that implementation,
check_errorcan be functionalized using thecheckify()transformation.This function is similar to
check()but with a different signature: whereascheck()takes as arguments a boolean predicate and a new error message string, this function takes anErrorvalue as argument. Bothcheck()and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out byjit(),pmap(),scan(), etc. Both also can be functionalized by usingcheckify().But unlike
check(), this function is like a direct inverse ofcheckify(): whereascheckify()takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces anErrorvalue as output, thischeck_errorfunction can accept anErrorvalue as input and can produce the side-effect of raising an Exception. That is, whilecheckify()goes from functionalizable Exception effect to error value, thischeck_errorgoes from error value to functionalizable Exception effect.check_erroris useful when you want to turn checks represented by anErrorvalue (produced by functionalizingchecksviacheckify()) back into Python Exceptions.For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through
jit(), then re-inject your error value outside of thejit():>>> import jax >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "must be positive!") ... return x >>> def with_inner_jit(x): ... checked_f = checkify.checkify(f) ... # a checkified function can be jitted ... error, out = jax.jit(checked_f)(x) ... checkify.check_error(error) ... return out >>> _ = with_inner_jit(1) # no failed check >>> with_inner_jit(-1) Traceback (most recent call last): ... jax._src.JaxRuntimeError: must be positive! >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1)