Add Ollama JSON schema for categorization

This commit is contained in:
Keannu Bernasol 2024-12-18 17:05:44 +08:00
parent 844113d44f
commit b24646d42d
5 changed files with 88 additions and 40 deletions

View file

@ -1,4 +1,3 @@
from ollama import ChatResponse
import base64 import base64
import httpx import httpx
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
@ -18,6 +17,9 @@ from django.core.files import File
import logging import logging
import time import time
from ollama import Client from ollama import Client
from pydantic import BaseModel
from typing import Optional
import json
class PDFHandler(FileSystemEventHandler): class PDFHandler(FileSystemEventHandler):
@ -83,55 +85,52 @@ class PDFHandler(FileSystemEventHandler):
# Perform OCR # Perform OCR
text = pytesseract.image_to_string(img).strip() text = pytesseract.image_to_string(img).strip()
# Get document category
# Try to pass image to the Ollama image recognition API first # Try to pass image to the Ollama image recognition API first
try: try:
class DocumentCategory(BaseModel):
category: str = "other"
explanation: Optional[str] = None
client = Client( client = Client(
host=get_secret("OLLAMA_URL"), host=get_secret("OLLAMA_URL"),
auth=httpx.BasicAuth( auth=httpx.BasicAuth(
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None,
) )
encoded_image = base64.b64encode( encoded_image = base64.b64encode(
img_buffer.getvalue()).decode() img_buffer.getvalue()).decode()
attempts = 0 possible_categories = set((Document.objects.all().values_list(
while True: "document_type", flat=True), "Documented Procedures Manual", "Form", "Special Order"))
if attempts >= 3: prompt = f"""
raise Exception( Read the text from the image and provide a category. Return as JSON.
"Unable to categorize using Ollama API")
attempts += 1
content = f""" Possible categories are: {possible_categories}
Read the text from the image and provide a category.
Possible categories are: Announcement, Manual, Form
Respond only with the category. No explanations are necessary.
""" """
response = client.chat(
model=get_secret("OLLAMA_MODEL"),
messages=[
{"role": "user",
"content": prompt,
"images": [encoded_image]},
],
format=DocumentCategory.model_json_schema(),
options={
"temperature": 0
},
response: ChatResponse = client.chat( )
model=get_secret("OLLAMA_MODEL"),
messages=[
{"role": "user", "content": content,
"images": [encoded_image]},
],
)
document_type = response["message"]["content"].replace( DocumentCategory.model_validate_json(
"*", "").replace(".", "") response.message.content)
result = json.loads(response.message.content)
# A few safety checks if the model does not follow through with output instructions document_type = result.get("category")
if len(document_type) > 16:
self.logger.warning(
f"Ollama API gave incorrect document category: {response['message']['content']}. Retrying...")
break
# If that fails, just use regular OCR read the title as a dirty fix/fallback # If that fails, just use regular OCR read the title as a dirty fix/fallback
except Exception as e: except Exception as e:
self.logger.warning(f"Error! {e}") self.logger.warning(f"Error! {e}")
self.logger.warning( self.logger.warning(
"Ollama OCR offloading failed. Falling back to default OCR") "Ollama OCR offload failed. Falling back to default OCR")
lines = text.split("\n") lines = text.split("\n")
for line in lines: for line in lines:
@ -158,7 +157,8 @@ class PDFHandler(FileSystemEventHandler):
DOCUMENT.file.save( DOCUMENT.file.save(
name=filename, content=File(open(file_path, "rb"))) name=filename, content=File(open(file_path, "rb")))
self.logger.info( self.logger.info(
f"Document '{filename}' created successfully with type '{document_type}'." f"Document '{filename}' created successfully with type '{
document_type}'."
) )
else: else:

View file

@ -0,0 +1,25 @@
# Generated by Django 5.1.3 on 2024-12-18 07:58
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("document_requests", "0001_initial"),
("questionnaires", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="documentrequest",
name="questionnaire",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="questionnaires.questionnaire",
),
),
]

View file

@ -6,12 +6,16 @@ class DocumentRequestUnit(models.Model):
document_request = models.ForeignKey( document_request = models.ForeignKey(
"document_requests.DocumentRequest", on_delete=models.CASCADE "document_requests.DocumentRequest", on_delete=models.CASCADE
) )
document = models.ForeignKey("documents.Document", on_delete=models.CASCADE) document = models.ForeignKey(
"documents.Document", on_delete=models.CASCADE)
copies = models.IntegerField(default=1, null=False, blank=False) copies = models.IntegerField(default=1, null=False, blank=False)
class DocumentRequest(models.Model): class DocumentRequest(models.Model):
requester = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE) requester = models.ForeignKey(
"accounts.CustomUser", on_delete=models.CASCADE)
questionnaire = models.ForeignKey(
"questionnaires.Questionnaire", on_delete=models.SET_NULL, null=True, blank=True)
documents = models.ManyToManyField("document_requests.DocumentRequestUnit") documents = models.ManyToManyField("document_requests.DocumentRequestUnit")
date_requested = models.DateTimeField(default=now, editable=False) date_requested = models.DateTimeField(default=now, editable=False)
college = models.CharField(max_length=64, blank=False, null=False) college = models.CharField(max_length=64, blank=False, null=False)
@ -23,11 +27,13 @@ class DocumentRequest(models.Model):
("denied", "Denied"), ("denied", "Denied"),
) )
status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending") status = models.CharField(
max_length=32, choices=STATUS_CHOICES, default="pending")
TYPE_CHOICES = ( TYPE_CHOICES = (
("softcopy", "Softcopy"), ("softcopy", "Softcopy"),
("hardcopy", "Hardcopy"), ("hardcopy", "Hardcopy"),
) )
type = models.CharField(max_length=16, choices=TYPE_CHOICES, default="softcopy") type = models.CharField(
max_length=16, choices=TYPE_CHOICES, default="softcopy")

