Clean up docker-compose and run Black formatter over entire codebase

This commit is contained in:
Keannu Bernasol 2024-10-30 22:09:58 +08:00
parent 6c232b3e89
commit 069aba80b1
60 changed files with 1946 additions and 1485 deletions

View file

@ -35,6 +35,7 @@ gunicorn = "*"
django-silk = "*"
django-redis = "*"
granian = "*"
black = "*"
[dev-packages]

1492
Pipfile.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,11 +6,13 @@ from .models import CustomUser
class CustomUserAdmin(UserAdmin):
model = CustomUser
list_display = ('id', 'is_active', 'user_group',) + UserAdmin.list_display
list_display = (
"id",
"is_active",
"user_group",
) + UserAdmin.list_display
# Editable fields per instance
fieldsets = UserAdmin.fieldsets + (
(None, {'fields': ('avatar',)}),
)
fieldsets = UserAdmin.fieldsets + ((None, {"fields": ("avatar",)}),)
admin.site.register(CustomUser, CustomUserAdmin)

View file

@ -2,8 +2,8 @@ from django.apps import AppConfig
class AccountsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'accounts'
default_auto_field = "django.db.models.BigAutoField"
name = "accounts"
def ready(self):
import accounts.signals

View file

@ -13,38 +13,145 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
('auth', '0012_alter_user_first_name_max_length'),
('user_groups', '0001_initial'),
("auth", "0012_alter_user_first_name_max_length"),
("user_groups", "0001_initial"),
]
operations = [
migrations.CreateModel(
name='CustomUser',
name="CustomUser",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('password', models.CharField(max_length=128, verbose_name='password')),
('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')),
('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')),
('username', models.CharField(error_messages={'unique': 'A user with that username already exists.'}, help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', max_length=150, unique=True, validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name='username')),
('first_name', models.CharField(blank=True, max_length=150, verbose_name='first name')),
('last_name', models.CharField(blank=True, max_length=150, verbose_name='last name')),
('email', models.EmailField(blank=True, max_length=254, verbose_name='email address')),
('is_staff', models.BooleanField(default=False, help_text='Designates whether the user can log into this admin site.', verbose_name='staff status')),
('is_active', models.BooleanField(default=True, help_text='Designates whether this user should be treated as active. Unselect this instead of deleting accounts.', verbose_name='active')),
('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')),
('avatar', django_resized.forms.ResizedImageField(crop=None, force_format='WEBP', keep_meta=True, null=True, quality=100, scale=None, size=[1920, 1080], upload_to='avatars/')),
('onboarding', models.BooleanField(default=True)),
('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.group', verbose_name='groups')),
('user_group', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='user_groups.usergroup')),
('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.permission', verbose_name='user permissions')),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("password", models.CharField(max_length=128, verbose_name="password")),
(
"last_login",
models.DateTimeField(
blank=True, null=True, verbose_name="last login"
),
),
(
"is_superuser",
models.BooleanField(
default=False,
help_text="Designates that this user has all permissions without explicitly assigning them.",
verbose_name="superuser status",
),
),
(
"username",
models.CharField(
error_messages={
"unique": "A user with that username already exists."
},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150,
unique=True,
validators=[
django.contrib.auth.validators.UnicodeUsernameValidator()
],
verbose_name="username",
),
),
(
"first_name",
models.CharField(
blank=True, max_length=150, verbose_name="first name"
),
),
(
"last_name",
models.CharField(
blank=True, max_length=150, verbose_name="last name"
),
),
(
"email",
models.EmailField(
blank=True, max_length=254, verbose_name="email address"
),
),
(
"is_staff",
models.BooleanField(
default=False,
help_text="Designates whether the user can log into this admin site.",
verbose_name="staff status",
),
),
(
"is_active",
models.BooleanField(
default=True,
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
verbose_name="active",
),
),
(
"date_joined",
models.DateTimeField(
default=django.utils.timezone.now, verbose_name="date joined"
),
),
(
"avatar",
django_resized.forms.ResizedImageField(
crop=None,
force_format="WEBP",
keep_meta=True,
null=True,
quality=100,
scale=None,
size=[1920, 1080],
upload_to="avatars/",
),
),
("onboarding", models.BooleanField(default=True)),
(
"groups",
models.ManyToManyField(
blank=True,
help_text="The groups this user belongs to. A user will get all permissions granted to each of their groups.",
related_name="user_set",
related_query_name="user",
to="auth.group",
verbose_name="groups",
),
),
(
"user_group",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="user_groups.usergroup",
),
),
(
"user_permissions",
models.ManyToManyField(
blank=True,
help_text="Specific permissions for this user.",
related_name="user_set",
related_query_name="user",
to="auth.permission",
verbose_name="user permissions",
),
),
],
options={
'verbose_name': 'user',
'verbose_name_plural': 'users',
'abstract': False,
"verbose_name": "user",
"verbose_name_plural": "users",
"abstract": False,
},
managers=[
('objects', django.contrib.auth.models.UserManager()),
("objects", django.contrib.auth.models.UserManager()),
],
),
]

View file

@ -15,14 +15,16 @@ class CustomUser(AbstractUser):
# is_admin inherited from base user class
avatar = ResizedImageField(
null=True, force_format="WEBP", quality=100, upload_to='avatars/')
null=True, force_format="WEBP", quality=100, upload_to="avatars/"
)
# Used for onboarding processes
# Set this to False later on once the user makes actions
onboarding = models.BooleanField(default=True)
user_group = models.ForeignKey(
'user_groups.UserGroup', on_delete=models.SET_NULL, null=True)
"user_groups.UserGroup", on_delete=models.SET_NULL, null=True
)
@property
def group_member(self):
@ -57,4 +59,4 @@ class CustomUser(AbstractUser):
@property
def admin_url(self):
return reverse('admin:users_customuser_change', args=(self.pk,))
return reverse("admin:users_customuser_change", args=(self.pk,))

View file

@ -8,13 +8,14 @@ from django.core.cache import cache
from django.core import exceptions as django_exceptions
from rest_framework.settings import api_settings
from django.contrib.auth.password_validation import validate_password
# There can be multiple subject instances with the same name, only differing in course, year level, and semester. We filter them here
class SimpleCustomUserSerializer(ModelSerializer):
class Meta(BaseUserSerializer.Meta):
model = CustomUser
fields = ('id', 'username', 'email', 'full_name')
fields = ("id", "username", "email", "full_name")
class CustomUserSerializer(BaseUserSerializer):
@ -22,19 +23,36 @@ class CustomUserSerializer(BaseUserSerializer):
class Meta(BaseUserSerializer.Meta):
model = CustomUser
fields = ('id', 'username', 'email', 'avatar', 'first_name',
'last_name', 'user_group', 'group_member', 'group_owner')
read_only_fields = ('id', 'username', 'email', 'user_group',
'group_member', 'group_owner')
fields = (
"id",
"username",
"email",
"avatar",
"first_name",
"is_new",
"last_name",
"user_group",
"group_member",
"group_owner",
)
read_only_fields = (
"id",
"username",
"email",
"user_group",
"group_member",
"group_owner",
)
def to_representation(self, instance):
representation = super().to_representation(instance)
representation['user_group'] = SimpleUserGroupSerializer(
instance.user_group, many=False).data
representation["user_group"] = SimpleUserGroupSerializer(
instance.user_group, many=False
).data
return representation
def update(self, instance, validated_data):
cache.delete(f'user:{instance.id}')
cache.delete(f"user:{instance.id}")
return super().update(instance, validated_data)
@ -42,16 +60,18 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
email = serializers.EmailField(required=True)
username = serializers.CharField(required=True)
password = serializers.CharField(
write_only=True, style={'input_type': 'password', 'placeholder': 'Password'})
write_only=True, style={"input_type": "password", "placeholder": "Password"}
)
first_name = serializers.CharField(
required=True, allow_blank=False, allow_null=False)
required=True, allow_blank=False, allow_null=False
)
last_name = serializers.CharField(
required=True, allow_blank=False, allow_null=False)
required=True, allow_blank=False, allow_null=False
)
class Meta:
model = CustomUser
fields = ['email', 'username', 'password',
'first_name', 'last_name']
fields = ["email", "username", "password", "first_name", "last_name"]
def validate(self, attrs):
user_attrs = attrs.copy()
@ -69,14 +89,15 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
raise serializers.ValidationError({"password": errors})
if self.Meta.model.objects.filter(username=attrs.get("username")).exists():
raise serializers.ValidationError(
"A user with that username already exists.")
"A user with that username already exists."
)
return super().validate(attrs)
def create(self, validated_data):
user = self.Meta.model(**validated_data)
user.username = validated_data['username']
user.username = validated_data["username"]
user.is_active = False
user.set_password(validated_data['password'])
user.set_password(validated_data["password"])
user.save()
return user

View file

@ -12,38 +12,37 @@ import json
@receiver(post_migrate)
def create_users(sender, **kwargs):
if sender.name == "accounts":
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
seed_data = json.loads(f.read())
for user in seed_data['users']:
USER = CustomUser.objects.filter(
email=user['email']).first()
for user in seed_data["users"]:
USER = CustomUser.objects.filter(email=user["email"]).first()
if not USER:
if user['password'] == 'USE_REGULAR':
password = get_secret('SEED_DATA_PASSWORD')
elif user['password'] == 'USE_ADMIN':
password = get_secret('SEED_DATA_ADMIN_PASSWORD')
if user["password"] == "USE_REGULAR":
password = get_secret("SEED_DATA_PASSWORD")
elif user["password"] == "USE_ADMIN":
password = get_secret("SEED_DATA_ADMIN_PASSWORD")
else:
password = user['password']
if (user['is_superuser'] == True):
password = user["password"]
if user["is_superuser"] == True:
# Admin users are created regardless of SEED_DATA value
USER = CustomUser.objects.create_superuser(
username=user['username'],
email=user['email'],
username=user["username"],
email=user["email"],
password=password,
)
print('Created Superuser:', user['email'])
print("Created Superuser:", user["email"])
else:
# Only create non-admin users if SEED_DATA=True
if SEED_DATA:
USER = CustomUser.objects.create_user(
username=user['email'],
email=user['email'],
username=user["email"],
email=user["email"],
password=password,
)
print('Created User:', user['email'])
print("Created User:", user["email"])
USER.first_name = user['first_name']
USER.last_name = user['last_name']
USER.first_name = user["first_name"]
USER.last_name = user["last_name"]
USER.is_active = True
USER.save()
@ -51,53 +50,57 @@ def create_users(sender, **kwargs):
@receiver(post_migrate)
def create_celery_beat_schedules(sender, **kwargs):
if sender.name == "django_celery_beat":
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
seed_data = json.loads(f.read())
# Creating Schedules
for schedule in seed_data['schedules']:
if schedule['type'] == 'crontab':
for schedule in seed_data["schedules"]:
if schedule["type"] == "crontab":
# Check if Schedule already exists
SCHEDULE = CrontabSchedule.objects.filter(minute=schedule['minute'],
hour=schedule['hour'],
day_of_week=schedule['day_of_week'],
day_of_month=schedule['day_of_month'],
month_of_year=schedule['month_of_year'],
timezone=schedule['timezone']
).first()
SCHEDULE = CrontabSchedule.objects.filter(
minute=schedule["minute"],
hour=schedule["hour"],
day_of_week=schedule["day_of_week"],
day_of_month=schedule["day_of_month"],
month_of_year=schedule["month_of_year"],
timezone=schedule["timezone"],
).first()
# If it does not exist, create a new Schedule
if not SCHEDULE:
SCHEDULE = CrontabSchedule.objects.create(
minute=schedule['minute'],
hour=schedule['hour'],
day_of_week=schedule['day_of_week'],
day_of_month=schedule['day_of_month'],
month_of_year=schedule['month_of_year'],
timezone=schedule['timezone']
minute=schedule["minute"],
hour=schedule["hour"],
day_of_week=schedule["day_of_week"],
day_of_month=schedule["day_of_month"],
month_of_year=schedule["month_of_year"],
timezone=schedule["timezone"],
)
print(
f'Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}')
f"Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}"
)
else:
print(
f'Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists')
for task in seed_data['scheduled_tasks']:
TASK = PeriodicTask.objects.filter(name=task['name']).first()
f"Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists"
)
for task in seed_data["scheduled_tasks"]:
TASK = PeriodicTask.objects.filter(name=task["name"]).first()
if not TASK:
if task['schedule']['type'] == 'crontab':
SCHEDULE = CrontabSchedule.objects.filter(minute=task['schedule']['minute'],
hour=task['schedule']['hour'],
day_of_week=task['schedule']['day_of_week'],
day_of_month=task['schedule']['day_of_month'],
month_of_year=task['schedule']['month_of_year'],
timezone=task['schedule']['timezone']
).first()
if task["schedule"]["type"] == "crontab":
SCHEDULE = CrontabSchedule.objects.filter(
minute=task["schedule"]["minute"],
hour=task["schedule"]["hour"],
day_of_week=task["schedule"]["day_of_week"],
day_of_month=task["schedule"]["day_of_month"],
month_of_year=task["schedule"]["month_of_year"],
timezone=task["schedule"]["timezone"],
).first()
TASK = PeriodicTask.objects.create(
crontab=SCHEDULE,
name=task['name'],
task=task['task'],
enabled=task['enabled']
name=task["name"],
task=task["task"],
enabled=task["enabled"],
)
print(f'Created Periodic Task: {TASK.name}')
print(f"Created Periodic Task: {TASK.name}")
else:
raise Exception('Schedule for Periodic Task not found')
raise Exception("Schedule for Periodic Task not found")
else:
print(f'Periodic Task: {TASK.name} already exists')
print(f"Periodic Task: {TASK.name} already exists")

