Coverage for agentlib/utils/validators.py: 85%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1""" 

2Module with validator function used in multiple parts of the agentlib 

3""" 

4 

5from __future__ import annotations 

6 

7from typing import List, Dict, Any, Type 

8from copy import deepcopy 

9import logging 

10 

11import pydantic 

12 

13from agentlib.core.datamodels import ( 

14 AgentVariable, 

15 AgentVariables, 

16 Source, 

17) 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def convert_to_list(obj): 

23 """Function to convert an object to a list. 

24 Either is is already a list. 

25 Or it is None, then [] is returned, 

26 or it is a scalar object and thus converted to a list. 

27 """ 

28 if isinstance(obj, list): 

29 return obj 

30 if obj is None: 

31 return list() 

32 return [obj] 

33 

34 

35def include_defaults_in_root( 

36 variables: List[Dict], 

37 field: pydantic.fields.FieldInfo, 

38 make_shared: bool, 

39 agent_id: str, 

40 field_name: str, 

41 type_: Type[AgentVariable] = AgentVariable, 

42) -> AgentVariables: 

43 """ 

44 Validator building block to merge default variables with config variables in the root validator. 

45 Updates default variables when a variable with the same name is present in the config. 

46 Then returns the union of the default variables and the external config variables. 

47 """ 

48 # First create a copy as otherwise multiple instances of e.g. a Model class 

49 # would share the same defaults 

50 default: AgentVariables = deepcopy(field.default) 

51 if default is None: 

52 default = [] 

53 variables = variables.copy() 

54 user_variables_dict = {d["name"]: d for d in variables} 

55 for i, var in enumerate(default): 

56 if var.name not in user_variables_dict: 

57 if make_shared: 

58 var.shared = make_shared 

59 var.source = Source(agent_id=agent_id) 

60 continue 

61 var_to_update_with = user_variables_dict[var.name] 

62 variables.remove(var_to_update_with) 

63 default[i] = update_default_agent_variable( 

64 default_var=var, 

65 user_data=var_to_update_with, 

66 make_shared=make_shared, 

67 agent_id=agent_id, 

68 field_name=field_name, 

69 ) 

70 

71 # add new variables and check if they are shared 

72 for var_dict in variables: 

73 if "shared" not in var_dict: 

74 var_dict["shared"] = make_shared 

75 new_var: AgentVariable = type_.validate_data(var_dict) 

76 if new_var.shared: 

77 new_var.source = Source(agent_id=agent_id) 

78 default.append(new_var) 

79 

80 return default 

81 

82 

83def update_default_agent_variable( 

84 default_var: AgentVariable, 

85 user_data: dict, 

86 make_shared: bool, 

87 agent_id: str, 

88 field_name: str, 

89): 

90 """Update a variable based on it's default""" 

91 

92 if is_valid_agent_var_config(user_data, field_name): 

93 update_var_with = user_data 

94 else: 

95 update_var_with = {"value": user_data} 

96 

97 # Setting the shared attribute first allows it to be overwritten by the user config 

98 if not isinstance(default_var, AgentVariable): 

99 default_var = AgentVariable(name=field_name) 

100 if "alias" not in update_var_with: 

101 # need exception here, as the copy below does not cover the default alias 

102 default_var.alias = update_var_with["name"] 

103 

104 if default_var.shared is None: 

105 default_var.shared = make_shared 

106 agent_variable = default_var.copy(update=update_var_with) 

107 # validate the model again, otherwise there can be buggy sources 

108 # todo check how this works with attrs variables 

109 agent_variable = type(default_var).validate_data(agent_variable.dict()) 

110 if agent_variable.shared: 

111 agent_variable.source = Source(agent_id=agent_id) 

112 

113 return agent_variable 

114 

115 

116def is_list_of_agent_variables(ls: Any): 

117 # TODO move somewhere more appropriate 

118 return isinstance(ls, list) and (len(ls) > 0) and isinstance(ls[0], AgentVariable) 

119 

120 

121def is_valid_agent_var_config( 

122 data: dict, field_name: str, type_: AgentVariable = AgentVariable 

123): 

124 if data == {}: 

125 return True 

126 try: 

127 type_.validate_data(data) 

128 return True 

129 except Exception as err: 

130 logger.error( 

131 "Could not update the default config of field '%s'. " 

132 "You most probably used some validator on this field. " 

133 "Error message: %s", 

134 err, 

135 field_name, 

136 ) 

137 return False