Reject raw callables in config during deserialization#22736
Reject raw callables in config during deserialization#22736Saumay wants to merge 1 commit intokeras-team:masterfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a security measure to block raw Python callables during Keras object deserialization when safe_mode is active. It introduces _reject_raw_callables, a recursive function that identifies and rejects unauthorized types such as functions, methods, and functools.partial objects within configurations. The review feedback highlights a potential performance bottleneck in the recursive traversal and suggests optimizing the function by skipping checks for PLAIN_TYPES and None values, which are common leaf nodes.
| def _reject_raw_callables(config, path="config"): | ||
| if isinstance(config, _RAW_CALLABLE_TYPES): | ||
| raise TypeError( | ||
| f"Received a raw Python callable at {path!r} during " | ||
| "deserialization. Configs must contain only serializable " | ||
| "values; wrap custom callables with " | ||
| "`@keras.saving.register_keras_serializable()` and pass the " | ||
| "serialized dict, or pass `safe_mode=False` to allow raw " | ||
| f"callables. Received: {config!r}" | ||
| ) | ||
| if isinstance(config, dict): | ||
| for k, v in config.items(): | ||
| _reject_raw_callables(v, f"{path}.{k}") | ||
| elif isinstance(config, (list, tuple)): | ||
| for i, v in enumerate(config): | ||
| _reject_raw_callables(v, f"{path}[{i}]") |
There was a problem hiding this comment.
The recursive walk in _reject_raw_callables can be a performance bottleneck for models with large configurations (e.g., layers containing large lists of constants or hyperparameters). Since _reject_raw_callables is called on every class instantiation during deserialization, and it is itself recursive, it may redundantly walk nested structures multiple times.
To improve efficiency, we should skip recursion for PLAIN_TYPES and None values, which are the most common leaf nodes in a Keras config. Adding an early exit at the top of the function and checking within the loops will significantly reduce function call overhead and string formatting for large collections.
def _reject_raw_callables(config, path="config"):
if config is None or isinstance(config, PLAIN_TYPES):
return
if isinstance(config, _RAW_CALLABLE_TYPES):
raise TypeError(
f"Received a raw Python callable at {path!r} during "
"deserialization. Configs must contain only serializable "
"values; wrap custom callables with "
"`@keras.saving.register_keras_serializable()` and pass the "
"serialized dict, or pass `safe_mode=False` to allow raw "
f"callables. Received: {config!r}"
)
if isinstance(config, dict):
for k, v in config.items():
if v is None or isinstance(v, PLAIN_TYPES):
continue
_reject_raw_callables(v, f"{path}.{k}")
elif isinstance(config, (list, tuple)):
for i, v in enumerate(config):
if v is None or isinstance(v, PLAIN_TYPES):
continue
_reject_raw_callables(v, f"{path}[{i}]")There was a problem hiding this comment.
I agree with this. Current placement walks the outer config once per class-path entry, and nested deserialize_keras_object calls on sub-objects re-walk their own subtrees, so worst case is O(depth × nodes). For typical model sizes (10-100 layers, each config 10-100 fields) this is a small constant overhead, but I agree it compounds on deeply nested configs (e.g. a Transformer stack).
Two mitigation options if this becomes a concern:
- Track walked ids in a thread-local set, skip re-walks
- Walk only at the top-level
deserialize_keras_objectentry by using a depth flag
Happy to implement either if the maintainer wants the guard before merge.
cac4449 to
adcc78c
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22736 +/- ##
==========================================
- Coverage 84.56% 84.10% -0.46%
==========================================
Files 462 462
Lines 66767 66780 +13
Branches 10897 10903 +6
==========================================
- Hits 56459 56167 -292
- Misses 7441 7757 +316
+ Partials 2867 2856 -11
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
1a00eaf to
5dd853d
Compare
Previously `deserialize_keras_object` silently accepted raw Python functions embedded in a config's inner values, storing them on the reconstructed object and invoking them later (e.g. during `build()`). Raw callables are never a valid serialization target (JSON cannot represent them) and can lead to unexpected execution when configs come from untrusted sources. Reject raw callables (function, method, builtin, functools.partial) recursively through `inner_config` before `cls.from_config` when `safe_mode=True` (the default). Users who need to pass raw callables can opt out via `safe_mode=False`, matching the existing opt-out for lambda deserialization. Class objects and class instances are unaffected, since they are legitimate config values (layer instances, constraint instances, etc.). Fixes keras-team#22705.
5dd853d to
4dd25e2
Compare
Description
Fixes #22705.
deserialize_keras_objectsilently accepted raw Python callables embedded in a config dict (the issue's repro nests one underconfig["config"]). The callable flowed throughcls.from_config(inner_config), got stored on the reconstructed object, and was invoked later (e.g. duringbuild()), with no validation or warning.This PR walks the outer config recursively and rejects raw callables (
types.FunctionType,BuiltinFunctionType,MethodType,BuiltinMethodType,functools.partial,functools.partialmethod) whensafe_mode=True(the default). The walk runs beforecls.from_config,build_from_config, andcompile_from_configcan see the config, so callables inconfig["config"],config["build_config"],config["compile_config"], or any other outer-level key are all covered.safe_mode=Falsepreserves the previous silent-accept behavior as an explicit opt-out, mirroring the existing opt-out for lambda deserialization.Class objects and class instances themselves are not rejected (they have different semantics from raw functions). Bound methods (
MethodType) are rejected since they are effectively ad-hoc function references.Existing
__lambda__safe_mode behavior is preserved: the walk sees the serialized bytecode as a string and passes through, so the pre-existing "arbitrary code execution" error still fires downstream undersafe_mode=True.Contributor Agreement
Please check all boxes below before submitting your PR for review: