Retain root nodes of results (RDF or sparql) MPX-10
[mp-sparql-moved-to-github.git] / src / filter_sparql.cpp
index 8aa9bab..bfb18c7 100644 (file)
@@ -60,6 +60,7 @@ namespace metaproxy_1 {
         public:
             std::string db;
             std::string uri;
+            std::string schema;
             yaz_sparql_t s;
             ~Conf();
         };
@@ -77,6 +78,7 @@ namespace metaproxy_1 {
             friend class Session;
             Odr_int hits;
             std::string db;
+            ConfPtr conf;
             xmlDoc *doc;
         };
         class SPARQL::Session {
@@ -88,7 +90,7 @@ namespace metaproxy_1 {
                                Z_APDU *apdu_req,
                                mp::odr &odr,
                                const char *sparql_query,
-                               const char *uri);
+                               ConfPtr conf);
             Z_Records *fetch(
                 FrontendSetPtr fset,
                 ODR odr, Odr_oid *preferredRecordSyntax,
@@ -145,6 +147,8 @@ void yf::SPARQL::configure(const xmlNode *xmlnode, bool test_only,
                     conf->db = mp::xml::get_text(attr->children);
                 else if (!strcmp((const char *) attr->name, "uri"))
                     conf->uri = mp::xml::get_text(attr->children);
+                else if (!strcmp((const char *) attr->name, "schema"))
+                    conf->schema = mp::xml::get_text(attr->children);
                 else
                     throw mp::filter::FilterException(
                         "Bad attribute " + std::string((const char *)
@@ -265,14 +269,24 @@ void yf::SPARQL::release_session(Package &package) const
     }
 }
 
-static xmlNode *get_result(xmlDoc *doc, Odr_int *sz, Odr_int pos)
+static bool get_result(xmlDoc *doc, Odr_int *sz, Odr_int pos,
+                       xmlDoc **ndoc)
 {
     xmlNode *ptr = xmlDocGetRootElement(doc);
+    xmlNode *q0;
     Odr_int cur = 0;
 
+    if (ndoc)
+        *ndoc = xmlNewDoc(BAD_CAST "1.0");
+
     if (ptr->type == XML_ELEMENT_NODE &&
         !strcmp((const char *) ptr->name, "RDF"))
     {
+        if (ndoc)
+        {
+            q0 = xmlCopyNode(ptr, 2);
+            xmlDocSetRootElement(*ndoc, q0);
+        }
         ptr = ptr->children;
 
         while (ptr && ptr->type != XML_ELEMENT_NODE)
@@ -292,7 +306,14 @@ static xmlNode *get_result(xmlDoc *doc, Odr_int *sz, Odr_int pos)
                         !strcmp((const char *) ptr->name, "solution"))
                     {
                         if (cur++ == pos)
+                        {
+                            if (ndoc)
+                            {
+                                xmlNode *q1 = xmlCopyNode(ptr, 1);
+                                xmlAddChild(q0, q1);
+                            }
                             break;
+                        }
                     }
             }
             else
@@ -302,7 +323,14 @@ static xmlNode *get_result(xmlDoc *doc, Odr_int *sz, Odr_int pos)
                         !strcmp((const char *) ptr->name, "Description"))
                     {
                         if (cur++ == pos)
-                            break;
+                        {
+                            if (ndoc)
+                            {
+                                xmlNode *q1 = xmlCopyNode(ptr, 1);
+                                xmlAddChild(q0, q1);
+                            }
+                            return true;
+                        }
                     }
             }
         }
@@ -315,6 +343,11 @@ static xmlNode *get_result(xmlDoc *doc, Odr_int *sz, Odr_int pos)
                 break;
         if (ptr)
         {
+            if (ndoc)
+            {
+                q0 = xmlCopyNode(ptr, 2);
+                xmlDocSetRootElement(*ndoc, q0);
+            }
             for (ptr = ptr->children; ptr; ptr = ptr->next)
                 if (ptr->type == XML_ELEMENT_NODE &&
                     !strcmp((const char *) ptr->name, "results"))
@@ -322,18 +355,31 @@ static xmlNode *get_result(xmlDoc *doc, Odr_int *sz, Odr_int pos)
         }
         if (ptr)
         {
+            xmlNode *q1 = 0;
+            if (ndoc)
+            {
+                q1 = xmlCopyNode(ptr, 0);
+                xmlAddChild(q0, q1);
+            }
             for (ptr = ptr->children; ptr; ptr = ptr->next)
                 if (ptr->type == XML_ELEMENT_NODE &&
                     !strcmp((const char *) ptr->name, "result"))
                 {
                     if (cur++ == pos)
-                        break;
+                    {
+                        if (ndoc)
+                        {
+                            xmlNode *q2 = xmlCopyNode(ptr, 1);
+                            xmlAddChild(q1, q2);
+                        }
+                        return true;
+                    }
                 }
         }
     }
     if (sz)
         *sz = cur;
