2222 RES_MATCH_RANGE , RES_SLOTS , RES_VALUE , SLOT_NAME , START , TEXT , UTTERANCES ,
2323 RES_PROBA )
2424from snips_nlu .dataset import validate_and_format_dataset
25+ from snips_nlu .dataset .utils import extract_entity_values
2526from snips_nlu .entity_parser .builtin_entity_parser import is_builtin_entity
2627from snips_nlu .exceptions import IntentNotFoundError , LoadingError
2728from snips_nlu .intent_parser .intent_parser import IntentParser
@@ -55,10 +56,11 @@ def __init__(self, config=None, **shared):
5556 self ._language = None
5657 self ._slot_names_to_entities = None
5758 self ._group_names_to_slot_names = None
59+ self ._stop_words = None
60+ self ._stop_words_whitelist = None
5861 self .slot_names_to_group_names = None
5962 self .regexes_per_intent = None
6063 self .entity_scopes = None
61- self .stop_words = None
6264
6365 @property
6466 def language (self ):
@@ -68,12 +70,12 @@ def language(self):
6870 def language (self , value ):
6971 self ._language = value
7072 if value is None :
71- self .stop_words = None
73+ self ._stop_words = None
7274 else :
7375 if self .config .ignore_stop_words :
74- self .stop_words = get_stop_words (self .resources )
76+ self ._stop_words = get_stop_words (self .resources )
7577 else :
76- self .stop_words = set ()
78+ self ._stop_words = set ()
7779
7880 @property
7981 def slot_names_to_entities (self ):
@@ -142,13 +144,15 @@ def fit(self, dataset, force_retrain=True):
142144 self .slot_names_to_entities = get_slot_name_mappings (dataset )
143145 self .group_names_to_slot_names = _get_group_names_to_slot_names (
144146 self .slot_names_to_entities )
147+ self ._stop_words_whitelist = _get_stop_words_whitelist (
148+ dataset , self ._stop_words )
145149
146150 # Do not use ambiguous patterns that appear in more than one intent
147151 all_patterns = set ()
148152 ambiguous_patterns = set ()
149153 intent_patterns = dict ()
150154 for intent_name , intent in iteritems (dataset [INTENTS ]):
151- patterns = self ._generate_patterns (intent [UTTERANCES ],
155+ patterns = self ._generate_patterns (intent_name , intent [UTTERANCES ],
152156 entity_placeholders )
153157 patterns = [p for p in patterns
154158 if len (p ) < self .config .max_pattern_length ]
@@ -221,7 +225,6 @@ def placeholder_fn(entity_name):
221225 return _get_entity_name_placeholder (entity_name , self .language )
222226
223227 results = []
224- cleaned_text = self ._preprocess_text (text )
225228
226229 for intent , entity_scope in iteritems (self .entity_scopes ):
227230 if intents is not None and intent not in intents :
@@ -233,7 +236,9 @@ def placeholder_fn(entity_name):
233236 all_entities = builtin_entities + custom_entities
234237 mapping , processed_text = replace_entities_with_placeholders (
235238 text , all_entities , placeholder_fn = placeholder_fn )
236- cleaned_processed_text = self ._preprocess_text (processed_text )
239+ cleaned_text = self ._preprocess_text (text , intent )
240+ cleaned_processed_text = self ._preprocess_text (processed_text ,
241+ intent )
237242 for regex in self .regexes_per_intent [intent ]:
238243 res = self ._get_matching_result (text , cleaned_processed_text ,
239244 regex , intent , mapping )
@@ -300,14 +305,19 @@ def get_slots(self, text, intent):
300305 slots = []
301306 return slots
302307
303- def _preprocess_text (self , string ):
308+ def _get_intent_stop_words (self , intent ):
309+ whitelist = self ._stop_words_whitelist .get (intent , set ())
310+ return self ._stop_words .difference (whitelist )
311+
312+ def _preprocess_text (self , string , intent ):
304313 """Replaces stop words and characters that are tokenized out by
305314 whitespaces"""
306315 tokens = tokenize (string , self .language )
307316 current_idx = 0
308317 cleaned_string = ""
318+ stop_words = self ._get_intent_stop_words (intent )
309319 for token in tokens :
310- if self . stop_words and normalize_token (token ) in self . stop_words :
320+ if stop_words and normalize_token (token ) in stop_words :
311321 token .value = "" .join (" " for _ in range (len (token .value )))
312322 prefix_length = token .start - current_idx
313323 cleaned_string += "" .join ((" " for _ in range (prefix_length )))
@@ -352,18 +362,21 @@ def _get_matching_result(self, text, processed_text, regex, intent,
352362 key = lambda s : s [RES_MATCH_RANGE ][START ])
353363 return extraction_result (parsed_intent , parsed_slots )
354364
355- def _generate_patterns (self , intent_utterances , entity_placeholders ):
365+ def _generate_patterns (self , intent , intent_utterances ,
366+ entity_placeholders ):
356367 unique_patterns = set ()
357368 patterns = []
369+ stop_words = self ._get_intent_stop_words (intent )
358370 for utterance in intent_utterances :
359371 pattern = self ._utterance_to_pattern (
360- utterance , entity_placeholders )
372+ utterance , stop_words , entity_placeholders )
361373 if pattern not in unique_patterns :
362374 unique_patterns .add (pattern )
363375 patterns .append (pattern )
364376 return patterns
365377
366- def _utterance_to_pattern (self , utterance , entity_placeholders ):
378+ def _utterance_to_pattern (self , utterance , stop_words ,
379+ entity_placeholders ):
367380 slot_names_count = defaultdict (int )
368381 pattern = []
369382 for chunk in utterance [DATA ]:
@@ -379,7 +392,7 @@ def _utterance_to_pattern(self, utterance, entity_placeholders):
379392 else :
380393 tokens = tokenize_light (chunk [TEXT ], self .language )
381394 pattern += [regex_escape (t .lower ()) for t in tokens
382- if normalize (t ) not in self . stop_words ]
395+ if normalize (t ) not in stop_words ]
383396
384397 pattern = r"^%s%s%s$" % (WHITESPACE_PATTERN ,
385398 WHITESPACE_PATTERN .join (pattern ),
@@ -417,12 +430,18 @@ def from_path(cls, path, **shared):
417430
418431 def to_dict (self ):
419432 """Returns a json-serializable dict"""
433+ stop_words_whitelist = None
434+ if self ._stop_words_whitelist is not None :
435+ stop_words_whitelist = {
436+ intent : sorted (values )
437+ for intent , values in iteritems (self ._stop_words_whitelist )}
420438 return {
421439 "config" : self .config .to_dict (),
422440 "language_code" : self .language ,
423441 "patterns" : self .patterns ,
424442 "group_names_to_slot_names" : self .group_names_to_slot_names ,
425- "slot_names_to_entities" : self .slot_names_to_entities
443+ "slot_names_to_entities" : self .slot_names_to_entities ,
444+ "stop_words_whitelist" : stop_words_whitelist
426445 }
427446
428447 @classmethod
@@ -439,6 +458,12 @@ def from_dict(cls, unit_dict, **shared):
439458 parser .group_names_to_slot_names = unit_dict [
440459 "group_names_to_slot_names" ]
441460 parser .slot_names_to_entities = unit_dict ["slot_names_to_entities" ]
461+ if parser .fitted :
462+ whitelist = unit_dict .get ("stop_words_whitelist" , dict ())
463+ # pylint:disable=protected-access
464+ parser ._stop_words_whitelist = {
465+ intent : set (values ) for intent , values in iteritems (whitelist )}
466+ # pylint:enable=protected-access
442467 return parser
443468
444469
@@ -487,3 +512,14 @@ def sort_key_fn(slot):
487512def _get_entity_name_placeholder (entity_label , language ):
488513 return "%%%s%%" % "" .join (
489514 tokenize_light (entity_label , language )).upper ()
515+
516+
517+ def _get_stop_words_whitelist (dataset , stop_words ):
518+ entity_values_per_intent = extract_entity_values (
519+ dataset , apply_normalization = True )
520+ stop_words_whitelist = dict ()
521+ for intent , entity_values in iteritems (entity_values_per_intent ):
522+ whitelist = stop_words .intersection (entity_values )
523+ if whitelist :
524+ stop_words_whitelist [intent ] = whitelist
525+ return stop_words_whitelist
0 commit comments