View file

@ -4,20 +4,26 @@ from celery import shared_task
@shared_task
def get_paying_users():
from subscriptions.models import UserSubscription
# Get a list of user subscriptions
active_subscriptions = UserSubscription.objects.filter(
valid=True).distinct('user')
active_subscriptions = UserSubscription.objects.filter(valid=True).distinct("user")
# Get paying users
active_users = []
# Paying regular users
active_users += [
subscription.user.id for subscription in active_subscriptions if subscription.user is not None and subscription.user.user_group is None]
subscription.user.id
for subscription in active_subscriptions
if subscription.user is not None and subscription.user.user_group is None
]
# Paying users within groups
active_users += [
subscription.user_group.members for subscription in active_subscriptions if subscription.user_group is not None and subscription.user is None]
subscription.user_group.members
for subscription in active_subscriptions
if subscription.user_group is not None and subscription.user is None
]
# Return paying users
return active_users

View file

@ -3,10 +3,10 @@ from rest_framework.routers import DefaultRouter
from accounts import views
router = DefaultRouter()
router.register(r'users', views.CustomUserViewSet, basename='users')
router.register(r"users", views.CustomUserViewSet, basename="users")
urlpatterns = [
path('', include(router.urls)),
path('', include('djoser.urls')),
path('', include('djoser.urls.jwt')),
path("", include(router.urls)),
path("", include("djoser.urls")),
path("", include("djoser.urls.jwt")),
]

View file

@ -1,4 +1,3 @@
from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _
import re
@ -6,9 +5,10 @@ import re
class UppercaseValidator(object):
def validate(self, password, user=None):
if not re.findall('[A-Z]', password):
if not re.findall("[A-Z]", password):
raise ValidationError(
_("The password must contain at least 1 uppercase letter (A-Z)."))
_("The password must contain at least 1 uppercase letter (A-Z).")
)
def get_help_text(self):
return _("Your password must contain at least 1 uppercase letter (A-Z).")
@ -16,9 +16,10 @@ class UppercaseValidator(object):
class LowercaseValidator(object):
def validate(self, password, user=None):
if not re.findall('[a-z]', password):
if not re.findall("[a-z]", password):
raise ValidationError(
_("The password must contain at least 1 lowercase letter (a-z)."))
_("The password must contain at least 1 lowercase letter (a-z).")
)
def get_help_text(self):
return _("Your password must contain at least 1 lowercase letter (a-z).")
@ -26,19 +27,25 @@ class LowercaseValidator(object):
class SpecialCharacterValidator(object):
def validate(self, password, user=None):
if not re.findall('[@#$%^&*()_+/\<>;:!?]', password):
if not re.findall("[@#$%^&*()_+/\<>;:!?]", password):
raise ValidationError(
_("The password must contain at least 1 special character (@, #, $, etc.)."))
_(
"The password must contain at least 1 special character (@, #, $, etc.)."
)
)
def get_help_text(self):
return _("Your password must contain at least 1 special character (@, #, $, etc.).")
return _(
"Your password must contain at least 1 special character (@, #, $, etc.)."
)
class NumberValidator(object):
def validate(self, password, user=None):
if not any(char.isdigit() for char in password):
raise ValidationError(
_("The password must contain at least one numerical digit (0-9)."))
_("The password must contain at least one numerical digit (0-9).")
)
def get_help_text(self):
return _("Your password must contain at least numerical digit (0-9).")

View file

