Add Ollama JSON schema for categorization

This commit is contained in:
Keannu Bernasol 2024-12-18 17:05:44 +08:00
parent 844113d44f
commit b24646d42d
5 changed files with 88 additions and 40 deletions

View file

@ -1,4 +1,3 @@
from ollama import ChatResponse
import base64
import httpx
from django.core.management.base import BaseCommand
@ -18,6 +17,9 @@ from django.core.files import File
import logging
import time
from ollama import Client
from pydantic import BaseModel
from typing import Optional
import json
class PDFHandler(FileSystemEventHandler):
@ -83,55 +85,52 @@ class PDFHandler(FileSystemEventHandler):
# Perform OCR
text = pytesseract.image_to_string(img).strip()
# Get document category
# Try to pass image to the Ollama image recognition API first
try:
class DocumentCategory(BaseModel):
category: str = "other"
explanation: Optional[str] = None
client = Client(
host=get_secret("OLLAMA_URL"),
auth=httpx.BasicAuth(
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None,
)
encoded_image = base64.b64encode(
img_buffer.getvalue()).decode()
attempts = 0
while True:
if attempts >= 3:
raise Exception(
"Unable to categorize using Ollama API")
attempts += 1
possible_categories = set((Document.objects.all().values_list(
"document_type", flat=True), "Documented Procedures Manual", "Form", "Special Order"))
prompt = f"""
Read the text from the image and provide a category. Return as JSON.
content = f"""
Read the text from the image and provide a category.
Possible categories are: Announcement, Manual, Form
Respond only with the category. No explanations are necessary.
Possible categories are: {possible_categories}
"""
response = client.chat(
model=get_secret("OLLAMA_MODEL"),
messages=[
{"role": "user",
"content": prompt,
"images": [encoded_image]},
],
format=DocumentCategory.model_json_schema(),
options={
"temperature": 0
},
response: ChatResponse = client.chat(
model=get_secret("OLLAMA_MODEL"),
messages=[
{"role": "user", "content": content,
"images": [encoded_image]},
],
)
)
document_type = response["message"]["content"].replace(
"*", "").replace(".", "")
# A few safety checks if the model does not follow through with output instructions
if len(document_type) > 16:
self.logger.warning(
f"Ollama API gave incorrect document category: {response['message']['content']}. Retrying...")
break
DocumentCategory.model_validate_json(
response.message.content)
result = json.loads(response.message.content)
document_type = result.get("category")
# If that fails, just use regular OCR read the title as a dirty fix/fallback
except Exception as e:
self.logger.warning(f"Error! {e}")
self.logger.warning(
"Ollama OCR offloading failed. Falling back to default OCR")
"Ollama OCR offload failed. Falling back to default OCR")
lines = text.split("\n")
for line in lines:
@ -158,7 +157,8 @@ class PDFHandler(FileSystemEventHandler):
DOCUMENT.file.save(
name=filename, content=File(open(file_path, "rb")))
self.logger.info(
f"Document '{filename}' created successfully with type '{document_type}'."
f"Document '{filename}' created successfully with type '{
document_type}'."
)
else:

View file

@ -0,0 +1,25 @@
# Generated by Django 5.1.3 on 2024-12-18 07:58
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("document_requests", "0001_initial"),
("questionnaires", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="documentrequest",
name="questionnaire",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="questionnaires.questionnaire",
),
),
]

View file

@ -6,12 +6,16 @@ class DocumentRequestUnit(models.Model):
document_request = models.ForeignKey(
"document_requests.DocumentRequest", on_delete=models.CASCADE
)
document = models.ForeignKey("documents.Document", on_delete=models.CASCADE)
document = models.ForeignKey(
"documents.Document", on_delete=models.CASCADE)
copies = models.IntegerField(default=1, null=False, blank=False)
class DocumentRequest(models.Model):
requester = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
requester = models.ForeignKey(
"accounts.CustomUser", on_delete=models.CASCADE)
questionnaire = models.ForeignKey(
"questionnaires.Questionnaire", on_delete=models.SET_NULL, null=True, blank=True)
documents = models.ManyToManyField("document_requests.DocumentRequestUnit")
date_requested = models.DateTimeField(default=now, editable=False)
college = models.CharField(max_length=64, blank=False, null=False)
@ -23,11 +27,13 @@ class DocumentRequest(models.Model):
("denied", "Denied"),
)
status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending")
status = models.CharField(
max_length=32, choices=STATUS_CHOICES, default="pending")
TYPE_CHOICES = (
("softcopy", "Softcopy"),
("hardcopy", "Hardcopy"),
)
type = models.CharField(max_length=16, choices=TYPE_CHOICES, default="softcopy")
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, default="softcopy")

View file

@ -1,6 +1,7 @@
from rest_framework import serializers
from documents.models import Document
from documents.serializers import DocumentSerializer, DocumentFileSerializer
from questionnaires.models import Questionnaire
from accounts.models import CustomUser
from emails.templates import RequestUpdateEmail
from .models import DocumentRequest, DocumentRequestUnit
@ -24,7 +25,8 @@ class DocumentRequestCreationSerializer(serializers.ModelSerializer):
documents = DocumentRequestUnitCreationSerializer(many=True, required=True)
college = serializers.CharField(max_length=64)
purpose = serializers.CharField(max_length=512)
type = serializers.ChoiceField(choices=DocumentRequest.TYPE_CHOICES, required=True)
type = serializers.ChoiceField(
choices=DocumentRequest.TYPE_CHOICES, required=True)
class Meta:
model = DocumentRequest
@ -79,6 +81,12 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
queryset=CustomUser.objects.all(),
required=False,
)
requester = serializers.SlugRelatedField(
many=False,
slug_field="id",
queryset=CustomUser.objects.all(),
required=False,
)
purpose = serializers.CharField(max_length=512)
date_requested = serializers.DateTimeField(
format="%m-%d-%Y %I:%M %p", read_only=True
@ -108,10 +116,10 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
]
def get_documents(self, obj):
if obj.status != "approved":
serializer_class = DocumentRequestUnitSerializer
else:
if obj.questionnaire and obj.status == "approved":
serializer_class = DocumentRequestUnitWithFileSerializer
else:
serializer_class = DocumentRequestUnitSerializer
return serializer_class(obj.documents, many=True).data

View file

@ -1,3 +1,5 @@
annotated-types==0.7.0
anyio==4.7.0
asgiref==3.8.1
attrs==24.2.0
black==24.10.0
@ -22,6 +24,9 @@ drf-spectacular-sidecar==2024.11.1
filelock==3.16.1
fsspec==2024.10.0
gunicorn==23.0.0
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
idna==3.10
inflection==0.5.1
Jinja2==3.1.4
@ -32,11 +37,14 @@ mpmath==1.3.0
mypy-extensions==1.0.0
networkx==3.4.2
oauthlib==3.2.2
ollama==0.4.4
packaging==24.2
pathspec==0.12.1
pillow==11.0.0
platformdirs==4.3.6
pycparser==2.22
pydantic==2.10.3
pydantic_core==2.27.1
pyflakes==3.2.0
PyJWT==2.10.0
PyMuPDF==1.24.14
@ -49,6 +57,7 @@ requests==2.32.3
requests-oauthlib==2.0.0
rpds-py==0.21.0
setuptools==70.2.0
sniffio==1.3.1
social-auth-app-django==5.4.2
social-auth-core==4.5.4
sqlparse==0.5.2