From b24646d42df516499a45579ce081be9f4dead7af Mon Sep 17 00:00:00 2001 From: Keannu Bernasol Date: Wed, 18 Dec 2024 17:05:44 +0800 Subject: [PATCH] Add Ollama JSON schema for categorization --- .../management/commands/start_watcher.py | 64 +++++++++---------- .../0002_documentrequest_questionnaire.py | 25 ++++++++ .../document_requests/models.py | 14 ++-- .../document_requests/serializers.py | 16 +++-- requirements.txt | 9 +++ 5 files changed, 88 insertions(+), 40 deletions(-) create mode 100644 docmanager_backend/document_requests/migrations/0002_documentrequest_questionnaire.py diff --git a/docmanager_backend/config/management/commands/start_watcher.py b/docmanager_backend/config/management/commands/start_watcher.py index 2fe93fe..fc456f9 100644 --- a/docmanager_backend/config/management/commands/start_watcher.py +++ b/docmanager_backend/config/management/commands/start_watcher.py @@ -1,4 +1,3 @@ -from ollama import ChatResponse import base64 import httpx from django.core.management.base import BaseCommand @@ -18,6 +17,9 @@ from django.core.files import File import logging import time from ollama import Client +from pydantic import BaseModel +from typing import Optional +import json class PDFHandler(FileSystemEventHandler): @@ -83,55 +85,52 @@ class PDFHandler(FileSystemEventHandler): # Perform OCR text = pytesseract.image_to_string(img).strip() - # Get document category # Try to pass image to the Ollama image recognition API first try: + class DocumentCategory(BaseModel): + category: str = "other" + explanation: Optional[str] = None + client = Client( host=get_secret("OLLAMA_URL"), auth=httpx.BasicAuth( - username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None + username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None, ) encoded_image = base64.b64encode( img_buffer.getvalue()).decode() - attempts = 0 - while True: - if attempts >= 3: - raise Exception( - "Unable to categorize using Ollama API") - attempts += 1 + possible_categories = set((Document.objects.all().values_list( + "document_type", flat=True), "Documented Procedures Manual", "Form", "Special Order")) + prompt = f""" + Read the text from the image and provide a category. Return as JSON. - content = f""" - Read the text from the image and provide a category. - - Possible categories are: Announcement, Manual, Form - - Respond only with the category. No explanations are necessary. + Possible categories are: {possible_categories} """ + response = client.chat( + model=get_secret("OLLAMA_MODEL"), + messages=[ + {"role": "user", + "content": prompt, + "images": [encoded_image]}, + ], + format=DocumentCategory.model_json_schema(), + options={ + "temperature": 0 + }, - response: ChatResponse = client.chat( - model=get_secret("OLLAMA_MODEL"), - messages=[ - {"role": "user", "content": content, - "images": [encoded_image]}, - ], - ) + ) - document_type = response["message"]["content"].replace( - "*", "").replace(".", "") - - # A few safety checks if the model does not follow through with output instructions - if len(document_type) > 16: - self.logger.warning( - f"Ollama API gave incorrect document category: {response['message']['content']}. Retrying...") - break + DocumentCategory.model_validate_json( + response.message.content) + result = json.loads(response.message.content) + document_type = result.get("category") # If that fails, just use regular OCR read the title as a dirty fix/fallback except Exception as e: self.logger.warning(f"Error! {e}") self.logger.warning( - "Ollama OCR offloading failed. Falling back to default OCR") + "Ollama OCR offload failed. Falling back to default OCR") lines = text.split("\n") for line in lines: @@ -158,7 +157,8 @@ class PDFHandler(FileSystemEventHandler): DOCUMENT.file.save( name=filename, content=File(open(file_path, "rb"))) self.logger.info( - f"Document '{filename}' created successfully with type '{document_type}'." + f"Document '{filename}' created successfully with type '{ + document_type}'." ) else: diff --git a/docmanager_backend/document_requests/migrations/0002_documentrequest_questionnaire.py b/docmanager_backend/document_requests/migrations/0002_documentrequest_questionnaire.py new file mode 100644 index 0000000..7f70605 --- /dev/null +++ b/docmanager_backend/document_requests/migrations/0002_documentrequest_questionnaire.py @@ -0,0 +1,25 @@ +# Generated by Django 5.1.3 on 2024-12-18 07:58 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("document_requests", "0001_initial"), + ("questionnaires", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="documentrequest", + name="questionnaire", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="questionnaires.questionnaire", + ), + ), + ] diff --git a/docmanager_backend/document_requests/models.py b/docmanager_backend/document_requests/models.py index c6e421c..1efc716 100644 --- a/docmanager_backend/document_requests/models.py +++ b/docmanager_backend/document_requests/models.py @@ -6,12 +6,16 @@ class DocumentRequestUnit(models.Model): document_request = models.ForeignKey( "document_requests.DocumentRequest", on_delete=models.CASCADE ) - document = models.ForeignKey("documents.Document", on_delete=models.CASCADE) + document = models.ForeignKey( + "documents.Document", on_delete=models.CASCADE) copies = models.IntegerField(default=1, null=False, blank=False) class DocumentRequest(models.Model): - requester = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE) + requester = models.ForeignKey( + "accounts.CustomUser", on_delete=models.CASCADE) + questionnaire = models.ForeignKey( + "questionnaires.Questionnaire", on_delete=models.SET_NULL, null=True, blank=True) documents = models.ManyToManyField("document_requests.DocumentRequestUnit") date_requested = models.DateTimeField(default=now, editable=False) college = models.CharField(max_length=64, blank=False, null=False) @@ -23,11 +27,13 @@ class DocumentRequest(models.Model): ("denied", "Denied"), ) - status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending") + status = models.CharField( + max_length=32, choices=STATUS_CHOICES, default="pending") TYPE_CHOICES = ( ("softcopy", "Softcopy"), ("hardcopy", "Hardcopy"), ) - type = models.CharField(max_length=16, choices=TYPE_CHOICES, default="softcopy") + type = models.CharField( + max_length=16, choices=TYPE_CHOICES, default="softcopy") diff --git a/docmanager_backend/document_requests/serializers.py b/docmanager_backend/document_requests/serializers.py index 0d57aa3..f9760f5 100644 --- a/docmanager_backend/document_requests/serializers.py +++ b/docmanager_backend/document_requests/serializers.py @@ -1,6 +1,7 @@ from rest_framework import serializers from documents.models import Document from documents.serializers import DocumentSerializer, DocumentFileSerializer +from questionnaires.models import Questionnaire from accounts.models import CustomUser from emails.templates import RequestUpdateEmail from .models import DocumentRequest, DocumentRequestUnit @@ -24,7 +25,8 @@ class DocumentRequestCreationSerializer(serializers.ModelSerializer): documents = DocumentRequestUnitCreationSerializer(many=True, required=True) college = serializers.CharField(max_length=64) purpose = serializers.CharField(max_length=512) - type = serializers.ChoiceField(choices=DocumentRequest.TYPE_CHOICES, required=True) + type = serializers.ChoiceField( + choices=DocumentRequest.TYPE_CHOICES, required=True) class Meta: model = DocumentRequest @@ -79,6 +81,12 @@ class DocumentRequestSerializer(serializers.ModelSerializer): queryset=CustomUser.objects.all(), required=False, ) + requester = serializers.SlugRelatedField( + many=False, + slug_field="id", + queryset=CustomUser.objects.all(), + required=False, + ) purpose = serializers.CharField(max_length=512) date_requested = serializers.DateTimeField( format="%m-%d-%Y %I:%M %p", read_only=True @@ -108,10 +116,10 @@ class DocumentRequestSerializer(serializers.ModelSerializer): ] def get_documents(self, obj): - if obj.status != "approved": - serializer_class = DocumentRequestUnitSerializer - else: + if obj.questionnaire and obj.status == "approved": serializer_class = DocumentRequestUnitWithFileSerializer + else: + serializer_class = DocumentRequestUnitSerializer return serializer_class(obj.documents, many=True).data diff --git a/requirements.txt b/requirements.txt index 5e3db18..7a802c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +annotated-types==0.7.0 +anyio==4.7.0 asgiref==3.8.1 attrs==24.2.0 black==24.10.0 @@ -22,6 +24,9 @@ drf-spectacular-sidecar==2024.11.1 filelock==3.16.1 fsspec==2024.10.0 gunicorn==23.0.0 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.27.2 idna==3.10 inflection==0.5.1 Jinja2==3.1.4 @@ -32,11 +37,14 @@ mpmath==1.3.0 mypy-extensions==1.0.0 networkx==3.4.2 oauthlib==3.2.2 +ollama==0.4.4 packaging==24.2 pathspec==0.12.1 pillow==11.0.0 platformdirs==4.3.6 pycparser==2.22 +pydantic==2.10.3 +pydantic_core==2.27.1 pyflakes==3.2.0 PyJWT==2.10.0 PyMuPDF==1.24.14 @@ -49,6 +57,7 @@ requests==2.32.3 requests-oauthlib==2.0.0 rpds-py==0.21.0 setuptools==70.2.0 +sniffio==1.3.1 social-auth-app-django==5.4.2 social-auth-core==4.5.4 sqlparse==0.5.2