|
14 | 14 | * See the License for the specific language governing permissions and |
15 | 15 | * limitations under the License. |
16 | 16 | */ |
| 17 | +use crate::model::parse; |
17 | 18 | use crate::model::parse::{Format, In, Method, OpenAPI}; |
18 | 19 | use crate::request::validator::ValidateRequest; |
19 | 20 | use anyhow::{Context, Result}; |
@@ -153,51 +154,50 @@ impl ValidateRequest for RequestData { |
153 | 154 | } |
154 | 155 |
|
155 | 156 | fn body(&self, open_api: &OpenAPI) -> Result<()> { |
| 157 | + let body = self |
| 158 | + .body |
| 159 | + .as_ref() |
| 160 | + .ok_or_else(|| anyhow::anyhow!("Missing body"))?; |
156 | 161 | let path = open_api |
157 | 162 | .paths |
158 | 163 | .get(self.path.as_str()) |
159 | 164 | .context("Path not found")?; |
160 | | - |
161 | 165 | let path_base = path |
162 | 166 | .get(&Method::Post) |
163 | 167 | .context("Post method not defined for this path")?; |
164 | 168 |
|
165 | 169 | if let Some(request) = &path_base.request { |
166 | | - let map = &request.content; |
167 | | - |
168 | | - if let Some(body) = &self.body { |
169 | | - let request_fields: HashMap<String, Value> = serde_json::from_slice(body)?; |
170 | | - |
171 | | - for (key, value) in map { |
172 | | - if let Some(field) = request_fields.get(key) { |
173 | | - match value.schema.format { |
174 | | - Format::UUID => { |
175 | | - if let Some(v) = field.as_str() { |
176 | | - uuid::Uuid::parse_str(v).map_err(|_| { |
177 | | - anyhow::anyhow!( |
178 | | - "Invalid UUID format for query parameter '{}': '{:?}'", |
179 | | - key, |
180 | | - value |
181 | | - ) |
182 | | - })?; |
183 | | - } else { |
184 | | - return Err(anyhow::anyhow!( |
185 | | - "Invalid UUID format for query parameter '{}'", |
186 | | - key |
187 | | - )); |
188 | | - } |
189 | | - } |
190 | | - _ => { |
191 | | - return Err(anyhow::anyhow!( |
192 | | - "Unsupported format '{:?}' for query parameter '{}'", |
193 | | - value.schema.format, |
194 | | - key |
195 | | - )); |
196 | | - } |
| 170 | + let request_fields: HashMap<String, Value> = serde_json::from_slice(body)?; |
| 171 | + let refs = collect_schema_refs(&request.content); |
| 172 | + |
| 173 | + for (key, media_type) in &request.content { |
| 174 | + if let Some(field) = request_fields.get(key) { |
| 175 | + validate_field_format(key, field, media_type.schema.format.clone())?; |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + let mut requireds = HashSet::new(); |
| 180 | + |
| 181 | + if let Some(components) = &open_api.components { |
| 182 | + for schema_ref in &refs { |
| 183 | + if let Some(schema) = components.schemas.get(*schema_ref) { |
| 184 | + requireds.extend(schema.required.iter().cloned()); |
| 185 | + |
| 186 | + if let Some(properties) = &schema.properties { |
| 187 | + validate_schema_properties(&request_fields, properties)?; |
197 | 188 | } |
198 | 189 | } |
199 | 190 | } |
200 | 191 | } |
| 192 | + |
| 193 | + for key in &requireds { |
| 194 | + if !request_fields.contains_key(key) { |
| 195 | + return Err(anyhow::anyhow!( |
| 196 | + "Missing required query parameter: '{}'", |
| 197 | + key |
| 198 | + )); |
| 199 | + } |
| 200 | + } |
201 | 201 | } |
202 | 202 |
|
203 | 203 | Ok(()) |
@@ -225,3 +225,56 @@ fn validate_format(format: &Format, value: &str, key: &str) -> Result<()> { |
225 | 225 | } |
226 | 226 | Ok(()) |
227 | 227 | } |
| 228 | + |
| 229 | +fn collect_schema_refs(content: &HashMap<String, parse::BaseContent>) -> Vec<&str> { |
| 230 | + let mut refs = Vec::new(); |
| 231 | + for media_type in content.values() { |
| 232 | + if let Some(r) = &media_type.schema._ref { |
| 233 | + refs.push(r.as_str()); |
| 234 | + } |
| 235 | + if let Some(one_of) = &media_type.schema.one_of { |
| 236 | + refs.extend(one_of.iter().filter_map(|s| s._ref.as_deref())); |
| 237 | + } |
| 238 | + if let Some(all_of) = &media_type.schema.all_of { |
| 239 | + refs.extend(all_of.iter().filter_map(|s| s._ref.as_deref())); |
| 240 | + } |
| 241 | + } |
| 242 | + refs |
| 243 | +} |
| 244 | + |
| 245 | +fn validate_field_format(key: &str, value: &Value, format: Format) -> Result<()> { |
| 246 | + match format { |
| 247 | + Format::UUID => { |
| 248 | + let str_val = value.as_str().ok_or_else(|| { |
| 249 | + anyhow::anyhow!("Invalid UUID format for query parameter '{}'", key) |
| 250 | + })?; |
| 251 | + uuid::Uuid::parse_str(str_val).map_err(|_| { |
| 252 | + anyhow::anyhow!( |
| 253 | + "Invalid UUID format for query parameter '{}': '{}'", |
| 254 | + key, |
| 255 | + str_val |
| 256 | + ) |
| 257 | + })?; |
| 258 | + } |
| 259 | + _ => { |
| 260 | + return Err(anyhow::anyhow!( |
| 261 | + "Unsupported format '{:?}' for query parameter '{}'", |
| 262 | + format, |
| 263 | + key |
| 264 | + )); |
| 265 | + } |
| 266 | + } |
| 267 | + Ok(()) |
| 268 | +} |
| 269 | + |
| 270 | +fn validate_schema_properties( |
| 271 | + request_fields: &HashMap<String, Value>, |
| 272 | + properties: &HashMap<String, parse::Properties>, |
| 273 | +) -> Result<()> { |
| 274 | + for (key, prop) in properties { |
| 275 | + if let Some(value) = request_fields.get(key) { |
| 276 | + validate_field_format(key, value, prop.format.clone())?; |
| 277 | + } |
| 278 | + } |
| 279 | + Ok(()) |
| 280 | +} |
0 commit comments