@ -22,28 +22,27 @@ class CustomUserViewSet(DjoserUserViewSet):
user = self.request.user
# If user is admin, show all active users
if user.is_superuser:
key = 'users'
key = "users"
# Get cache
queryset = cache.get(key)
# Set cache if stale or does not exist
if not queryset:
queryset = CustomUser.objects.filter(is_active=True)
cache.set(key, queryset, 60*60)
cache.set(key, queryset, 60 * 60)
return queryset
elif not user.user_group:
key = f'user:{user.id}'
key = f"user:{user.id}"
queryset = cache.get(key)
if not queryset:
queryset = CustomUser.objects.filter(is_active=True)
cache.set(key, queryset, 60*60)
cache.set(key, queryset, 60 * 60)
return queryset
elif user.user_group:
key = f'usergroup_users:{user.user_group.id}'
key = f"usergroup_users:{user.user_group.id}"
queryset = cache.get(key)
if not queryset:
queryset = CustomUser.objects.filter(
user_group=user.user_group)
cache.set(key, queryset, 60*60)
queryset = CustomUser.objects.filter(user_group=user.user_group)
cache.set(key, queryset, 60 * 60)
return queryset
else:
return CustomUser.objects.none()
@ -52,10 +51,10 @@ class CustomUserViewSet(DjoserUserViewSet):
user = self.request.user
# Clear cache
cache.delete(f'users')
cache.delete(f'user:{user.id}')
cache.delete(f"users")
cache.delete(f"user:{user.id}")
if user.user_group:
cache.delete(f'usergroup_users:{user.user_group.id}')
cache.delete(f"usergroup_users:{user.user_group.id}")
super().perform_update(serializer, *args, **kwargs)
user = serializer.instance
@ -84,16 +83,18 @@ class CustomUserViewSet(DjoserUserViewSet):
settings.EMAIL.confirmation(self.request, context).send(to)
# Clear cache
cache.delete('users')
cache.delete(f'user:{user.id}')
cache.delete("users")
cache.delete(f"user:{user.id}")
if user.user_group:
cache.delete(f'usergroup_users:{user.user_group.id}')
cache.delete(f"usergroup_users:{user.user_group.id}")
except Exception as e:
print('Warning: Unable to send email')
print("Warning: Unable to send email")
print(e)
@action(methods=['post'], detail=False, url_path='activation', url_name='activation')
@action(
methods=["post"], detail=False, url_path="activation", url_name="activation"
)
def activation(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
@ -103,16 +104,16 @@ class CustomUserViewSet(DjoserUserViewSet):
# Construct a response with user's first name, last name, and email
user_data = {
'first_name': user.first_name,
'last_name': user.last_name,
'email': user.email,
'username': user.username
"first_name": user.first_name,
"last_name": user.last_name,
"email": user.email,
"username": user.username,
}
# Clear cache
cache.delete('users')
cache.delete(f'user:{user.id}')
cache.delete("users")
cache.delete(f"user:{user.id}")
if user.user_group:
cache.delete(f'usergroup_users:{user.user_group.id}')
cache.delete(f"usergroup_users:{user.user_group.id}")
return Response(user_data, status=status.HTTP_200_OK)

View file

@ -1,27 +1,31 @@
from django.conf.urls.static import static
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
from django.urls import path, include
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
from drf_spectacular.views import (
SpectacularAPIView,
SpectacularRedocView,
SpectacularSwaggerView,
)
from django.contrib import admin
from config.settings import DEBUG, SERVE_MEDIA, MEDIA_ROOT
urlpatterns = [
path('accounts/', include('accounts.urls')),
path('subscriptions/', include('subscriptions.urls')),
path('notifications/', include('notifications.urls')),
path('billing/', include('billing.urls')),
path('stripe/', include('payments.urls')),
path('admin/', admin.site.urls),
path('schema/', SpectacularAPIView.as_view(), name='schema'),
path('swagger/',
SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'),
path('redoc/',
SpectacularRedocView.as_view(url_name='schema'), name='redoc'),
path("accounts/", include("accounts.urls")),
path("subscriptions/", include("subscriptions.urls")),
path("notifications/", include("notifications.urls")),
path("billing/", include("billing.urls")),
path("stripe/", include("payments.urls")),
path("admin/", admin.site.urls),
path("schema/", SpectacularAPIView.as_view(), name="schema"),
path(
"swagger/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"
),
path("redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"),
]
# URLs for local development
if DEBUG and SERVE_MEDIA:
urlpatterns += staticfiles_urlpatterns()
urlpatterns += static(
'media/', document_root=MEDIA_ROOT)
urlpatterns += static("media/", document_root=MEDIA_ROOT)
if DEBUG:
urlpatterns += [path('silk/', include('silk.urls', namespace='silk'))]
urlpatterns += [path("silk/", include("silk.urls", namespace="silk"))]

View file

@ -2,6 +2,5 @@ from django.urls import path
from billing import views
urlpatterns = [
path('',
views.BillingHistoryView.as_view()),
path("", views.BillingHistoryView.as_view()),
]

View file

@ -24,7 +24,7 @@ class BillingHistoryView(APIView):
email = requesting_user.email
# Check cache
key = f'billing_user:{requesting_user.id}'
key = f"billing_user:{requesting_user.id}"
billing_history = cache.get(key)
if not billing_history:
@ -39,23 +39,25 @@ class BillingHistoryView(APIView):
if len(customers.data) > 0:
# Retrieve the customer's charges (billing history)
charges = stripe.Charge.list(
limit=10, customer=customer.id)
charges = stripe.Charge.list(limit=10, customer=customer.id)
# Prepare the response
billing_history = [
{
'email': charge['billing_details']['email'],
'amount_charged': int(charge['amount']/100),
'paid': charge['paid'],
'refunded': int(charge['amount_refunded']/100) > 0,
'amount_refunded': int(charge['amount_refunded']/100),
'last_4': charge['payment_method_details']['card']['last4'],
'receipt_link': charge['receipt_url'],
'timestamp': datetime.fromtimestamp(charge['created']).strftime("%m-%d-%Y %I:%M %p"),
} for charge in charges.auto_paging_iter()
"email": charge["billing_details"]["email"],
"amount_charged": int(charge["amount"] / 100),
"paid": charge["paid"],
"refunded": int(charge["amount_refunded"] / 100) > 0,
"amount_refunded": int(charge["amount_refunded"] / 100),
"last_4": charge["payment_method_details"]["card"]["last4"],
"receipt_link": charge["receipt_url"],
"timestamp": datetime.fromtimestamp(
charge["created"]
).strftime("%m-%d-%Y %I:%M %p"),
}
for charge in charges.auto_paging_iter()
]
cache.set(key, billing_history, 60*60)
cache.set(key, billing_history, 60 * 60)
return Response(billing_history, status=status.HTTP_200_OK)

View file

@ -1,3 +1,3 @@
from .celery import app as celery_app
__all__ = ('celery_app',)
__all__ = ("celery_app",)

View file

@ -11,6 +11,6 @@ import os
from django.core.asgi import get_asgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
application = get_asgi_application()

View file

@ -3,15 +3,15 @@ import os
# Set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
app = Celery('config')
app = Celery("config")
# Using a string here means the worker doesn't have to serialize
# the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix.
app.config_from_object('django.conf:settings', namespace='CELERY')
app.config_from_object("django.conf:settings", namespace="CELERY")
# Load task modules from all registered Django apps.
app.autodiscover_tasks()

View file

@ -10,9 +10,9 @@ BASE_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
# If you're hosting this with a secret provider, have this set to True
USE_VAULT = bool(os.getenv('USE_VAULT', False) == 'True')
USE_VAULT = bool(os.getenv("USE_VAULT", False) == "True")
# Have this set to True to serve media and static contents directly via Django
SERVE_MEDIA = bool(os.getenv('SERVE_MEDIA', False) == 'True')
SERVE_MEDIA = bool(os.getenv("SERVE_MEDIA", False) == "True")
load_dotenv(find_dotenv())
@ -35,98 +35,97 @@ def get_secret(secret_name):
# URL Prefixes
URL_SCHEME = 'https' if (get_secret('USE_HTTPS') == 'True') else 'http'
URL_SCHEME = "https" if (get_secret("USE_HTTPS") == "True") else "http"
# Backend
BACKEND_ADDRESS = get_secret('BACKEND_ADDRESS')
BACKEND_PORT = get_secret('BACKEND_PORT')
BACKEND_ADDRESS = get_secret("BACKEND_ADDRESS")
BACKEND_PORT = get_secret("BACKEND_PORT")
# Frontend
FRONTEND_ADDRESS = get_secret('FRONTEND_ADDRESS')
FRONTEND_PORT = get_secret('FRONTEND_PORT')
FRONTEND_ADDRESS = get_secret("FRONTEND_ADDRESS")
FRONTEND_PORT = get_secret("FRONTEND_PORT")
ALLOWED_HOSTS = ['*']
ALLOWED_HOSTS = ["*"]
CSRF_TRUSTED_ORIGINS = [
# Frontend
f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}',
f'{URL_SCHEME}://{FRONTEND_ADDRESS}', # For external domains
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}",
f"{URL_SCHEME}://{FRONTEND_ADDRESS}", # For external domains
# Backend
f'{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}',
f'{URL_SCHEME}://{BACKEND_ADDRESS}' # For external domains
f"{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}",
f"{URL_SCHEME}://{BACKEND_ADDRESS}", # For external domains
# You can also set up https://*.name.xyz for wildcards here
]
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = (get_secret('BACKEND_DEBUG') == 'True')
DEBUG = get_secret("BACKEND_DEBUG") == "True"
# Determines whether or not to insert test data within tables
SEED_DATA = (get_secret('SEED_DATA') == 'True')
SEED_DATA = get_secret("SEED_DATA") == "True"
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = get_secret('SECRET_KEY')
SECRET_KEY = get_secret("SECRET_KEY")
# Selenium Config
# Initiate CAPTCHA solver in test mode
CAPTCHA_TESTING = (get_secret('CAPTCHA_TESTING') == 'True')
CAPTCHA_TESTING = get_secret("CAPTCHA_TESTING") == "True"
# If using Selenium and/or the provided CAPTCHA solver, determines whether or not to use proxies
USE_PROXY = (get_secret('USE_PROXY') == 'True')
USE_PROXY = get_secret("USE_PROXY") == "True"
# Stripe (For payments)
STRIPE_SECRET_KEY = get_secret(
"STRIPE_SECRET_KEY")
STRIPE_SECRET_WEBHOOK = get_secret('STRIPE_SECRET_WEBHOOK')
STRIPE_CHECKOUT_URL = f''
STRIPE_SECRET_KEY = get_secret("STRIPE_SECRET_KEY")
STRIPE_SECRET_WEBHOOK = get_secret("STRIPE_SECRET_WEBHOOK")
STRIPE_CHECKOUT_URL = f""
# Email credentials
EMAIL_HOST = get_secret('EMAIL_HOST')
EMAIL_HOST_USER = get_secret('EMAIL_HOST_USER')
EMAIL_HOST_PASSWORD = get_secret('EMAIL_HOST_PASSWORD')
EMAIL_PORT = get_secret('EMAIL_PORT')
EMAIL_USE_TLS = (get_secret('EMAIL_USE_TLS') == 'True')
EMAIL_ADDRESS = (get_secret('EMAIL_ADDRESS') == 'True')
EMAIL_HOST = get_secret("EMAIL_HOST")
EMAIL_HOST_USER = get_secret("EMAIL_HOST_USER")
EMAIL_HOST_PASSWORD = get_secret("EMAIL_HOST_PASSWORD")
EMAIL_PORT = get_secret("EMAIL_PORT")
EMAIL_USE_TLS = get_secret("EMAIL_USE_TLS") == "True"
EMAIL_ADDRESS = get_secret("EMAIL_ADDRESS") == "True"
# Application definition
INSTALLED_APPS = [
'config',
'unfold',
'unfold.contrib.filters',
'unfold.contrib.simple_history',
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'storages',
'django_extensions',
'rest_framework',
'rest_framework_simplejwt',
'django_celery_results',
'django_celery_beat',
'simple_history',
'djoser',
'corsheaders',
'drf_spectacular',
'drf_spectacular_sidecar',
'webdriver',
'accounts',
'user_groups',
'subscriptions',
'payments',
'billing',
'emails',
'notifications',
'search_results'
"config",
"unfold",
"unfold.contrib.filters",
"unfold.contrib.simple_history",
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"storages",
"django_extensions",
"rest_framework",
"rest_framework_simplejwt",
"django_celery_results",
"django_celery_beat",
"simple_history",
"djoser",
"corsheaders",
"drf_spectacular",
"drf_spectacular_sidecar",
"webdriver",
"accounts",
"user_groups",
"subscriptions",
"payments",
"billing",
"emails",
"notifications",
"search_results",
]
if DEBUG:
INSTALLED_APPS += ['silk']
INSTALLED_APPS += ["silk"]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
"django.middleware.security.SecurityMiddleware",
"silk.middleware.SilkyMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"corsheaders.middleware.CorsMiddleware",
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
DJANGO_LOG_LEVEL = "DEBUG"
# Enables VS Code debugger to break on raised exceptions
@ -146,111 +145,101 @@ if DEBUG:
}
else:
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
"django.middleware.security.SecurityMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"corsheaders.middleware.CorsMiddleware",
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/
ROOT_URLCONF = 'config.urls'
ROOT_URLCONF = "config.urls"
if SERVE_MEDIA:
# Cloud Storage Settings
# This is assuming you use the same bucket for media and static containers
CLOUD_BUCKET = get_secret('CLOUD_BUCKET')
MEDIA_CONTAINER = get_secret('MEDIA_CONTAINER')
STATIC_CONTAINER = get_secret('STATIC_CONTAINER')
CLOUD_BUCKET = get_secret("CLOUD_BUCKET")
MEDIA_CONTAINER = get_secret("MEDIA_CONTAINER")
STATIC_CONTAINER = get_secret("STATIC_CONTAINER")
MEDIA_URL = f'https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/'
MEDIA_ROOT = f'https://{CLOUD_BUCKET}/'
MEDIA_URL = f"https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/"
MEDIA_ROOT = f"https://{CLOUD_BUCKET}/"
STATIC_URL = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/'
STATIC_ROOT = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/'
STATIC_URL = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
STATIC_ROOT = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
# Consult django-storages documentation when filling in these values. This will vary depending on your cloud service provider
STORAGES = {
'default': {
"default": {
# TODO: Set this up here if you're using cloud storage
'BACKEND': None,
'OPTIONS': {
"BACKEND": None,
"OPTIONS": {
# Optional parameters
},
},
'staticfiles': {
"staticfiles": {
# TODO: Set this up here if you're using cloud storage
'BACKEND': None,
'OPTIONS': {
"BACKEND": None,
"OPTIONS": {
# Optional parameters
},
},
}
else:
STATIC_URL = 'static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
STATIC_URL = "static/"
STATIC_ROOT = os.path.join(BASE_DIR, "static")
STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage"
MEDIA_URL = 'api/v1/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
ROOT_URLCONF = 'config.urls'
MEDIA_URL = "api/v1/media/"
MEDIA_ROOT = os.path.join(BASE_DIR, "media")
ROOT_URLCONF = "config.urls"
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [
BASE_DIR / 'emails/templates/',
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [
BASE_DIR / "emails/templates/",
],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
"django.template.context_processors.debug",
"django.template.context_processors.request",
"django.contrib.auth.context_processors.auth",
"django.contrib.messages.context_processors.messages",
],
},
},
]
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework_simplejwt.authentication.JWTAuthentication',
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework_simplejwt.authentication.JWTAuthentication",
),
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle'
"DEFAULT_THROTTLE_CLASSES": [
"rest_framework.throttling.AnonRateThrottle",
"rest_framework.throttling.UserRateThrottle",
],
'DEFAULT_THROTTLE_RATES': {
'anon': '360/min',
'user': '1440/min'
},
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
"DEFAULT_THROTTLE_RATES": {"anon": "360/min", "user": "1440/min"},
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
}
# DRF-Spectacular
SPECTACULAR_SETTINGS = {
'TITLE': 'DRF-Template',
'DESCRIPTION': 'A Template Project by Keannu Bernasol',
'VERSION': '1.0.0',
'SERVE_INCLUDE_SCHEMA': False,
'SWAGGER_UI_DIST': 'SIDECAR',
'SWAGGER_UI_FAVICON_HREF': 'SIDECAR',
'REDOC_DIST': 'SIDECAR',
"TITLE": "DRF-Template",
"DESCRIPTION": "A Template Project by Keannu Bernasol",
"VERSION": "1.0.0",
"SERVE_INCLUDE_SCHEMA": False,
"SWAGGER_UI_DIST": "SIDECAR",
"SWAGGER_UI_FAVICON_HREF": "SIDECAR",
"REDOC_DIST": "SIDECAR",
}
WSGI_APPLICATION = 'config.wsgi.application'
WSGI_APPLICATION = "config.wsgi.application"
# If you're using an external connection bouncer (eg. PgBouncer), server side cursors must be disabled to avoid any issues
USE_BOUNCER = get_secret("USE_BOUNCER")
@ -266,15 +255,13 @@ else:
DATABASES = {
"default": {
"ENGINE": "django.db.backends.postgresql",
'DISABLE_SERVER_SIDE_CURSORS': DISABLE_SERVER_SIDE_CURSORS,
"DISABLE_SERVER_SIDE_CURSORS": DISABLE_SERVER_SIDE_CURSORS,
"NAME": get_secret("DB_DATABASE"),
"USER": get_secret("DB_USERNAME"),
"PASSWORD": get_secret("DB_PASSWORD"),
"HOST": DB_HOST,
"PORT": DB_PORT,
"OPTIONS": {
"sslmode": get_secret("DB_SSL_MODE")
},
"OPTIONS": {"sslmode": get_secret("DB_SSL_MODE")},
}
}
# Django Cache
@ -284,34 +271,34 @@ CACHES = {
"LOCATION": f"redis://{get_secret('REDIS_HOST')}:{get_secret('REDIS_PORT')}/2",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
}
}
AUTH_USER_MODEL = 'accounts.CustomUser'
AUTH_USER_MODEL = "accounts.CustomUser"
DJOSER = {
'SEND_ACTIVATION_EMAIL': True,
'SEND_CONFIRMATION_EMAIL': True,
'PASSWORD_RESET_CONFIRM_URL': 'reset_password_confirm/{uid}/{token}',
'ACTIVATION_URL': 'activation/{uid}/{token}',
'USER_AUTHENTICATION_RULES': ['djoser.authentication.TokenAuthenticationRule'],
'EMAIL': {
'activation': 'emails.templates.ActivationEmail',
'password_reset': 'emails.templates.PasswordResetEmail'
"SEND_ACTIVATION_EMAIL": True,
"SEND_CONFIRMATION_EMAIL": True,
"PASSWORD_RESET_CONFIRM_URL": "reset_password_confirm/{uid}/{token}",
"ACTIVATION_URL": "activation/{uid}/{token}",
"USER_AUTHENTICATION_RULES": ["djoser.authentication.TokenAuthenticationRule"],
"EMAIL": {
"activation": "emails.templates.ActivationEmail",
"password_reset": "emails.templates.PasswordResetEmail",
},
'SERIALIZERS': {
'user': 'accounts.serializers.CustomUserSerializer',
'current_user': 'accounts.serializers.CustomUserSerializer',
'user_create': 'accounts.serializers.UserRegistrationSerializer',
"SERIALIZERS": {
"user": "accounts.serializers.CustomUserSerializer",
"current_user": "accounts.serializers.CustomUserSerializer",
"user_create": "accounts.serializers.UserRegistrationSerializer",
},
'PERMISSIONS': {
"PERMISSIONS": {
# Disable some unneeded endpoints by setting them to admin only
'username_reset': ['rest_framework.permissions.IsAdminUser'],
'username_reset_confirm': ['rest_framework.permissions.IsAdminUser'],
'set_username': ['rest_framework.permissions.IsAdminUser'],
'set_password': ['rest_framework.permissions.IsAdminUser'],
}
"username_reset": ["rest_framework.permissions.IsAdminUser"],
"username_reset_confirm": ["rest_framework.permissions.IsAdminUser"],
"set_username": ["rest_framework.permissions.IsAdminUser"],
"set_password": ["rest_framework.permissions.IsAdminUser"],
},
}
# Password validation
@ -319,32 +306,32 @@ DJOSER = {
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
"OPTIONS": {
"min_length": 8,
}
},
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
# Additional password validators
{
'NAME': 'accounts.validators.SpecialCharacterValidator',
"NAME": "accounts.validators.SpecialCharacterValidator",
},
{
'NAME': 'accounts.validators.LowercaseValidator',
"NAME": "accounts.validators.LowercaseValidator",
},
{
'NAME': 'accounts.validators.UppercaseValidator',
"NAME": "accounts.validators.UppercaseValidator",
},
{
'NAME': 'accounts.validators.NumberValidator',
"NAME": "accounts.validators.NumberValidator",
},
]
@ -352,9 +339,9 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization
# https://docs.djangoproject.com/en/4.2/topics/i18n/
LANGUAGE_CODE = 'en-us'
LANGUAGE_CODE = "en-us"
TIME_ZONE = get_secret('TIMEZONE')
TIME_ZONE = get_secret("TIMEZONE")
USE_I18N = True
@ -364,14 +351,14 @@ USE_TZ = True
# Default primary key field type
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
SITE_NAME = 'DRF-Template'
SITE_NAME = "DRF-Template"
# JWT Token Lifetimes
SIMPLE_JWT = {
"ACCESS_TOKEN_LIFETIME": timedelta(hours=1),
"REFRESH_TOKEN_LIFETIME": timedelta(days=3)
"REFRESH_TOKEN_LIFETIME": timedelta(days=3),
}
CORS_ALLOW_ALL_ORIGINS = True
@ -388,11 +375,19 @@ CELERY_RESULT_BACKEND = get_secret("CELERY_RESULT_BACKEND")
CELERY_RESULT_EXTENDED = True
# Celery Beat Options
CELERY_BEAT_SCHEDULER = 'django_celery_beat.schedulers:DatabaseScheduler'
CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler"
# Maximum number of rows that can be updated within the Django admin panel
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20480
GRAPH_MODELS = {
'app_labels': ['accounts', 'user_groups', 'billing', 'emails', 'payments', 'subscriptions', 'search_results']
"app_labels": [
"accounts",
"user_groups",
"billing",
"emails",
"payments",
"subscriptions",
"search_results",
]
}

View file

@ -1,5 +1,5 @@
from django.urls import path, include
urlpatterns = [
path('api/v1/', include('api.urls')),
path("api/v1/", include("api.urls")),
]

View file

@ -11,6 +11,6 @@ import os
from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
application = get_wsgi_application()

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class EmailsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'emails'
default_auto_field = "django.db.models.BigAutoField"
name = "emails"

View file

@ -3,11 +3,11 @@ from django.utils import timezone
class ActivationEmail(email.ActivationEmail):
template_name = 'email_activation.html'
template_name = "email_activation.html"
class PasswordResetEmail(email.PasswordResetEmail):
template_name = 'password_change.html'
template_name = "password_change.html"
class SubscriptionAvailedEmail(email.BaseEmailMessage):
@ -19,7 +19,7 @@ class SubscriptionAvailedEmail(email.BaseEmailMessage):
context["subscription_plan"] = context.get("subscription_plan")
context["subscription"] = context.get("subscription")
context["price_paid"] = context.get("price_paid")
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
context.update(self.context)
return context
@ -32,7 +32,7 @@ class SubscriptionRefundedEmail(email.BaseEmailMessage):
context["user"] = context.get("user")
context["subscription_plan"] = context.get("subscription_plan")
context["refund"] = context.get("refund")
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
context.update(self.context)
return context
@ -44,6 +44,6 @@ class SubscriptionCancelledEmail(email.BaseEmailMessage):
context = super().get_context_data()
context["user"] = context.get("user")
context["subscription_plan"] = context.get("subscription_plan")
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
context.update(self.context)
return context

View file

@ -6,7 +6,7 @@ import sys
def main():
"""Run administrative tasks."""
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
@ -18,5 +18,5 @@ def main():
execute_from_command_line(sys.argv)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -6,5 +6,5 @@ from .models import Notification
@admin.register(Notification)
class NotificationAdmin(ModelAdmin):
model = Notification
search_fields = ('id', 'content')
list_display = ['id', 'dismissed']
search_fields = ("id", "content")
list_display = ["id", "dismissed"]

View file

@ -2,8 +2,8 @@ from django.apps import AppConfig
class NotificationsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'notifications'
default_auto_field = "django.db.models.BigAutoField"
name = "notifications"
def ready(self):
import notifications.signals

View file

@ -15,13 +15,27 @@ class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
name='Notification',
name="Notification",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('content', models.CharField(max_length=1000, null=True)),
('timestamp', models.DateTimeField(auto_now_add=True)),
('dismissed', models.BooleanField(default=False)),
('recipient', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("content", models.CharField(max_length=1000, null=True)),
("timestamp", models.DateTimeField(auto_now_add=True)),
("dismissed", models.BooleanField(default=False)),
(
"recipient",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
),
]

View file

@ -2,8 +2,7 @@ from django.db import models
class Notification(models.Model):
recipient = models.ForeignKey(
'accounts.CustomUser', on_delete=models.CASCADE)
recipient = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
content = models.CharField(max_length=1000, null=True)
timestamp = models.DateTimeField(auto_now_add=True, editable=False)
dismissed = models.BooleanField(default=False)

View file

@ -3,10 +3,9 @@ from notifications.models import Notification
class NotificationSerializer(serializers.ModelSerializer):
timestamp = serializers.DateTimeField(
format="%m-%d-%Y %I:%M %p", read_only=True)
timestamp = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
class Meta:
model = Notification
fields = '__all__'
read_only_fields = ('id', 'recipient', 'content', 'timestamp')
fields = "__all__"
read_only_fields = ("id", "recipient", "content", "timestamp")

View file

@ -9,5 +9,5 @@ from django.core.cache import cache
@receiver(post_save, sender=Notification)
def clear_cache_after_notification_update(sender, instance, **kwargs):
# Clear cache
cache.delete('notifications')
cache.delete(f'notifications_user:{instance.recipient.id}')
cache.delete("notifications")
cache.delete(f"notifications_user:{instance.recipient.id}")

View file

@ -9,5 +9,4 @@ def cleanup_notifications():
three_days_ago = timezone.now() - timezone.timedelta(days=3)
# Delete notifications that are older than 3 days and dismissed
Notification.objects.filter(
dismissed=True, timestamp__lte=three_days_ago).delete()
Notification.objects.filter(dismissed=True, timestamp__lte=three_days_ago).delete()

View file

@ -3,8 +3,7 @@ from notifications.views import NotificationViewSet
from rest_framework.routers import DefaultRouter
router = DefaultRouter()
router.register(r'', NotificationViewSet,
basename="Notifications")
router.register(r"", NotificationViewSet, basename="Notifications")
urlpatterns = [
path('', include(router.urls)),
path("", include(router.urls)),
]

View file

@ -6,30 +6,33 @@ from django.core.cache import cache
class NotificationViewSet(viewsets.ModelViewSet):
http_method_names = ['get', 'patch', 'delete']
http_method_names = ["get", "patch", "delete"]
serializer_class = NotificationSerializer
queryset = Notification.objects.all()
def get_queryset(self):
user = self.request.user
key = f'notifications_user:{user.id}'
key = f"notifications_user:{user.id}"
queryset = cache.get(key)
if not queryset:
queryset = Notification.objects.filter(
recipient=user).order_by('-timestamp')
cache.set(key, queryset, 60*60)
queryset = Notification.objects.filter(recipient=user).order_by(
"-timestamp"
)
cache.set(key, queryset, 60 * 60)
return queryset
def update(self, request, *args, **kwargs):
instance = self.get_object()
if instance.recipient != request.user:
raise PermissionDenied(
"You do not have permission to update this notification.")
"You do not have permission to update this notification."
)
return super().update(request, *args, **kwargs)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
if instance.recipient != request.user:
raise PermissionDenied(
"You do not have permission to delete this notification.")
"You do not have permission to delete this notification."
)
return super().destroy(request, *args, **kwargs)

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class PaymentsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'payments'
default_auto_field = "django.db.models.BigAutoField"
name = "payments"