-    return ptr;
+    return false;
 }
 
 Z_Records *yf::SPARQL::Session::fetch(
@@ -344,6 +390,20 @@ Z_Records *yf::SPARQL::Session::fetch(
     int *number_returned, int *next_position)
 {
     Z_Records *rec = (Z_Records *) odr_malloc(odr, sizeof(Z_Records));
+    if (esn && esn->which == Z_ElementSetNames_generic &&
+        fset->conf->schema.length())
+    {
+        if (strcmp(esn->u.generic, fset->conf->schema.c_str()))
+        {
+            rec->which = Z_Records_NSD;
+            rec->u.nonSurrogateDiagnostic =
+                zget_DefaultDiagFormat(
+                    odr,
+                    YAZ_BIB1_SPECIFIED_ELEMENT_SET_NAME_NOT_VALID_FOR_SPECIFIED_,
+                    esn->u.generic);
+            return rec;
+        }
+    }
     rec->which = Z_Records_DBOSD;
     rec->u.databaseOrSurDiagnostics = (Z_NamePlusRecordList *)
         odr_malloc(odr, sizeof(Z_NamePlusRecordList));
@@ -357,17 +417,25 @@ Z_Records *yf::SPARQL::Session::fetch(
         Z_NamePlusRecord *npr = rec->u.databaseOrSurDiagnostics->records[i];
         npr->databaseName = odr_strdup(odr, fset->db.c_str());
         npr->which = Z_NamePlusRecord_databaseRecord;
+        xmlDoc *ndoc = 0;
 
-        xmlNode *node = get_result(fset->doc, 0, start - 1 + i);
-        if (!node)
+        if (!get_result(fset->doc, 0, start - 1 + i, &ndoc))
+        {
+            if (ndoc)
+                xmlFreeDoc(ndoc);
             break;
-        assert(node->type == XML_ELEMENT_NODE);
-        xmlNode *tmp = xmlCopyNode(node, 1);
+        }
+        xmlNode *ndoc_root = xmlDocGetRootElement(ndoc);
+        if (!ndoc_root)
+        {
+            xmlFreeDoc(ndoc);
+            break;
+        }
         xmlBufferPtr buf = xmlBufferCreate();
-        xmlNodeDump(buf, tmp->doc, tmp, 0, 0);
+        xmlNodeDump(buf, ndoc, ndoc_root, 0, 0);
         npr->u.databaseRecord =
             z_ext_record_xml(odr, (const char *) buf->content, buf->use);
-        xmlFreeNode(tmp);
+        xmlFreeDoc(ndoc);
         xmlBufferFree(buf);
     }
     rec->u.databaseOrSurDiagnostics->num_records = i;
@@ -383,18 +451,19 @@ Z_APDU *yf::SPARQL::Session::run_sparql(mp::Package &package,
                                         Z_APDU *apdu_req,
                                         mp::odr &odr,
                                         const char *sparql_query,
-                                        const char *uri)
+                                        ConfPtr conf)
 {
     Z_SearchRequest *req = apdu_req->u.searchRequest;
     Package http_package(package.session(), package.origin());
 
     http_package.copy_filter(package);
-    Z_GDU *gdu = z_get_HTTP_Request_uri(odr, uri, 0, 1);
+    Z_GDU *gdu = z_get_HTTP_Request_uri(odr, conf->uri.c_str(), 0, 1);
 
     z_HTTP_header_add(odr, &gdu->u.HTTP_Request->headers,
                       "Content-Type", "application/x-www-form-urlencoded");
     z_HTTP_header_add(odr, &gdu->u.HTTP_Request->headers,
-                      "Accept", "application/rdf+xml");
+                      "Accept", "application/sparql-results+xml,"
+                      "application/rdf+xml");
     const char *names[2];
     names[0] = "query";
     names[1] = 0;
@@ -437,6 +506,7 @@ Z_APDU *yf::SPARQL::Session::run_sparql(mp::Package &package,
 
         fset->doc = xmlParseMemory(resp->content_buf, resp->content_len);
         fset->db = req->databaseNames[0];
+        fset->conf = conf;
         if (!fset->doc)
             apdu_res = odr.create_searchResponse(apdu_req,
                                              YAZ_BIB1_TEMPORARY_SYSTEM_ERROR,
@@ -449,7 +519,7 @@ Z_APDU *yf::SPARQL::Session::run_sparql(mp::Package &package,
             int error_code = 0;
             std::string addinfo;
 
-            get_result(fset->doc, &fset->hits, -1);
+            get_result(fset->doc, &fset->hits, -1, 0);
             m_frontend_sets[req->resultSetName] = fset;
 
             Odr_int number = 0;
@@ -597,8 +667,7 @@ void yf::SPARQL::Session::handle_z(mp::Package &package, Z_APDU *apdu_req)
                 else
                 {
                     apdu_res = run_sparql(package, apdu_req, odr,
-                                          wrbuf_cstr(sparql_wr),
-                                          (*it)->uri.c_str());
+                                          wrbuf_cstr(sparql_wr), *it);
                 }
                 wrbuf_destroy(addinfo_wr);
                 wrbuf_destroy(sparql_wr);