View file

@ -1,6 +1,7 @@
from rest_framework import serializers from rest_framework import serializers
from documents.models import Document from documents.models import Document
from documents.serializers import DocumentSerializer, DocumentFileSerializer from documents.serializers import DocumentSerializer, DocumentFileSerializer
from questionnaires.models import Questionnaire
from accounts.models import CustomUser from accounts.models import CustomUser
from emails.templates import RequestUpdateEmail from emails.templates import RequestUpdateEmail
from .models import DocumentRequest, DocumentRequestUnit from .models import DocumentRequest, DocumentRequestUnit
@ -24,7 +25,8 @@ class DocumentRequestCreationSerializer(serializers.ModelSerializer):
documents = DocumentRequestUnitCreationSerializer(many=True, required=True) documents = DocumentRequestUnitCreationSerializer(many=True, required=True)
college = serializers.CharField(max_length=64) college = serializers.CharField(max_length=64)
purpose = serializers.CharField(max_length=512) purpose = serializers.CharField(max_length=512)
type = serializers.ChoiceField(choices=DocumentRequest.TYPE_CHOICES, required=True) type = serializers.ChoiceField(
choices=DocumentRequest.TYPE_CHOICES, required=True)
class Meta: class Meta:
model = DocumentRequest model = DocumentRequest
@ -79,6 +81,12 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
queryset=CustomUser.objects.all(), queryset=CustomUser.objects.all(),
required=False, required=False,
) )
requester = serializers.SlugRelatedField(
many=False,
slug_field="id",
queryset=CustomUser.objects.all(),
required=False,
)
purpose = serializers.CharField(max_length=512) purpose = serializers.CharField(max_length=512)
date_requested = serializers.DateTimeField( date_requested = serializers.DateTimeField(
format="%m-%d-%Y %I:%M %p", read_only=True format="%m-%d-%Y %I:%M %p", read_only=True
@ -108,10 +116,10 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
] ]
def get_documents(self, obj): def get_documents(self, obj):
if obj.status != "approved": if obj.questionnaire and obj.status == "approved":
serializer_class = DocumentRequestUnitSerializer
else:
serializer_class = DocumentRequestUnitWithFileSerializer serializer_class = DocumentRequestUnitWithFileSerializer
else:
serializer_class = DocumentRequestUnitSerializer
return serializer_class(obj.documents, many=True).data return serializer_class(obj.documents, many=True).data

View file

@ -1,3 +1,5 @@
annotated-types==0.7.0
anyio==4.7.0
asgiref==3.8.1 asgiref==3.8.1
attrs==24.2.0 attrs==24.2.0
black==24.10.0 black==24.10.0
@ -22,6 +24,9 @@ drf-spectacular-sidecar==2024.11.1
filelock==3.16.1 filelock==3.16.1
fsspec==2024.10.0 fsspec==2024.10.0
gunicorn==23.0.0 gunicorn==23.0.0
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
idna==3.10 idna==3.10
inflection==0.5.1 inflection==0.5.1
Jinja2==3.1.4 Jinja2==3.1.4
@ -32,11 +37,14 @@ mpmath==1.3.0
mypy-extensions==1.0.0 mypy-extensions==1.0.0
networkx==3.4.2 networkx==3.4.2
oauthlib==3.2.2 oauthlib==3.2.2
ollama==0.4.4
packaging==24.2 packaging==24.2
pathspec==0.12.1 pathspec==0.12.1
pillow==11.0.0 pillow==11.0.0
platformdirs==4.3.6 platformdirs==4.3.6
pycparser==2.22 pycparser==2.22
pydantic==2.10.3
pydantic_core==2.27.1
pyflakes==3.2.0 pyflakes==3.2.0
PyJWT==2.10.0 PyJWT==2.10.0
PyMuPDF==1.24.14 PyMuPDF==1.24.14
@ -49,6 +57,7 @@ requests==2.32.3
requests-oauthlib==2.0.0 requests-oauthlib==2.0.0
rpds-py==0.21.0 rpds-py==0.21.0
setuptools==70.2.0 setuptools==70.2.0
sniffio==1.3.1
social-auth-app-django==5.4.2 social-auth-app-django==5.4.2
social-auth-core==4.5.4 social-auth-core==4.5.4
sqlparse==0.5.2 sqlparse==0.5.2