View file

@ -3,6 +3,6 @@ from payments import views
urlpatterns = [
path('checkout_session/', views.StripeCheckoutView.as_view()),
path('webhook/', views.stripe_webhook_view, name='Stripe Webhook'),
path("checkout_session/", views.StripeCheckoutView.as_view()),
path("webhook/", views.stripe_webhook_view, name="Stripe Webhook"),
]

View file

@ -1,4 +1,10 @@
from config.settings import STRIPE_SECRET_KEY, STRIPE_SECRET_WEBHOOK, URL_SCHEME, FRONTEND_ADDRESS, FRONTEND_PORT
from config.settings import (
STRIPE_SECRET_KEY,
STRIPE_SECRET_WEBHOOK,
URL_SCHEME,
FRONTEND_ADDRESS,
FRONTEND_PORT,
)
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
from rest_framework.response import Response
@ -12,16 +18,19 @@ from accounts.models import CustomUser
from rest_framework.decorators import api_view
from subscriptions.tasks import get_user_subscription
import json
from emails.templates import SubscriptionAvailedEmail, SubscriptionRefundedEmail, SubscriptionCancelledEmail
from emails.templates import (
SubscriptionAvailedEmail,
SubscriptionRefundedEmail,
SubscriptionCancelledEmail,
)
from django.core.cache import cache
from payments.serializers import CheckoutSerializer
from drf_spectacular.utils import extend_schema
stripe.api_key = STRIPE_SECRET_KEY
@extend_schema(
request=CheckoutSerializer
)
@extend_schema(request=CheckoutSerializer)
class StripeCheckoutView(APIView):
permission_classes = [IsAuthenticated]
@ -30,41 +39,46 @@ class StripeCheckoutView(APIView):
# Get subscription ID from POST
USER = CustomUser.objects.get(id=self.request.user.id)
data = json.loads(request.body)
subscription_id = data.get('subscription_id')
annual = data.get('annual')
subscription_id = data.get("subscription_id")
annual = data.get("annual")
# Validation for subscription_id field
try:
subscription_id = int(subscription_id)
except:
return Response({
'error': 'Invalid value specified in subscription_id field'
}, status=status.HTTP_403_FORBIDDEN)
return Response(
{"error": "Invalid value specified in subscription_id field"},
status=status.HTTP_403_FORBIDDEN,
)
# Validation for annual field
try:
annual = bool(annual)
except:
return Response({
'error': 'Invalid value specified in annual field'
}, status=status.HTTP_403_FORBIDDEN)
return Response(
{"error": "Invalid value specified in annual field"},
status=status.HTTP_403_FORBIDDEN,
)
# Return an error if the user already has an active subscription
EXISTING_SUBSCRIPTION = get_user_subscription(USER.id)
if EXISTING_SUBSCRIPTION:
return Response({
'error': f'User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}'
}, status=status.HTTP_403_FORBIDDEN)
return Response(
{
"error": f"User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}"
},
status=status.HTTP_403_FORBIDDEN,
)
# Attempt to query the subscription
SUBSCRIPTION = SubscriptionPlan.objects.filter(
id=subscription_id).first()
SUBSCRIPTION = SubscriptionPlan.objects.filter(id=subscription_id).first()
# Return an error if the plan does not exist
if not SUBSCRIPTION:
return Response({
'error': 'Subscription plan not found'
}, status=status.HTTP_404_NOT_FOUND)
return Response(
{"error": "Subscription plan not found"},
status=status.HTTP_404_NOT_FOUND,
)
# Get the stripe_price_id from the related StripePrice instances
if annual:
@ -74,52 +88,58 @@ class StripeCheckoutView(APIView):
# Return 404 if no price is set
if not PRICE:
return Response({
'error': 'Specified price does not exist for plan'
}, status=status.HTTP_404_NOT_FOUND)
return Response(
{"error": "Specified price does not exist for plan"},
status=status.HTTP_404_NOT_FOUND,
)
PRICE_ID = PRICE.stripe_price_id
prorated = PRICE.prorated
# Return an error if a user is in a user_group and is availing pro-rated plans
if not USER.user_group and SUBSCRIPTION.group_exclusive:
return Response({
'error': 'Regular users cannot avail prorated plans'
}, status=status.HTTP_403_FORBIDDEN)
return Response(
{"error": "Regular users cannot avail prorated plans"},
status=status.HTTP_403_FORBIDDEN,
)
success_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \
'/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}'
cancel_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \
'/user/subscription/payment?success=false&user_group=False'
success_url = (
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
+ "/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}"
)
cancel_url = (
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
+ "/user/subscription/payment?success=false&user_group=False"
)
checkout_session = stripe.checkout.Session.create(
line_items=[
{
'price': PRICE_ID,
'quantity': 1
} if not prorated else
{
'price': PRICE_ID,
}
(
{"price": PRICE_ID, "quantity": 1}
if not prorated
else {
"price": PRICE_ID,
}
)
],
mode='subscription',
payment_method_types=['card'],
mode="subscription",
payment_method_types=["card"],
success_url=success_url,
cancel_url=cancel_url,
)
return Response({"url": checkout_session.url})
except Exception as e:
logging.error(str(e))
return Response({
'error': str(e)
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@ api_view(['POST'])
@ csrf_exempt
@api_view(["POST"])
@csrf_exempt
def stripe_webhook_view(request):
payload = request.body
sig_header = request.META['HTTP_STRIPE_SIGNATURE']
sig_header = request.META["HTTP_STRIPE_SIGNATURE"]
event = None
try:
@ -133,12 +153,12 @@ def stripe_webhook_view(request):
# Invalid signature
return Response(status=401)
if event['type'] == 'customer.subscription.created':
subscription = event['data']['object']
if event["type"] == "customer.subscription.created":
subscription = event["data"]["object"]
# Get the Invoice object from the Subscription object
invoice = stripe.Invoice.retrieve(subscription['latest_invoice'])
invoice = stripe.Invoice.retrieve(subscription["latest_invoice"])
# Get the Charge object from the Invoice object
charge = stripe.Charge.retrieve(invoice['charge'])
charge = stripe.Charge.retrieve(invoice["charge"])
# Get paying user
customer = stripe.Customer.retrieve(subscription["customer"])
@ -146,18 +166,20 @@ def stripe_webhook_view(request):
product = subscription["items"]["data"][0]
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
stripe_product_id=product["plan"]["product"])
stripe_product_id=product["plan"]["product"]
)
SUBSCRIPTION = UserSubscription.objects.create(
subscription=SUBSCRIPTION_PLAN,
annual=product["plan"]["interval"] == "year",
valid=True,
user=USER,
stripe_id=subscription['id'])
stripe_id=subscription["id"],
)
email = SubscriptionAvailedEmail()
paid = {
"amount": charge['amount']/100,
"currency": str(charge['currency']).upper()
"amount": charge["amount"] / 100,
"currency": str(charge["currency"]).upper(),
}
email.context = {
@ -169,19 +191,20 @@ def stripe_webhook_view(request):
email.send(to=[customer.email])
# Clear cache
cache.delete(f'billing_user:{USER.id}')
cache.delete(f'subscriptions_user:{USER.id}')
cache.delete(f"billing_user:{USER.id}")
cache.delete(f"subscriptions_user:{USER.id}")
# On chargebacks/refunds, invalidate the subscription
elif event['type'] == 'charge.refunded':
charge = event['data']['object']
elif event["type"] == "charge.refunded":
charge = event["data"]["object"]
# Get the Invoice object from the Charge object
invoice = stripe.Invoice.retrieve(charge['invoice'])
invoice = stripe.Invoice.retrieve(charge["invoice"])
# Check if the subscription exists
SUBSCRIPTION = UserSubscription.objects.filter(
stripe_id=invoice['subscription']).first()
stripe_id=invoice["subscription"]
).first()
if not (SUBSCRIPTION):
return HttpResponse(status=404)
@ -196,8 +219,8 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
refund = {
"amount": charge['amount_refunded']/100,
"currency": str(charge['currency']).upper()
"amount": charge["amount_refunded"] / 100,
"currency": str(charge["currency"]).upper(),
}
# Send an email
@ -206,13 +229,13 @@ def stripe_webhook_view(request):
email.context = {
"user": USER,
"subscription_plan": SUBSCRIPTION_PLAN,
"refund": refund
"refund": refund,
}
email.send(to=[USER.email])
# Clear cache
cache.delete(f'billing_user:{USER.id}')
cache.delete(f"billing_user:{USER.id}")
elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner
@ -223,8 +246,8 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
refund = {
"amount": charge['amount_refunded']/100,
"currency": str(charge['currency']).upper()
"amount": charge["amount_refunded"] / 100,
"currency": str(charge["currency"]).upper(),
}
# Send en email
@ -233,36 +256,38 @@ def stripe_webhook_view(request):
email.context = {
"user": OWNER,
"subscription_plan": SUBSCRIPTION_PLAN,
"refund": refund
"refund": refund,
}
email.send(to=[OWNER.email])
# Clear cache
cache.delete(f'billing_user:{USER.id}')
cache.delete(f'subscriptions_user:{USER.id}')
cache.delete(f"billing_user:{USER.id}")
cache.delete(f"subscriptions_user:{USER.id}")
elif event['type'] == 'customer.subscription.updated':
subscription = event['data']['object']
elif event["type"] == "customer.subscription.updated":
subscription = event["data"]["object"]
# Check if the subscription exists
SUBSCRIPTION = UserSubscription.objects.filter(
stripe_id=subscription['id']).first()
stripe_id=subscription["id"]
).first()
if not (SUBSCRIPTION):
return HttpResponse(status=404)
# Check if a subscription has been upgraded/downgraded
new_stripe_product_id = subscription['items']['data'][0]['plan']['product']
new_stripe_product_id = subscription["items"]["data"][0]["plan"]["product"]
current_stripe_product_id = SUBSCRIPTION.subscription.stripe_product_id
if new_stripe_product_id != current_stripe_product_id:
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
stripe_product_id=new_stripe_product_id)
stripe_product_id=new_stripe_product_id
)
SUBSCRIPTION.subscription = SUBSCRIPTION_PLAN
SUBSCRIPTION.save()
# TODO: Add a plan upgraded email message here
# Subscription activation/reactivation
if subscription['status'] == 'active':
if subscription["status"] == "active":
SUBSCRIPTION.valid = True
SUBSCRIPTION.save()
@ -270,26 +295,24 @@ def stripe_webhook_view(request):
USER = SUBSCRIPTION.user
# Clear cache
cache.delete(f'billing_user:{USER.id}')
cache.delete(
f'subscriptions_user:{USER.id}')
cache.delete(f"billing_user:{USER.id}")
cache.delete(f"subscriptions_user:{USER.id}")
elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner
# Clear cache
cache.delete(f'billing_user:{OWNER.id}')
cache.delete(
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
cache.delete(f"billing_user:{OWNER.id}")
cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
# TODO: Add notification here to inform users if their plan has been reactivated
elif subscription['status'] == 'past_due':
elif subscription["status"] == "past_due":
# TODO: Add notification here to inform users if their payment method for an existing subscription payment is failing
pass
# If subscriptions get cancelled due to non-payment, invalidate the UserSubscription
elif subscription['status'] == 'cancelled':
elif subscription["status"] == "cancelled":
if SUBSCRIPTION.user:
USER = SUBSCRIPTION.user
@ -310,8 +333,8 @@ def stripe_webhook_view(request):
email.send(to=[USER.email])
# Clear cache
cache.delete(f'billing_user:{USER.id}')
cache.delete(f'subscriptions_user:{USER.id}')
cache.delete(f"billing_user:{USER.id}")
cache.delete(f"subscriptions_user:{USER.id}")
elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner
@ -325,24 +348,21 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
email.context = {
"user": OWNER,
"subscription_plan": SUBSCRIPTION_PLAN
}
email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
email.send(to=[OWNER.email])
# Clear cache
cache.delete(f'billing_user:{OWNER.id}')
cache.delete(
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
cache.delete(f"billing_user:{OWNER.id}")
cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
# If a subscription gets cancelled, invalidate it
elif event['type'] == 'customer.subscription.deleted':
subscription = event['data']['object']
elif event["type"] == "customer.subscription.deleted":
subscription = event["data"]["object"]
# Check if the subscription exists
SUBSCRIPTION = UserSubscription.objects.filter(
stripe_id=subscription['id']).first()
stripe_id=subscription["id"]
).first()
if not (SUBSCRIPTION):
return HttpResponse(status=404)
@ -367,7 +387,7 @@ def stripe_webhook_view(request):
email.send(to=[USER.email])
# Clear cache
cache.delete(f'billing_user:{USER.id}')
cache.delete(f"billing_user:{USER.id}")
elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner
@ -381,14 +401,11 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
email.context = {
"user": OWNER,
"subscription_plan": SUBSCRIPTION_PLAN
}
email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
email.send(to=[OWNER.email])
# Clear cache
cache.delete(f'billing_user:{OWNER.id}')
cache.delete(f"billing_user:{OWNER.id}")
# Passed signature verification
return HttpResponse(status=200)

View file

@ -805,6 +805,9 @@ components:
first_name:
type: string
maxLength: 150
is_new:
type: string
readOnly: true
last_name:
type: string
maxLength: 150
@ -824,6 +827,7 @@ components:
- group_member
- group_owner
- id
- is_new
- user_group
- username
Notification:
@ -885,6 +889,9 @@ components:
first_name:
type: string
maxLength: 150
is_new:
type: string
readOnly: true
last_name:
type: string
maxLength: 150

View file

@ -7,12 +7,11 @@ from unfold.contrib.filters.admin import RangeDateFilter
@admin.register(SearchResult)
class SearchResultAdmin(ModelAdmin):
model = SearchResult
search_fields = ('id', 'title', 'link')
list_display = ['id', 'title', 'timestamp']
search_fields = ("id", "title", "link")
list_display = ["id", "title", "timestamp"]
list_filter_submit = True
list_filter = ((
"timestamp", RangeDateFilter
), (
"timestamp", RangeDateFilter
),)
list_filter = (
("timestamp", RangeDateFilter),
("timestamp", RangeDateFilter),
)

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class SearchResultsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'search_results'
default_auto_field = "django.db.models.BigAutoField"
name = "search_results"

View file

@ -7,17 +7,24 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
]
dependencies = []
operations = [
migrations.CreateModel(
name='SearchResult',
name="SearchResult",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(max_length=1000)),
('link', models.CharField(max_length=1000)),
('timestamp', models.DateTimeField(auto_now_add=True)),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("title", models.CharField(max_length=1000)),
("link", models.CharField(max_length=1000)),
("timestamp", models.DateTimeField(auto_now_add=True)),
],
),
]

View file

@ -1,16 +1,13 @@
from celery import shared_task
from .models import SearchResult
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 0, 'countdown': 5})
@shared_task(
autoretry_for=(Exception,), retry_kwargs={"max_retries": 0, "countdown": 5}
)
def create_search_result(title, link):
if SearchResult.objects.filter(title=title, link=link).exists():
return ("SearchResult entry already exists")
return "SearchResult entry already exists"
else:
SearchResult.objects.create(
title=title,
link=link
)
SearchResult.objects.create(title=title, link=link)
return f"Created new SearchResult entry titled: {title}"

View file

@ -6,10 +6,24 @@ from unfold.contrib.filters.admin import RangeDateFilter
@admin.register(StripePrice)
class StripePriceAdmin(ModelAdmin):
search_fields = ["id", "lookup_key",
"stripe_price_id","price","currency", "prorated", "annual"]
list_display = ["id", "lookup_key",
"stripe_price_id", "price", "currency", "prorated", "annual"]
search_fields = [
"id",
"lookup_key",
"stripe_price_id",
"price",
"currency",
"prorated",
"annual",
]
list_display = [
"id",
"lookup_key",
"stripe_price_id",
"price",
"currency",
"prorated",
"annual",
]
@admin.register(SubscriptionPlan)
@ -21,9 +35,6 @@ class SubscriptionPlanAdmin(ModelAdmin):
@admin.register(UserSubscription)
class UserSubscriptionAdmin(ModelAdmin):
list_filter_submit = True
list_filter = ((
"date", RangeDateFilter
),)
list_display = ["id", "__str__", "valid", "annual",
"date"]
list_filter = (("date", RangeDateFilter),)
list_display = ["id", "__str__", "valid", "annual", "date"]
search_fields = ["id", "date"]

View file

@ -2,8 +2,8 @@ from django.apps import AppConfig
class SubscriptionConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'subscriptions'
default_auto_field = "django.db.models.BigAutoField"
name = "subscriptions"
def ready(self):
import subscriptions.signals

View file

@ -11,46 +11,118 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
('user_groups', '0001_initial'),
("user_groups", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='StripePrice',
name="StripePrice",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('annual', models.BooleanField(default=False)),
('stripe_price_id', models.CharField(max_length=100)),
('price', models.DecimalField(decimal_places=2, default=0.0, max_digits=10)),
('currency', models.CharField(max_length=20)),
('lookup_key', models.CharField(blank=True, max_length=100, null=True)),
('prorated', models.BooleanField(default=False)),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("annual", models.BooleanField(default=False)),
("stripe_price_id", models.CharField(max_length=100)),
(
"price",
models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
),
("currency", models.CharField(max_length=20)),
("lookup_key", models.CharField(blank=True, max_length=100, null=True)),
("prorated", models.BooleanField(default=False)),
],
),
migrations.CreateModel(
name='SubscriptionPlan',
name="SubscriptionPlan",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
('description', models.TextField(max_length=1024, null=True)),
('stripe_product_id', models.CharField(max_length=100)),
('group_exclusive', models.BooleanField(default=False)),
('annual_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='annual_plan', to='subscriptions.stripeprice')),
('monthly_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='monthly_plan', to='subscriptions.stripeprice')),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=100)),
("description", models.TextField(max_length=1024, null=True)),
("stripe_product_id", models.CharField(max_length=100)),
("group_exclusive", models.BooleanField(default=False)),
(
"annual_price",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="annual_plan",
to="subscriptions.stripeprice",
),
),
(
"monthly_price",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="monthly_plan",
to="subscriptions.stripeprice",
),
),
],
),
migrations.CreateModel(
name='UserSubscription',
name="UserSubscription",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('stripe_id', models.CharField(max_length=100)),
('date', models.DateTimeField(default=django.utils.timezone.now, editable=False)),
('valid', models.BooleanField()),
('annual', models.BooleanField()),
('subscription', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='subscriptions.subscriptionplan')),
('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
('user_group', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='user_groups.usergroup')),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("stripe_id", models.CharField(max_length=100)),
(
"date",
models.DateTimeField(
default=django.utils.timezone.now, editable=False
),
),
("valid", models.BooleanField()),
("annual", models.BooleanField()),
(
"subscription",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="subscriptions.subscriptionplan",
),
),
(
"user",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
(
"user_group",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="user_groups.usergroup",
),
),
],
),
]

View file

@ -1,4 +1,3 @@
from django.db import models
from accounts.models import CustomUser
from user_groups.models import UserGroup
@ -25,9 +24,11 @@ class SubscriptionPlan(models.Model):
description = models.TextField(max_length=1024, null=True)
stripe_product_id = models.CharField(max_length=100)
annual_price = models.ForeignKey(
StripePrice, on_delete=models.SET_NULL, related_name='annual_plan', null=True)
StripePrice, on_delete=models.SET_NULL, related_name="annual_plan", null=True
)
monthly_price = models.ForeignKey(
StripePrice, on_delete=models.SET_NULL, related_name='monthly_plan', null=True)
StripePrice, on_delete=models.SET_NULL, related_name="monthly_plan", null=True
)
group_exclusive = models.BooleanField(default=False)
def __str__(self):
@ -39,11 +40,14 @@ class SubscriptionPlan(models.Model):
class UserSubscription(models.Model):
user = models.ForeignKey(
CustomUser, on_delete=models.CASCADE, blank=True, null=True)
CustomUser, on_delete=models.CASCADE, blank=True, null=True
)
user_group = models.ForeignKey(
UserGroup, on_delete=models.CASCADE, blank=True, null=True)
UserGroup, on_delete=models.CASCADE, blank=True, null=True
)
subscription = models.ForeignKey(
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True)
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True
)
stripe_id = models.CharField(max_length=100)
date = models.DateTimeField(default=now, editable=False)
valid = models.BooleanField()
@ -51,6 +55,6 @@ class UserSubscription(models.Model):
def __str__(self):
if self.user:
return f'Subscription {self.subscription.name} for {self.user}'
return f"Subscription {self.subscription.name} for {self.user}"
else:
return f'Subscription {self.subscription.name} for {self.user_group}'
return f"Subscription {self.subscription.name} for {self.user_group}"

View file

@ -7,38 +7,46 @@ class SimpleStripePriceSerializer(serializers.ModelSerializer):
class Meta:
model = StripePrice
fields = ['price', 'currency', 'prorated']
fields = ["price", "currency", "prorated"]
class SubscriptionPlanSerializer(serializers.ModelSerializer):
class Meta:
model = SubscriptionPlan
fields = ['id', 'name', 'description',
'annual_price', 'monthly_price', 'group_exclusive']
fields = [
"id",
"name",
"description",
"annual_price",
"monthly_price",
"group_exclusive",
]
def to_representation(self, instance):
representation = super().to_representation(instance)
representation['annual_price'] = SimpleStripePriceSerializer(
instance.annual_price, many=False).data
representation['monthly_price'] = SimpleStripePriceSerializer(
instance.monthly_price, many=False).data
representation["annual_price"] = SimpleStripePriceSerializer(
instance.annual_price, many=False
).data
representation["monthly_price"] = SimpleStripePriceSerializer(
instance.monthly_price, many=False
).data
return representation
class UserSubscriptionSerializer(serializers.ModelSerializer):
date = serializers.DateTimeField(
format="%m-%d-%Y %I:%M %p", read_only=True)
date = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
class Meta:
model = UserSubscription
fields = ['id', 'user', 'user_group', 'subscription',
'date', 'valid', 'annual']
fields = ["id", "user", "user_group", "subscription", "date", "valid", "annual"]
def to_representation(self, instance):
representation = super().to_representation(instance)
representation['user'] = SimpleCustomUserSerializer(
instance.user, many=False).data
representation['subscription'] = SubscriptionPlanSerializer(
instance.subscription, many=False).data
representation["user"] = SimpleCustomUserSerializer(
instance.user, many=False
).data
representation["subscription"] = SubscriptionPlanSerializer(
instance.subscription, many=False
).data
return representation

View file

@ -4,6 +4,7 @@ from .models import UserSubscription, StripePrice, SubscriptionPlan
from django.core.cache import cache
from config.settings import STRIPE_SECRET_KEY
import stripe
stripe.api_key = STRIPE_SECRET_KEY
# Template for running actions after user have paid for a subscription
@ -12,7 +13,7 @@ stripe.api_key = STRIPE_SECRET_KEY
@receiver(post_save, sender=SubscriptionPlan)
def clear_cache_after_plan_updates(sender, instance, **kwargs):
# Clear cache
cache.delete('subscriptionplans')
cache.delete("subscriptionplans")
@receiver(post_save, sender=UserSubscription)
@ -25,8 +26,8 @@ def scan_after_payment(sender, instance, **kwargs):
@receiver(post_migrate)
def create_subscriptions(sender, **kwargs):
if sender.name == 'subscriptions':
print('Importing data from Stripe')
if sender.name == "subscriptions":
print("Importing data from Stripe")
created_prices = 0
created_plans = 0
skipped_prices = 0
@ -35,16 +36,19 @@ def create_subscriptions(sender, **kwargs):
prices = stripe.Price.list(expand=["data.tiers"], active=True)
# Create the StripePrice
for price in prices['data']:
annual = (price['recurring']['interval'] ==
'year') if price['recurring'] else False
for price in prices["data"]:
annual = (
(price["recurring"]["interval"] == "year")
if price["recurring"]
else False
)
STRIPE_PRICE, CREATED = StripePrice.objects.get_or_create(
stripe_price_id=price['id'],
price=price['unit_amount'] / 100,
stripe_price_id=price["id"],
price=price["unit_amount"] / 100,
annual=annual,
lookup_key=price['lookup_key'],
prorated=price['recurring']['usage_type'] == 'metered',
currency=price['currency']
lookup_key=price["lookup_key"],
prorated=price["recurring"]["usage_type"] == "metered",
currency=price["currency"],
)
if CREATED:
created_prices += 1
@ -52,13 +56,13 @@ def create_subscriptions(sender, **kwargs):
skipped_prices += 1
# Create the SubscriptionPlan
for product in products['data']:
for product in products["data"]:
ANNUAL_PRICE = None
MONTHLY_PRICE = None
for price in prices['data']:
if price['product'] == product['id']:
for price in prices["data"]:
if price["product"] == product["id"]:
STRIPE_PRICE = StripePrice.objects.get(
stripe_price_id=price['id'],
stripe_price_id=price["id"],
)
if STRIPE_PRICE.annual:
ANNUAL_PRICE = STRIPE_PRICE
@ -66,12 +70,12 @@ def create_subscriptions(sender, **kwargs):
MONTHLY_PRICE = STRIPE_PRICE
if ANNUAL_PRICE or MONTHLY_PRICE:
SUBSCRIPTION_PLAN, CREATED = SubscriptionPlan.objects.get_or_create(
name=product['name'],
description=product['description'],
stripe_product_id=product['id'],
name=product["name"],
description=product["description"],
stripe_product_id=product["id"],
annual_price=ANNUAL_PRICE,
monthly_price=MONTHLY_PRICE,
group_exclusive=product['metadata']['group_exclusive'] == 'True'
group_exclusive=product["metadata"]["group_exclusive"] == "True",
)
if CREATED:
created_plans += 1
@ -79,13 +83,12 @@ def create_subscriptions(sender, **kwargs):
skipped_plans += 1
# Skip over plans with missing pricing rates
else:
print('Skipping plan' +
product['name'] + 'with missing pricing data')
print("Skipping plan" + product["name"] + "with missing pricing data")
# Assign the StripePrice to the SubscriptionPlan
SUBSCRIPTION_PLAN.save()
print('Created', created_plans, 'new plans')
print('Skipped', skipped_plans, 'existing plans')
print('Created', created_prices, 'new prices')
print('Skipped', skipped_prices, 'existing prices')
print("Created", created_plans, "new plans")
print("Skipped", skipped_plans, "existing plans")
print("Created", created_prices, "new prices")
print("Skipped", skipped_prices, "existing prices")

View file

@ -12,10 +12,10 @@ def get_user_subscription(user_id):
active_subscriptions = None
if USER.user_group:
active_subscriptions = UserSubscription.objects.filter(
user_group=USER.user_group, valid=True)
user_group=USER.user_group, valid=True
)
else:
active_subscriptions = UserSubscription.objects.filter(
user=USER, valid=True)
active_subscriptions = UserSubscription.objects.filter(user=USER, valid=True)
# Return first valid subscription if there is one
if len(active_subscriptions) > 0:
@ -33,7 +33,8 @@ def get_user_group_subscription(user_group):
# Get a list of subscriptions for the specified user
active_subscriptions = None
active_subscriptions = UserSubscription.objects.filter(
user_group=USER_GROUP, valid=True)
user_group=USER_GROUP, valid=True
)
# Return first valid subscription if there is one
if len(active_subscriptions) > 0:

View file

@ -3,12 +3,11 @@ from subscriptions import views
from rest_framework.routers import DefaultRouter
router = DefaultRouter()
router.register(r'plans', views.SubscriptionPlanViewset,
basename="Subscription Plans")
router.register(r'self', views.UserSubscriptionViewset,
basename="Self Subscriptions")
router.register(r'user_group', views.UserGroupSubscriptionViewet,
basename="Group Subscriptions")
router.register(r"plans", views.SubscriptionPlanViewset, basename="Subscription Plans")
router.register(r"self", views.UserSubscriptionViewset, basename="Self Subscriptions")
router.register(
r"user_group", views.UserGroupSubscriptionViewet, basename="Group Subscriptions"
)
urlpatterns = [
path('', include(router.urls)),
path("", include(router.urls)),
]

View file

@ -1,4 +1,7 @@
from subscriptions.serializers import SubscriptionPlanSerializer, UserSubscriptionSerializer
from subscriptions.serializers import (
SubscriptionPlanSerializer,
UserSubscriptionSerializer,
)
from subscriptions.models import SubscriptionPlan, UserSubscription
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework import viewsets
@ -6,38 +9,38 @@ from django.core.cache import cache
class SubscriptionPlanViewset(viewsets.ModelViewSet):
http_method_names = ['get']
http_method_names = ["get"]
serializer_class = SubscriptionPlanSerializer
permission_classes = [AllowAny]
queryset = SubscriptionPlan.objects.all()
def get_queryset(self):
key = 'subscriptionplans'
key = "subscriptionplans"
queryset = cache.get(key)
if not queryset:
queryset = super().get_queryset()
cache.set(key, queryset, 60*60)
cache.set(key, queryset, 60 * 60)
return queryset
class UserSubscriptionViewset(viewsets.ModelViewSet):
http_method_names = ['get']
http_method_names = ["get"]
serializer_class = UserSubscriptionSerializer
permission_classes = [IsAuthenticated]
queryset = UserSubscription.objects.all()
def get_queryset(self):
user = self.request.user
key = f'subscriptions_user:{user.id}'
key = f"subscriptions_user:{user.id}"
queryset = cache.get(key)
if not queryset:
queryset = UserSubscription.objects.filter(user=user)
cache.set(key, queryset, 60*60)
cache.set(key, queryset, 60 * 60)
return queryset
class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
http_method_names = ['get']
http_method_names = ["get"]
serializer_class = UserSubscriptionSerializer
permission_classes = [IsAuthenticated]
queryset = UserSubscription.objects.all()
@ -47,10 +50,9 @@ class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
if not user.user_group:
return UserSubscription.objects.none()
else:
key = f'subscriptions_usergroup:{user.user_group.id}'
key = f"subscriptions_usergroup:{user.user_group.id}"
queryset = cache.get(key)
if not cache:
queryset = UserSubscription.objects.filter(
user_group=user.user_group)
cache.set(key, queryset, 60*60)
queryset = UserSubscription.objects.filter(user_group=user.user_group)
cache.set(key, queryset, 60 * 60)
return queryset

View file

@ -7,9 +7,7 @@ from unfold.contrib.filters.admin import RangeDateFilter
@admin.register(UserGroup)
class UserGroupAdmin(ModelAdmin):
list_filter_submit = True
list_filter = ((
"date_created", RangeDateFilter
),)
list_filter = (("date_created", RangeDateFilter),)
list_display = ['id', 'name']
search_fields = ['id', 'name']
list_display = ["id", "name"]
search_fields = ["id", "name"]

View file

@ -8,16 +8,28 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
]
dependencies = []
operations = [
migrations.CreateModel(
name='UserGroup',
name="UserGroup",
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=128)),
('date_created', models.DateTimeField(default=django.utils.timezone.now, editable=False)),
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=128)),
(
"date_created",
models.DateTimeField(
default=django.utils.timezone.now, editable=False
),
),
],
),
]

View file

@ -8,24 +8,33 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('user_groups', '0001_initial'),
("user_groups", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AddField(
model_name='usergroup',
name='managers',
field=models.ManyToManyField(related_name='usergroup_managers', to=settings.AUTH_USER_MODEL),
model_name="usergroup",
name="managers",
field=models.ManyToManyField(
related_name="usergroup_managers", to=settings.AUTH_USER_MODEL
),
),
migrations.AddField(
model_name='usergroup',
name='members',
field=models.ManyToManyField(related_name='usergroup_members', to=settings.AUTH_USER_MODEL),
model_name="usergroup",
name="members",
field=models.ManyToManyField(
related_name="usergroup_members", to=settings.AUTH_USER_MODEL
),
),
migrations.AddField(
model_name='usergroup',
name='owner',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='usergroup_owner', to=settings.AUTH_USER_MODEL),
model_name="usergroup",
name="owner",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="usergroup_owner",
to=settings.AUTH_USER_MODEL,
),
),
]

View file

@ -2,17 +2,24 @@ from django.db import models
from django.utils.timezone import now
from config.settings import STRIPE_SECRET_KEY
import stripe
stripe.api_key = STRIPE_SECRET_KEY
class UserGroup(models.Model):
name = models.CharField(max_length=128, null=False)
owner = models.ForeignKey(
'accounts.CustomUser', on_delete=models.SET_NULL, null=True, related_name='usergroup_owner')
"accounts.CustomUser",
on_delete=models.SET_NULL,
null=True,
related_name="usergroup_owner",
)
managers = models.ManyToManyField(
'accounts.CustomUser', related_name='usergroup_managers')
"accounts.CustomUser", related_name="usergroup_managers"
)
members = models.ManyToManyField(
'accounts.CustomUser', related_name='usergroup_members')
"accounts.CustomUser", related_name="usergroup_members"
)
date_created = models.DateTimeField(default=now, editable=False)
# Derived from email of owner, may be used for billing

View file

@ -3,10 +3,9 @@ from .models import UserGroup
class SimpleUserGroupSerializer(serializers.ModelSerializer):
date_created = serializers.DateTimeField(
format="%m-%d-%Y %I:%M %p", read_only=True)
date_created = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
class Meta:
model = UserGroup
fields = ['id', 'name', 'date_created']
read_only_fields = ['id', 'name', 'date_created']
fields = ["id", "name", "date_created"]
read_only_fields = ["id", "name", "date_created"]

View file

@ -8,15 +8,16 @@ from config.settings import STRIPE_SECRET_KEY, ROOT_DIR
import os
import json
import stripe
stripe.api_key = STRIPE_SECRET_KEY
@receiver(m2m_changed, sender=UserGroup.managers.through)
def update_group_managers(sender, instance, action, **kwargs):
# When adding new managers to a UserGroup, associate them with it
if action == 'post_add':
if action == "post_add":
# Get the newly added managers
new_managers = kwargs.get('pk_set', set())
new_managers = kwargs.get("pk_set", set())
for manager in new_managers:
# Retrieve the member
USER = CustomUser.objects.get(pk=manager)
@ -27,8 +28,8 @@ def update_group_managers(sender, instance, action, **kwargs):
if USER not in instance.members.all():
instance.members.add(USER)
# When removing managers from a UserGroup, remove their association with it
elif action == 'post_remove':
for manager in kwargs['pk_set']:
elif action == "post_remove":
for manager in kwargs["pk_set"]:
# Retrieve the manager
USER = CustomUser.objects.get(pk=manager)
if USER not in instance.members.all():
@ -39,9 +40,9 @@ def update_group_managers(sender, instance, action, **kwargs):
@receiver(m2m_changed, sender=UserGroup.members.through)
def update_group_members(sender, instance, action, **kwargs):
# When adding new members to a UserGroup, associate them with it
if action == 'post_add':
if action == "post_add":
# Get the newly added members
new_members = kwargs.get('pk_set', set())
new_members = kwargs.get("pk_set", set())
for member in new_members:
# Retrieve the member
USER = CustomUser.objects.get(pk=member)
@ -50,10 +51,13 @@ def update_group_members(sender, instance, action, **kwargs):
USER.user_group = instance
USER.save()
# When removing members from a UserGroup, remove their association with it
elif action == 'post_remove':
for client in kwargs['pk_set']:
elif action == "post_remove":
for client in kwargs["pk_set"]:
USER = CustomUser.objects.get(pk=client)
if USER not in instance.members.all() and USER not in instance.managers.all():
if (
USER not in instance.members.all()
and USER not in instance.managers.all()
):
USER.user_group = None
USER.save()
# Update usage records
@ -66,42 +70,42 @@ def update_group_members(sender, instance, action, **kwargs):
stripe.SubscriptionItem.create_usage_record(
SUBSCRIPTION_ITEM.stripe_id,
quantity=len(instance.members.all()),
action="set"
action="set",
)
except:
print(
f'Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}')
f"Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}"
)
@receiver(post_migrate)
def create_groups(sender, **kwargs):
if sender.name == "agencies":
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
seed_data = json.loads(f.read())
for user_group in seed_data['user_groups']:
OWNER = CustomUser.objects.filter(
email=user_group['owner']).first()
for user_group in seed_data["user_groups"]:
OWNER = CustomUser.objects.filter(email=user_group["owner"]).first()
USER_GROUP, CREATED = UserGroup.objects.get_or_create(
owner=OWNER,
agency_name=user_group['name'],
agency_name=user_group["name"],
)
if CREATED:
print(f"Created UserGroup {USER_GROUP.agency_name}")
# Add managers
USERS = CustomUser.objects.filter(
email__in=user_group['managers'])
USERS = CustomUser.objects.filter(email__in=user_group["managers"])
for USER in USERS:
if USER not in USER_GROUP.managers.all():
print(
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}")
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}"
)
USER_GROUP.managers.add(USER)
# Add members
USERS = CustomUser.objects.filter(
email__in=user_group['members'])
USERS = CustomUser.objects.filter(email__in=user_group["members"])
for USER in USERS:
if USER not in USER_GROUP.members.all():
print(
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}")
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}"
)
USER_GROUP.clients.add(USER)
USER_GROUP.save()

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class EmailsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'webdriver'
default_auto_field = "django.db.models.BigAutoField"
name = "webdriver"

View file

@ -1,11 +1,19 @@
from celery import shared_task
from webdriver.utils import setup_webdriver, selenium_action_template, google_search, get_element, get_elements
from webdriver.utils import (
setup_webdriver,
selenium_action_template,
google_search,
get_element,
get_elements,
)
from selenium.webdriver.common.by import By
from search_results.tasks import create_search_result
# Task template
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5})
@shared_task(
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
)
def sample_selenium_task():
driver = setup_webdriver(use_proxy=False, use_saved_session=False)
@ -18,27 +26,29 @@ def sample_selenium_task():
driver.close()
driver.quit()
# Sample task to scrape Google for search results based on a keyword
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5})
@shared_task(
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
)
def simple_google_search():
driver = setup_webdriver(driver_type="firefox",
use_proxy=False, use_saved_session=False)
driver = setup_webdriver(
driver_type="firefox", use_proxy=False, use_saved_session=False
)
driver.get(f"https://google.com/")
google_search(driver, search_term="cat blog posts")
# Count number of Google search results
search_items = get_elements(
driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
search_items = get_elements(driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
for item in search_items:
title = item.find_element(By.TAG_NAME, 'h3').text
link = item.find_element(By.TAG_NAME, 'a').get_attribute('href')
title = item.find_element(By.TAG_NAME, "h3").text
link = item.find_element(By.TAG_NAME, "a").get_attribute("href")
create_search_result.apply_async(
kwargs={"title": title, "link": link})
create_search_result.apply_async(kwargs={"title": title, "link": link})
driver.close()
driver.quit()

View file

@ -1,6 +1,7 @@
"""
Settings file to hold constants and functions
"""
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from config.settings import get_secret
@ -18,24 +19,26 @@ import os
import random
def take_snapshot(driver, filename='dump.png'):
def take_snapshot(driver, filename="dump.png"):
# Set window size
required_width = driver.execute_script(
'return document.body.parentNode.scrollWidth')
"return document.body.parentNode.scrollWidth"
)
required_height = driver.execute_script(
'return document.body.parentNode.scrollHeight')
driver.set_window_size(
required_width, required_height+(required_height*0.05))
"return document.body.parentNode.scrollHeight"
)
driver.set_window_size(required_width, required_height + (required_height * 0.05))
# Take the snapshot
driver.find_element(By.TAG_NAME,
'body').screenshot('/dumps/'+filename) # avoids any scrollbars
print('Snapshot saved')
driver.find_element(By.TAG_NAME, "body").screenshot(
"/dumps/" + filename
) # avoids any scrollbars
print("Snapshot saved")
def dump_html(driver, filename='dump.html'):
def dump_html(driver, filename="dump.html"):
# Save the page source to error.html
with open(('/dumps/'+filename), 'w', encoding='utf-8') as file:
with open(("/dumps/" + filename), "w", encoding="utf-8") as file:
file.write(driver.page_source)
@ -44,83 +47,83 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
if not USE_PROXY:
use_proxy = False
if use_proxy:
print('Running driver with proxy enabled')
print("Running driver with proxy enabled")
else:
print('Running driver with proxy disabled')
print("Running driver with proxy disabled")
if use_saved_session:
print('Running with saved session')
print("Running with saved session")
else:
print('Running without using saved session')
print("Running without using saved session")
if driver_type == "chrome":
print('Using Chrome driver')
print("Using Chrome driver")
opts = uc.ChromeOptions()
if use_saved_session:
if os.path.exists("/tmp_chrome_profile"):
print('Existing Chrome ephemeral profile found')
print("Existing Chrome ephemeral profile found")
else:
print('No existing Chrome ephemeral profile found')
print("No existing Chrome ephemeral profile found")
os.system("mkdir /tmp_chrome_profile")
if os.path.exists('/chrome'):
print('Copying Chrome Profile to ephemeral directory')
if os.path.exists("/chrome"):
print("Copying Chrome Profile to ephemeral directory")
# Flush any non-essential cache directories from the existing profile as they may balloon in size overtime
os.system(
'rm -rf "/chrome/Selenium Profile/Code Cache/*"')
os.system('rm -rf "/chrome/Selenium Profile/Code Cache/*"')
# Create a copy of the Chrome Profile
os.system("cp -r /chrome/* /tmp_chrome_profile")
try:
# Remove some items related to file locks
os.remove('/tmp_chrome_profile/SingletonLock')
os.remove('/tmp_chrome_profile/SingletonSocket')
os.remove('/tmp_chrome_profile/SingletonLock')
os.remove("/tmp_chrome_profile/SingletonLock")
os.remove("/tmp_chrome_profile/SingletonSocket")
os.remove("/tmp_chrome_profile/SingletonLock")
except:
pass
else:
print('No existing Chrome Profile found. Creating one from scratch')
print("No existing Chrome Profile found. Creating one from scratch")
if use_saved_session:
# Specify the user data directory
opts.add_argument(f'--user-data-dir=/tmp_chrome_profile')
opts.add_argument('--profile-directory=Selenium Profile')
opts.add_argument(f"--user-data-dir=/tmp_chrome_profile")
opts.add_argument("--profile-directory=Selenium Profile")
# Set proxy
if use_proxy:
opts.add_argument(
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}')
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}'
)
opts.add_argument("--disable-extensions")
opts.add_argument('--disable-application-cache')
opts.add_argument("--disable-application-cache")
opts.add_argument("--disable-setuid-sandbox")
opts.add_argument('--disable-dev-shm-usage')
opts.add_argument("--disable-dev-shm-usage")
opts.add_argument("--disable-gpu")
opts.add_argument("--no-sandbox")
opts.add_argument("--headless=new")
driver = uc.Chrome(options=opts)
elif driver_type == "firefox":
print('Using firefox driver')
print("Using firefox driver")
opts = FirefoxOptions()
if use_saved_session:
if not os.path.exists("/firefox"):
print('No profile found')
print("No profile found")
os.makedirs("/firefox")
else:
print('Existing profile found')
print("Existing profile found")
# Specify a profile if it exists
opts.profile = "/firefox"
# Set proxy
if use_proxy:
opts.set_preference('network.proxy.type', 1)
opts.set_preference('network.proxy.socks',
get_secret('PROXY_IP'))
opts.set_preference('network.proxy.socks_port',
int(get_secret('PROXY_PORT_IP_AUTH')))
opts.set_preference('network.proxy.socks_remote_dns', False)
opts.set_preference("network.proxy.type", 1)
opts.set_preference("network.proxy.socks", get_secret("PROXY_IP"))
opts.set_preference(
"network.proxy.socks_port", int(get_secret("PROXY_PORT_IP_AUTH"))
)
opts.set_preference("network.proxy.socks_remote_dns", False)
opts.add_argument('--disable-dev-shm-usage')
opts.add_argument("--disable-dev-shm-usage")
opts.add_argument("--headless")
opts.add_argument("--disable-gpu")
driver = webdriver.Firefox(options=opts)
@ -128,13 +131,15 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
driver.maximize_window()
# Check if proxy is working
driver.get('https://api.ipify.org/')
driver.get("https://api.ipify.org/")
body = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body")))
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
ip_address = body.text
print(f'External IP: {ip_address}')
print(f"External IP: {ip_address}")
return driver
# These are wrapper function for quickly automating multiple steps in webscraping (logins, button presses, text inputs, etc.)
# Depending on your use case, you may have to opt out of using this
@ -151,10 +156,11 @@ def get_element(driver, by, key, hidden_element=False, timeout=8):
wait = WebDriverWait(driver, timeout=timeout)
if not hidden_element:
element = wait.until(
EC.element_to_be_clickable((by, key)) and EC.visibility_of_element_located((by, key)))
EC.element_to_be_clickable((by, key))
and EC.visibility_of_element_located((by, key))
)
else:
element = wait.until(EC.presence_of_element_located(
(by, key)))
element = wait.until(EC.presence_of_element_located((by, key)))
return element
except Exception:
dump_html(driver)
@ -173,13 +179,12 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
wait = WebDriverWait(driver, timeout=timeout)
if hidden_element:
elements = wait.until(
EC.presence_of_all_elements_located((by, key)))
elements = wait.until(EC.presence_of_all_elements_located((by, key)))
else:
visible_elements = wait.until(
EC.visibility_of_any_elements_located((by, key)))
elements = [
element for element in visible_elements if element.is_enabled()]
EC.visibility_of_any_elements_located((by, key))
)
elements = [element for element in visible_elements if element.is_enabled()]
return elements
except Exception:
@ -193,17 +198,22 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
def execute_selenium_elements(driver, timeout, elements):
try:
for index, element in enumerate(elements):
print('Waiting...')
print("Waiting...")
# Element may have a keyword specified, check if that exists before running any actions
if "keyword" in element:
# Skip a step if the keyword does not exist
if element['keyword'] not in driver.page_source:
if element["keyword"] not in driver.page_source:
print(
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}')
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}'
)
continue
elif element['keyword'] in driver.page_source and element['type'] == 'skip':
elif (
element["keyword"] in driver.page_source
and element["type"] == "skip"
):
print(
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}')
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}'
)
break
print(f'Step: {index+1} - {element["name"]}')
# Revert to default iframe action
@ -217,31 +227,47 @@ def execute_selenium_elements(driver, timeout, elements):
else:
values = element["input"]
if type(values) is list:
raise Exception(
'Invalid input value specified for "callback" type')
raise Exception('Invalid input value specified for "callback" type')
else:
# For single input values
driver.execute_script(
f'onRecaptcha("{values}");')
driver.execute_script(f'onRecaptcha("{values}");')
continue
try:
# Try to get default element
if "hidden" in element:
site_element = get_element(
driver, element["default"]["type"], element["default"]["key"], hidden_element=True, timeout=timeout)
driver,
element["default"]["type"],
element["default"]["key"],
hidden_element=True,
timeout=timeout,
)
else:
site_element = get_element(
driver, element["default"]["type"], element["default"]["key"], timeout=timeout)
driver,
element["default"]["type"],
element["default"]["key"],
timeout=timeout,
)
except Exception as e:
print(f'Failed to find primary element')
print(f"Failed to find primary element")
# If that fails, try to get the failover one
print('Trying to find legacy element')
print("Trying to find legacy element")
if "hidden" in element:
site_element = get_element(
driver, element["failover"]["type"], element["failover"]["key"], hidden_element=True, timeout=timeout)
driver,
element["failover"]["type"],
element["failover"]["key"],
hidden_element=True,
timeout=timeout,
)
else:
site_element = get_element(
driver, element["failover"]["type"], element["failover"]["key"], timeout=timeout)
driver,
element["failover"]["type"],
element["failover"]["key"],
timeout=timeout,
)
# Clicking an element
if element["type"] == "click":
site_element.click()
@ -272,11 +298,13 @@ def execute_selenium_elements(driver, timeout, elements):
values = element["input"]
if type(values) is list:
raise Exception(
'Invalid input value specified for "input_replace" type')
'Invalid input value specified for "input_replace" type'
)
else:
# For single input values
driver.execute_script(
f'arguments[0].value = "{values}";', site_element)
f'arguments[0].value = "{values}";', site_element
)
except Exception as e:
take_snapshot(driver)
dump_html(driver)
@ -285,30 +313,33 @@ def execute_selenium_elements(driver, timeout, elements):
raise Exception(e)
def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=False, use_proxy=True):
def solve_captcha(
site_key, url, retry_attempts=3, version="v2", enterprise=False, use_proxy=True
):
# Manual proxy override set via $ENV
if not USE_PROXY:
use_proxy = False
if CAPTCHA_TESTING:
print('Initializing CAPTCHA solver in dummy mode')
print("Initializing CAPTCHA solver in dummy mode")
code = random.randint()
print("CAPTCHA Successful")
return code
elif use_proxy:
print('Using CAPTCHA solver with proxy')
print("Using CAPTCHA solver with proxy")
else:
print('Using CAPTCHA solver without proxy')
print("Using CAPTCHA solver without proxy")
captcha_params = {
"url": url,
"sitekey": site_key,
"version": version,
"enterprise": 1 if enterprise else 0,
"proxy": {
'type': 'socks5',
'uri': get_secret('PROXY_USER_AUTH')
} if use_proxy else None
"proxy": (
{"type": "socks5", "uri": get_secret("PROXY_USER_AUTH")}
if use_proxy
else None
),
}
# Keep retrying until max attempts is reached
@ -316,12 +347,12 @@ def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=Fals
# Solver uses 2CAPTCHA by default
solver = TwoCaptcha(get_secret("CAPTCHA_API_KEY"))
try:
print('Waiting for CAPTCHA code...')
print("Waiting for CAPTCHA code...")
code = solver.recaptcha(**captcha_params)["code"]
print("CAPTCHA Successful")
return code
except Exception as e:
print(f'CAPTCHA Failed! {e}')
print(f"CAPTCHA Failed! {e}")
raise Exception(f"CAPTCHA API Failed!")
@ -339,13 +370,12 @@ def save_browser_session(driver):
# Copy over the profile once we finish logging in
if isinstance(driver, webdriver.Firefox):
# Copy process for Firefox
print('Updating saved Firefox profile')
print("Updating saved Firefox profile")
# Get the current profile directory from about:support page
driver.get("about:support")
box = get_element(
driver, "id", "profile-dir-box", timeout=4)
box = get_element(driver, "id", "profile-dir-box", timeout=4)
temp_profile_path = os.path.join(os.getcwd(), box.text)
profile_path = '/firefox'
profile_path = "/firefox"
# Create the command
copy_command = "cp -r " + temp_profile_path + "/* " + profile_path
# Copy over the Firefox profile
@ -353,13 +383,13 @@ def save_browser_session(driver):
print("Firefox profile saved")
elif isinstance(driver, uc.Chrome):
# Copy the Chrome profile
print('Updating non-ephemeral Chrome profile')
print("Updating non-ephemeral Chrome profile")
# Flush Code Cache again to speed up copy
os.system(
'rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
os.system('rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
if os.system("cp -r /tmp_chrome_profile/* /chrome"):
print("Chrome profile saved")
# Sample function
# Call this within a Celery task
# TODO: Modify as needed to your needs
@ -370,7 +400,7 @@ def selenium_action_template(driver):
info = {
"sample_field1": "sample_data",
"sample_field2": "sample_data",
"captcha_code": lambda: solve_captcha('SITE_KEY', 'SITE_URL')
"captcha_code": lambda: solve_captcha("SITE_KEY", "SITE_URL"),
}
elements = [
@ -382,13 +412,10 @@ def selenium_action_template(driver):
"default": {
# See get_element() for possible selector types
"type": "xpath",
"key": ''
"key": "",
},
# If a site implements canary design releases, you can place the ID for the element in the old design here
"failover": {
"type": "xpath",
"key": ''
}
"failover": {"type": "xpath", "key": ""},
},
]
@ -398,8 +425,8 @@ def selenium_action_template(driver):
# Fill in final fstring values in elements
for element in elements:
if 'input' in element and '{' in element['input']:
a = element['input'].strip('{}')
if "input" in element and "{" in element["input"]:
a = element["input"].strip("{}")
if a in info:
value = info[a]
# Check if the value is a callable (a lambda function) and call it if so
@ -411,11 +438,12 @@ def selenium_action_template(driver):
# Use the stored value
value = site_form_values[a]
# Replace the placeholder with the actual value
element['input'] = str(value)
element["input"] = str(value)
# Execute the selenium actions
execute_selenium_elements(driver, 8, elements)
# Sample task for Google search
@ -429,40 +457,28 @@ def google_search(driver, search_term):
"name": "Type in search term",
"type": "input",
"input": "{search_term}",
"default": {
"type": "xpath",
"key": '//*[@id="APjFqb"]'
},
"failover": {
"type": "xpath",
"key": '//*[@id="APjFqb"]'
}
"default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
"failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
},
{
"name": "Press enter",
"type": "input_enter",
"default": {
"type": "xpath",
"key": '//*[@id="APjFqb"]'
},
"failover": {
"type": "xpath",
"key": '//*[@id="APjFqb"]'
}
"default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
"failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
},
]
site_form_values = {}
for element in elements:
if 'input' in element and '{' in element['input']:
a = element['input'].strip('{}')
if "input" in element and "{" in element["input"]:
a = element["input"].strip("{}")
if a in info:
value = info[a]
if callable(value):
if a not in site_form_values:
site_form_values[a] = value()
value = site_form_values[a]
element['input'] = str(value)
element["input"] = str(value)
execute_selenium_elements(driver, 8, elements)

View file

@ -5,7 +5,7 @@ services:
build:
context: .
dockerfile: Dockerfile
image: drf_template:latest
image: drf_template
ports:
- "${BACKEND_PORT}:${BACKEND_PORT}"
environment:
@ -23,7 +23,7 @@ services:
env_file: .env
environment:
- RUN_TYPE=worker
image: drf_template:latest
image: drf_template
volumes:
- .:/code
- ./chrome:/chrome
@ -42,7 +42,7 @@ services:
env_file: .env
environment:
- RUN_TYPE=beat
image: drf_template:latest
image: drf_template
volumes:
- .:/code
depends_on:
@ -58,7 +58,7 @@ services:
env_file: .env
environment:
- RUN_TYPE=monitor
image: drf_template:latest
image: drf_template
ports:
- "${CELERY_FLOWER_PORT}:5555"
volumes: